return n1 == n2 and a1[1:] == a2[1:]
-def _destination_and_source(af, where, port, source, source_port):
+def _destination_and_source(af, where, port, source, source_port,
+ default_to_inet=True):
# Apply defaults and compute destination and source tuples
# suitable for use in connect(), sendto(), or bind().
if af is None:
try:
af = dns.inet.af_for_address(where)
except Exception:
- af = dns.inet.AF_INET
+ if default_to_inet:
+ af = dns.inet.AF_INET
if af == dns.inet.AF_INET:
destination = (where, port)
if source is not None or source_port != 0:
if source is None:
source = '::'
source = (source, source_port, 0, 0)
+ else:
+ source = None
+ destination = None
return (af, destination, source)
def send_https(session, what, lifetime=None):
what = what.prepare()
return session.send(what, timeout=lifetime)
-def https(q, where, session, timeout=None, port=443, path='/dns-query', post=True,
- bootstrap_address=None, verify=True, source=None, source_port=0,
- one_rr_per_rrset=False, ignore_trailing=False):
+def https(q, where, timeout=None, port=443, af=None, source=None, source_port=0,
+ one_rr_per_rrset=False, ignore_trailing=False,
+ session=None, path='/dns-query', post=True,
+ bootstrap_address=None, verify=True):
"""Return the response obtained after sending a query via DNS-over-HTTPS.
*q*, a ``dns.message.Message``, the query to send.
address is given, the URL will be constructed using the following schema:
https://<IP-address>:<port>/<path>.
- *session*, a ``requests.session.Session``, the session to use to send the
- queries. This argument is required to allow for connection reuse.
-
*timeout*, a ``float`` or ``None``, the number of seconds to
wait before the query times out. If ``None``, the default, wait forever.
- *port*, a ``int``, the port to send the query to. Default is 443.
-
- *path*, a ``str``. If *where* is an IP address, then *path* will be used to
- construct the URL to send the DNS query to.
-
- *post*, a ``bool``. If ``True``, the default, POST method will be used.
+ *port*, a ``int``, the port to send the query to. The default is 443.
- *bootstrap_address*, a ``str``, the IP address to use to bypass the system's
- DNS resolver.
+ *af*, an ``int``, the address family to use. The default is ``None``,
+ which causes the address family to use to be inferred from the form of
+ *where*, or uses the system default. Setting this to AF_INET or
+ AF_INET6 currently has no effect.
*source*, a ``str`` containing an IPv4 or IPv6 address, specifying
the source address. The default is the wildcard address.
*ignore_trailing*, a ``bool``. If ``True``, ignore trailing
junk at end of the received message.
+ *session*, a ``requests.session.Session``. If provided, the session to use
+ to send the queries.
+
+ *path*, a ``str``. If *where* is an IP address, then *path* will be used to
+ construct the URL to send the DNS query to.
+
+ *post*, a ``bool``. If ``True``, the default, POST method will be used.
+
+ *bootstrap_address*, a ``str``, the IP address to use to bypass the
+ system's DNS resolver.
+
+ *verify*, a ``str`, containing a path to a certificate file or directory.
+
Returns a ``dns.message.Message``.
"""
wire = q.to_wire()
- af = None
(af, destination, source) = _destination_and_source(af, where, port,
- source, source_port)
+ source, source_port,
+ False)
+ transport_adapter = None
headers = {
"accept": "application/dns-message"
}
split_url = urllib.parse.urlsplit(where)
headers['Host'] = split_url.hostname
url = where.replace(split_url.hostname, bootstrap_address)
- session.mount(url, HostHeaderSSLAdapter())
+ transport_adapter = HostHeaderSSLAdapter()
else:
url = where
if source is not None:
# set source port and source address
- session.mount(url, SourceAddressAdapter(source))
-
- # see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH GET and POST examples
- if post:
- headers.update({
- "content-type": "application/dns-message",
- "content-length": str(len(wire))
- })
- response = session.post(url, headers=headers, data=wire, stream=True,
- timeout=timeout, verify=verify)
+ transport_adapter = SourceAddressAdapter(source)
+
+ if session:
+ close_session = False
else:
- wire = base64.urlsafe_b64encode(wire).decode('utf-8').strip("=")
- url += "?dns={}".format(wire)
- response = session.get(url, headers=headers, stream=True,
- timeout=timeout, verify=verify)
+ session = requests.sessions.Session()
+ close_session = True
+
+ try:
+ if transport_adapter:
+ session.mount(url, transport_adapter)
+
+ # see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH
+ # GET and POST examples
+ if post:
+ headers.update({
+ "content-type": "application/dns-message",
+ "content-length": str(len(wire))
+ })
+ response = session.post(url, headers=headers, data=wire,
+ stream=True, timeout=timeout,
+ verify=verify)
+ else:
+ wire = base64.urlsafe_b64encode(wire).decode('utf-8').strip("=")
+ url += "?dns={}".format(wire)
+ response = session.get(url, headers=headers, stream=True,
+ timeout=timeout, verify=verify)
+ finally:
+ if close_session:
+ session.close()
- # see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH status codes
+ # see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH
+ # status codes
if response.status_code < 200 or response.status_code > 299:
- raise ValueError('{} responded with status code {}\nResponse body: {}'.format(
- where, response.status_code, response.content))
+ raise ValueError('{} responded with status code {}'
+ '\nResponse body: {}'.format(where,
+ response.status_code,
+ response.content))
r = dns.message.from_wire(response.content,
keyring=q.keyring,
request_mac=q.request_mac,
def test_get_request(self):
nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS)
q = dns.message.make_query('example.com.', dns.rdatatype.A)
- r = dns.query.https(q, nameserver_url, self.session, post=False)
+ r = dns.query.https(q, nameserver_url, session=self.session, post=False)
self.assertTrue(q.is_response(r))
def test_post_request(self):
nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS)
q = dns.message.make_query('example.com.', dns.rdatatype.A)
- r = dns.query.https(q, nameserver_url, self.session, post=True)
+ r = dns.query.https(q, nameserver_url, session=self.session, post=True)
self.assertTrue(q.is_response(r))
def test_build_url_from_ip(self):
q = dns.message.make_query('example.com.', dns.rdatatype.A)
# For some reason Google's DNS over HTTPS fails when you POST to https://8.8.8.8/dns-query
# So we're just going to do GET requests here
- r = dns.query.https(q, nameserver_ip, self.session, post=False)
+ r = dns.query.https(q, nameserver_ip, session=self.session, post=False)
self.assertTrue(q.is_response(r))
def test_bootstrap_address(self):
q = dns.message.make_query('example.com.', dns.rdatatype.A)
# make sure CleanBrowsing's IP address will fail TLS certificate check
with self.assertRaises(SSLError):
- dns.query.https(q, invalid_tls_url, self.session)
+ dns.query.https(q, invalid_tls_url, session=self.session)
# use host header
- r = dns.query.https(q, valid_tls_url, self.session, bootstrap_address=ip)
+ r = dns.query.https(q, valid_tls_url, session=self.session, bootstrap_address=ip)
self.assertTrue(q.is_response(r))
def test_send_https(self):
dns_resp = dns.message.from_wire(response.content)
self.assertTrue(q.is_response(dns_resp))
+ def test_new_session(self):
+ nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS)
+ q = dns.message.make_query('example.com.', dns.rdatatype.A)
+ r = dns.query.https(q, nameserver_url)
+ self.assertTrue(q.is_response(r))
+
if __name__ == '__main__':
unittest.main()