From: Charles-Henri Bruyand Date: Wed, 22 Nov 2023 10:03:16 +0000 (+0100) Subject: dnsdist: add basic DoHTTP/3 test X-Git-Tag: dnsdist-1.9.0-alpha4~15^2~15 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=4f0b10a9822da0db99b63f83a1baea869aa82fad;p=thirdparty%2Fpdns.git dnsdist: add basic DoHTTP/3 test --- diff --git a/regression-tests.dnsdist/dnsdisttests.py b/regression-tests.dnsdist/dnsdisttests.py index f2e5e31a99..7b1cca3530 100644 --- a/regression-tests.dnsdist/dnsdisttests.py +++ b/regression-tests.dnsdist/dnsdisttests.py @@ -29,6 +29,7 @@ import pycurl from io import BytesIO from doqclient import quic_query +from doh3client import doh3_query from eqdnsmessage import AssertEqualDNSMessageMixin from proxyprotocol import ProxyProtocol @@ -1126,3 +1127,26 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase): receivedQuery = cls._fromResponderQueue.get(True, timeout) return (receivedQuery, message) + + @classmethod + def sendDOH3Query(cls, port, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, fromQueue=None, toQueue=None, connection=None, serverName=None): + + if response: + if toQueue: + toQueue.put(response, True, timeout) + else: + cls._toResponderQueue.put(response, True, timeout) + + message = doh3_query(query, baseurl, timeout, port, verify=caFile, server_hostname=serverName) + + receivedQuery = None + + if useQueue: + if fromQueue: + if not fromQueue.empty(): + receivedQuery = fromQueue.get(True, timeout) + else: + if not cls._fromResponderQueue.empty(): + receivedQuery = cls._fromResponderQueue.get(True, timeout) + + return (receivedQuery, message) diff --git a/regression-tests.dnsdist/doh3client.py b/regression-tests.dnsdist/doh3client.py new file mode 100644 index 0000000000..eeebb4c3b6 --- /dev/null +++ b/regression-tests.dnsdist/doh3client.py @@ -0,0 +1,280 @@ +import base64 +import asyncio +import pickle +import ssl +import struct +import dns +import time +import async_timeout + +from collections import deque +from typing import BinaryIO, Callable, Deque, Dict, List, Optional, Union, cast +from urllib.parse import urlparse + +import aioquic +from aioquic.asyncio.client import connect +from aioquic.asyncio.protocol import QuicConnectionProtocol +from aioquic.h0.connection import H0_ALPN, H0Connection +from aioquic.h3.connection import H3_ALPN, ErrorCode, H3Connection +from aioquic.h3.events import ( + DataReceived, + H3Event, + HeadersReceived, + PushPromiseReceived, +) +from aioquic.quic.configuration import QuicConfiguration +from aioquic.quic.events import QuicEvent +#from aioquic.quic.logger import QuicFileLogger +from aioquic.tls import CipherSuite, SessionTicket +# +#class DnsClientProtocol(QuicConnectionProtocol): +# def __init__(self, *args, **kwargs): +# super().__init__(*args, **kwargs) +# self._ack_waiter: Any = None +# +# def pack(self, data): +# # serialize query +# data = bytes(data) +# data = struct.pack("!H", len(data)) + data +# return data +# +# async def query(self, query: dns.message) -> None: +# data = self.pack(query.to_wire()) +# # send query and wait for answer +# stream_id = self._quic.get_next_available_stream_id() +# self._quic.send_stream_data(stream_id, data, end_stream=True) +# waiter = self._loop.create_future() +# self._ack_waiter = waiter +# self.transmit() +# +# return await asyncio.shield(waiter) +# +# def quic_event_received(self, event: QuicEvent) -> None: +# if self._ack_waiter is not None: +# if isinstance(event, StreamDataReceived): +# length = struct.unpack("!H", bytes(event.data[:2]))[0] +# answer = dns.message.from_wire(event.data[2 : 2 + length], ignore_trailing=True) +# +# waiter = self._ack_waiter +# self._ack_waiter = None +# waiter.set_result(answer) +# if isinstance(event, StreamReset): +# waiter = self._ack_waiter +# self._ack_waiter = None +# waiter.set_result(event) +# +#class BogusDnsClientProtocol(DnsClientProtocol): +# def pack(self, data): +# # serialize query +# data = bytes(data) +# data = struct.pack("!H", len(data) * 2) + data +# return data +HttpConnection = Union[H0Connection, H3Connection] + +class URL: + def __init__(self, url: str) -> None: + parsed = urlparse(url) + + self.authority = parsed.netloc + self.full_path = parsed.path or "/" + if parsed.query: + self.full_path += "?" + parsed.query + self.scheme = parsed.scheme + + +class HttpRequest: + def __init__( + self, + method: str, + url: URL, + content: bytes = b"", + headers: Optional[Dict] = None, + ) -> None: + if headers is None: + headers = {} + + self.content = content + self.headers = headers + self.method = method + self.url = url + +class HttpClient(QuicConnectionProtocol): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + self.pushes: Dict[int, Deque[H3Event]] = {} + self._http: Optional[HttpConnection] = None + self._request_events: Dict[int, Deque[H3Event]] = {} + self._request_waiter: Dict[int, asyncio.Future[Deque[H3Event]]] = {} + + if self._quic.configuration.alpn_protocols[0].startswith("hq-"): + self._http = H0Connection(self._quic) + else: + self._http = H3Connection(self._quic) + + async def get(self, url: str, headers: Optional[Dict] = None) -> Deque[H3Event]: + """ + Perform a GET request. + """ + return await self._request( + HttpRequest(method="GET", url=URL(url), headers=headers) + ) + + async def post( + self, url: str, data: bytes, headers: Optional[Dict] = None + ) -> Deque[H3Event]: + """ + Perform a POST request. + """ + return await self._request( + HttpRequest(method="POST", url=URL(url), content=data, headers=headers) + ) + + + def http_event_received(self, event: H3Event) -> None: + if isinstance(event, (HeadersReceived, DataReceived)): + stream_id = event.stream_id + if stream_id in self._request_events: + # http + self._request_events[event.stream_id].append(event) + if event.stream_ended: + request_waiter = self._request_waiter.pop(stream_id) + request_waiter.set_result(self._request_events.pop(stream_id)) + + elif stream_id in self._websockets: + # websocket + websocket = self._websockets[stream_id] + websocket.http_event_received(event) + + elif event.push_id in self.pushes: + # push + self.pushes[event.push_id].append(event) + + elif isinstance(event, PushPromiseReceived): + self.pushes[event.push_id] = deque() + self.pushes[event.push_id].append(event) + + def quic_event_received(self, event: QuicEvent) -> None: + #  pass event to the HTTP layer + if self._http is not None: + for http_event in self._http.handle_event(event): + self.http_event_received(http_event) + + async def _request(self, request: HttpRequest) -> Deque[H3Event]: + stream_id = self._quic.get_next_available_stream_id() + self._http.send_headers( + stream_id=stream_id, + headers=[ + (b":method", request.method.encode()), + (b":scheme", request.url.scheme.encode()), + (b":authority", request.url.authority.encode()), + (b":path", request.url.full_path.encode()), + ] + + [(k.encode(), v.encode()) for (k, v) in request.headers.items()], + end_stream=not request.content, + ) + if request.content: + self._http.send_data( + stream_id=stream_id, data=request.content, end_stream=True + ) + + waiter = self._loop.create_future() + self._request_events[stream_id] = deque() + self._request_waiter[stream_id] = waiter + self.transmit() + + return await asyncio.shield(waiter) + + +async def perform_http_request( + client: HttpClient, + url: str, + data: Optional[str], + include: bool, + output_dir: Optional[str], +) -> None: + # perform request + start = time.time() + if data is not None: + data_bytes = data.encode() + http_events = await client.post( + url, + data=data_bytes, + headers={ + "content-length": str(len(data_bytes)), + "content-type": "application/x-www-form-urlencoded", + }, + ) + method = "POST" + else: + http_events = await client.get(url) + method = "GET" + elapsed = time.time() - start + + result = bytes() + for http_event in http_events: + if isinstance(http_event, DataReceived): + result += http_event.data + return result + + +async def async_h3_query( + configuration: QuicConfiguration, + baseurl: str, + port: int, + query: dns.message, + timeout: float, + create_protocol=HttpClient +) -> None: + + url = "{}?dns={}".format(baseurl, base64.urlsafe_b64encode(query.to_wire()).decode('UTF8').rstrip('=')) + print("Querying for {}".format(url)) + async with connect( + "127.0.0.1", + port, + configuration=configuration, + create_protocol=create_protocol, + ) as client: + client = cast(HttpClient, client) + + print("Sending DNS query") + try: + async with async_timeout.timeout(timeout): + + answer = await perform_http_request( + client=client, + url=url, + data=None, + include=False, + output_dir=None, + ) + + return answer + except asyncio.TimeoutError as e: + return e + +class StreamResetError(Exception): + def __init__(self, error, message="Stream reset by peer"): + self.error = error + super().__init__(message) + +def doh3_query(query, baseurl, timeout=2, port=853, verify=None, server_hostname=None): + configuration = QuicConfiguration(alpn_protocols=H3_ALPN, is_client=True) + if verify: + configuration.load_verify_locations(verify) + result = asyncio.run( + async_h3_query( + configuration=configuration, + baseurl=baseurl, + port=port, + query=query, + timeout=timeout, + create_protocol=HttpClient + ) + ) + # if (isinstance(result, StreamReset)): + # raise StreamResetError(result.error_code) + if (isinstance(result, asyncio.TimeoutError)): + raise TimeoutError() + return result + diff --git a/regression-tests.dnsdist/test_DOH3.py b/regression-tests.dnsdist/test_DOH3.py new file mode 100644 index 0000000000..74e4bb15a0 --- /dev/null +++ b/regression-tests.dnsdist/test_DOH3.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python +import dns +import clientsubnetoption + +from dnsdisttests import DNSDistTest +from dnsdisttests import pickAvailablePort + +import doh3client + +class TestDOH3(DNSDistTest): + _serverKey = 'server.key' + _serverCert = 'server.chain' + _serverName = 'tls.tests.dnsdist.org' + _caCert = 'ca.pem' + _doqServerPort = pickAvailablePort() + _dohBaseURL = ("https://%s:%d/" % (_serverName, _doqServerPort)) + _config_template = """ + newServer{address="127.0.0.1:%d"} + + addAction("drop.doq.tests.powerdns.com.", DropAction()) + addAction("refused.doq.tests.powerdns.com.", RCodeAction(DNSRCode.REFUSED)) + addAction("spoof.doq.tests.powerdns.com.", SpoofAction("1.2.3.4")) + addAction("no-backend.doq.tests.powerdns.com.", PoolAction('this-pool-has-no-backend')) + + addDOH3Local("127.0.0.1:%d", "%s", "%s") + """ + _config_params = ['_testServerPort', '_doqServerPort','_serverCert', '_serverKey'] + _verboseMode = True + + def testDOH3Simple(self): + """ + DOH3: Simple query + """ + name = 'simple.doq.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN', use_edns=False) + query.id = 0 + expectedQuery = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096) + expectedQuery.id = 0 + response = dns.message.make_response(query) + rrset = dns.rrset.from_text(name, + 3600, + dns.rdataclass.IN, + dns.rdatatype.A, + '127.0.0.1') + response.answer.append(rrset) + (receivedQuery, receivedResponse) = self.sendDOH3Query(self._doqServerPort, self._dohBaseURL, query, response=response, caFile=self._caCert, serverName=self._serverName) + self.assertTrue(receivedQuery) + self.assertTrue(receivedResponse) + receivedQuery.id = expectedQuery.id + self.assertEqual(expectedQuery, receivedQuery)