]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Add socket_factory to allow socket creation to be overridden when needed.
authorBob Halley <halley@dnspython.org>
Sat, 28 May 2016 19:45:46 +0000 (12:45 -0700)
committerBob Halley <halley@dnspython.org>
Sat, 28 May 2016 19:45:52 +0000 (12:45 -0700)
dns/query.py

index 356709837772db9466c1ef9adf5036cf78343a04..6b76b42e74f54d1999683e71f8f3710991c0e856 100644 (file)
@@ -37,6 +37,9 @@ if sys.version_info > (3,):
 else:
     select_error = select.error
 
+# Function used to create a socket.  Can be overridden if needed in special
+# situations.
+socket_factory = socket.socket
 
 class UnexpectedSource(dns.exception.DNSException):
 
@@ -223,7 +226,7 @@ def udp(q, where, timeout=None, port=53, af=None, source=None, source_port=0,
     wire = q.to_wire()
     (af, destination, source) = _destination_and_source(af, where, port,
                                                         source, source_port)
-    s = socket.socket(af, socket.SOCK_DGRAM, 0)
+    s = socket_factory(af, socket.SOCK_DGRAM, 0)
     begin_time = None
     try:
         expiration = _compute_expiration(timeout)
@@ -331,7 +334,7 @@ def tcp(q, where, timeout=None, port=53, af=None, source=None, source_port=0,
     wire = q.to_wire()
     (af, destination, source) = _destination_and_source(af, where, port,
                                                         source, source_port)
-    s = socket.socket(af, socket.SOCK_STREAM, 0)
+    s = socket_factory(af, socket.SOCK_STREAM, 0)
     begin_time = None
     try:
         expiration = _compute_expiration(timeout)
@@ -435,9 +438,9 @@ def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN,
     if use_udp:
         if rdtype != dns.rdatatype.IXFR:
             raise ValueError('cannot do a UDP AXFR')
-        s = socket.socket(af, socket.SOCK_DGRAM, 0)
+        s = socket_factory(af, socket.SOCK_DGRAM, 0)
     else:
-        s = socket.socket(af, socket.SOCK_STREAM, 0)
+        s = socket_factory(af, socket.SOCK_STREAM, 0)
     s.setblocking(0)
     if source is not None:
         s.bind(source)