+++ /dev/null
-#!/bin/sh
-
-# Copyright (C) Internet Systems Consortium, Inc. ("ISC")
-#
-# SPDX-License-Identifier: MPL-2.0
-#
-# This Source Code Form is subject to the terms of the Mozilla Public
-# License, v. 2.0. If a copy of the MPL was not distributed with this
-# file, you can obtain one at https://mozilla.org/MPL/2.0/.
-#
-# See the COPYRIGHT file distributed with this work for additional
-# information regarding copyright ownership.
-
-set -e
-
-# shellcheck source=../conf.sh
-. ../conf.sh
-
-dig_with_opts() {
- "${DIG}" -p "${PORT}" "$@"
-}
-
-rndccmd() {
- "${RNDC}" -p "${CONTROLPORT}" -c ../_common/rndc.conf -s "$@"
-}
-
-status=0
-n=0
-
-####################################################
-# NOTE: The next test resets the debug level to 1. #
-####################################################
-
-n=$((n + 1))
-echo_i "checking that BIND 9 doesn't crash on long TCP messages ($n)"
-ret=0
-# Avoid logging useless information.
-rndccmd 10.53.0.1 trace 1 || ret=1
-{ $PERL ../packet.pl -a "10.53.0.1" -p "${PORT}" -t tcp -r 300000 1996-alloc_dnsbuf-crash-test.pkt || ret=1; } | cat_i
-dig_with_opts +tcp @10.53.0.1 txt.example >dig.out.test$n || ret=1
-if [ $ret != 0 ]; then echo_i "failed"; fi
-status=$((status + ret))
-
-echo_i "exit status: $status"
-[ $status -eq 0 ] || exit 1
# See the COPYRIGHT file distributed with this work for additional
# information regarding copyright ownership.
-from collections.abc import Iterable
+from collections.abc import Iterable, Iterator
from types import TracebackType
from typing import NamedTuple
import asyncio
+import contextlib
import socket
import struct
import time
sock.close()
+async def send_long_tcp_stream(
+ host: str, port: int, message: dns.message.Message, min_bytes: int
+) -> None:
+ frame = message.to_wire(prepend_length=True)
+ chunk_frames = max(1, 65536 // len(frame))
+ frames_remaining = (min_bytes + len(frame) - 1) // len(frame)
+
+ async def discard_stream(reader: asyncio.StreamReader) -> None:
+ with contextlib.suppress(OSError):
+ while await reader.read(65535):
+ pass
+
+ async def run() -> None:
+ reader, writer = await asyncio.open_connection(host, port)
+ discard_task = asyncio.create_task(discard_stream(reader))
+ try:
+ remaining = frames_remaining
+ while remaining > 0:
+ frames = min(chunk_frames, remaining)
+ writer.write(frame * frames)
+ await writer.drain()
+ remaining -= frames
+
+ writer.write_eof()
+ await writer.drain()
+
+ writer.close()
+ with contextlib.suppress(ConnectionError, OSError):
+ await writer.wait_closed()
+ await discard_task
+ finally:
+ writer.close()
+ if not discard_task.done():
+ discard_task.cancel()
+ with contextlib.suppress(asyncio.CancelledError):
+ await discard_task
+
+ await asyncio.wait_for(run(), timeout=10 * TIMEOUT)
+
+
def test_tcp_garbage(ns7: NamedInstance, named_port: int) -> None:
with create_socket(ns7.ip, named_port) as sock:
msg = isctest.query.create(
check_tcp_response(ns5.ip)
asyncio.run(run())
+
+
+def debug_level(ns: NamedInstance) -> int:
+ status = ns.rndc("status").out
+ matches = status.grep("debug level:")
+ assert matches, f"'debug level' not found in rndc status:\n{status}"
+ return int(matches[0].string.partition(":")[2])
+
+
+@contextlib.contextmanager
+def temporary_trace_level(ns: NamedInstance, level: int) -> Iterator[None]:
+ """Lower the debug level for a noisy section, then restore the default."""
+ prev_level = debug_level(ns)
+ ns.rndc(f"trace {level}")
+ try:
+ yield
+ finally:
+ # Don't mask an in-flight test failure if named has died.
+ ns.rndc(f"trace {prev_level}", raise_on_exception=False)
+
+
+def test_long_tcp_messages(ns1: NamedInstance, named_port: int) -> None:
+ isctest.log.info("checking that BIND 9 doesn't crash on long TCP messages")
+ stream_bytes = 6 * 1024 * 1024
+ msg = isctest.query.create(
+ "isc.org.",
+ "AXFR",
+ dnssec=False,
+ use_edns=False,
+ rd=False,
+ ad=False,
+ message_id=1,
+ )
+
+ # Avoid logging the huge query stream at the default debug level.
+ with temporary_trace_level(ns1, 1):
+ asyncio.run(send_long_tcp_stream(ns1.ip, named_port, msg, stream_bytes))
+
+ msg = isctest.query.create("txt.example.", "A")
+ isctest.query.tcp(msg, ns1.ip)