import dns.rdatatype
import requests
+import requests.packages.urllib3.util.connection as urllib3_cn
+from requests_toolbelt.adapters.source import SourceAddressAdapter
try:
import ssl
return (af, destination, source)
-def https(q, where, timeout=None, af=None, source=None, source_port=0,
- one_rr_per_rrset=False, ignore_trailing=False,
- post=True, path='/dns-query', verify=True):
+def https(q, where, timeout=None, port=443, path='/dns-query', post=True,
+ verify=True, af=None, source=None, source_port=0,
+ one_rr_per_rrset=False, ignore_trailing=False):
"""Return the response obtained after sending a query via DNS-over-HTTPS.
*q*, a ``dns.message.Message``, the query to send.
- *where*, a ``str``, the nameserver IP address or the full URL.
+ *where*, a ``str``, the nameserver IP address or the full URL. If an IP
+ address is given, the URL will be constructed using the following schema:
+ https:<IP-address>:<port>/<path>.
*timeout*, a ``float`` or ``None``, the number of seconds to
wait before the query times out. If ``None``, the default, wait forever.
+ *port*, a ``int``, the port to send the query to. Default is 443.
+
+ *path*, a ``str``. If *where* is an IP address, then *path* will be used to
+ construct the URL to send the DNS query to.
+
+ *post*, a ``bool``. If ``True``, the default, POST method will be used.
+
*af*, an ``int``, the address family to use. The default is ``None``,
which causes the address family to use to be inferred from the form of
*where*. If the inference attempt fails, AF_INET is used. This
*ignore_trailing*, a ``bool``. If ``True``, ignore trailing
junk at end of the received message.
- *post*, a ``bool``. If ``True``, the default, POST method should be used.
-
- *path*, a ``str``. If *where* is an IP address, then *path* will be used to
- construct the URL to send the DNS query to
-
Returns a ``dns.message.Message``.
"""
wire = q.to_wire()
- port = 443
(af, destination, source) = _destination_and_source(af, where, port,
source, source_port)
+ if source is None:
+ source = ('', 0)
+ if headers is None:
+ headers = {}
+ with requests.Session() as session:
+ # set source port and source address
+ # see https://github.com/requests/toolbelt/blob/master/requests_toolbelt/adapters/source.py
+ session.mount('http://', SourceAddressAdapter(source))
+ session.mount('https://', SourceAddressAdapter(source))
+
+ # effectively set address family
+ # see https://stackoverflow.com/a/46972341/9638991
+ urllib3_cn.allowed_gai_family = lambda: af
- # see https://github.com/requests/toolbelt/blob/master/requests_toolbelt/adapters/source.py
- from requests_toolbelt.adapters.source import SourceAddressAdapter
- session = requests.Session()
- session.mount('http://', SourceAddressAdapter(source))
- session.mount('https://', SourceAddressAdapter(source))
-
- # see https://stackoverflow.com/a/46972341/9638991
- import requests.packages.urllib3.util.connection as urllib3_cn
-
- def allowed_gai_family():
- return af
-
- urllib3_cn.allowed_gai_family = allowed_gai_family
-
- try:
- _ = ipaddress.ip_address(where)
- url = 'https://{}{}'.format(where, path)
- except ValueError:
- url = where
-
- # see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH GET and POST examples
- if post:
- headers = {
- "accept": "application/dns-message",
- "content-type": "application/dns-message",
- "content-length": str(len(wire))
- }
- response = session.post(url, headers=headers, data=wire, stream=True, timeout=timeout, verify=verify)
- else:
- wire = base64.urlsafe_b64encode(wire).decode('utf-8').strip("=")
- headers = {
- "accept": "application/dns-message"
- }
- url += "?dns={}".format(wire)
- response = session.get(url, headers=headers, stream=True, timeout=timeout, verify=verify)
-
- # see https://tools.ietf.org/html/rfc8484#section-4.2.1 for more info about DoH status codes
- if 200 > response.status_code > 299:
- raise ValueError('{} responded with status code {}\nResponse body: {}'.format(
- where, response.status_code, response.content))
- r = dns.message.from_wire(response.content,
- keyring=q.keyring,
- request_mac=q.request_mac,
- one_rr_per_rrset=one_rr_per_rrset,
- ignore_trailing=ignore_trailing)
+ try:
+ _ = ipaddress.ip_address(where)
+ url = 'https://{}:{}{}'.format(where, port, path)
+ except ValueError:
+ url = where
+
+ # see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH GET and POST examples
+ if post:
+ headers = {
+ "accept": "application/dns-message",
+ "content-type": "application/dns-message",
+ "content-length": str(len(wire))
+ }
+ response = session.post(url, headers=headers, data=wire, stream=True,
+ timeout=timeout, verify=verify)
+ else:
+ wire = base64.urlsafe_b64encode(wire).decode('utf-8').strip("=")
+ headers = {
+ "accept": "application/dns-message"
+ }
+ url += "?dns={}".format(wire)
+ response = session.get(url, headers=headers, stream=True,
+ timeout=timeout, verify=verify)
+
+ # see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH status codes
+ if response.status_code < 200 or response.status_code > 299:
+ raise ValueError('{} responded with status code {}\nResponse body: {}'.format(
+ where, response.status_code, response.content))
+ r = dns.message.from_wire(response.content,
+ keyring=q.keyring,
+ request_mac=q.request_mac,
+ one_rr_per_rrset=one_rr_per_rrset,
+ ignore_trailing=ignore_trailing)
r.time = response.elapsed
- r.status_code = response.status_code
if not q.is_response(r):
raise BadResponse
return r
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
-
import unittest
import random
self.assertTrue(q.is_response(r))
def test_build_url_from_ip(self):
- nameserver_ip = '8.8.8.8' #random.choice(KNOWN_ANYCAST_DOH_RESOLVER_IPS)
+ nameserver_ip = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_IPS)
q = dns.message.make_query('example.com.', dns.rdatatype.A)
# For some reason Google's DNS over HTTPS fails when you POST to https://8.8.8.8/dns-query
- # So we're just going to do the GET request
+ # So we're just going to do GET requests here
r = dns.query.https(q, nameserver_ip, post=False)
self.assertTrue(q.is_response(r))
- def test_custom_path(self):
- cleanbrowsing_ip = '185.228.168.168'
- cleanbrowsing_path = '/doh/security-filter/'
- q = dns.message.make_query('example.com.', dns.rdatatype.A)
- r = dns.query.https(q, cleanbrowsing_ip, path=cleanbrowsing_path, verify=False)
- self.assertTrue(q.is_response(r))
-
- def test_use_full_url(self):
- pass
-
if __name__ == '__main__':
unittest.main()