]> git.ipfire.org Git - thirdparty/pdns.git/blob - regression-tests.dnsdist/proxyprotocol.py
dnsdist: Make the Proxy Protocol tests compatible with Python 2
[thirdparty/pdns.git] / regression-tests.dnsdist / proxyprotocol.py
1 #!/usr/bin/env python
2
3 import socket
4 import struct
5
6 class ProxyProtocol(object):
7 MAGIC = b'\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A'
8 # Header is magic + versioncommand (1) + family (1) + content length (2)
9 HEADER_SIZE = len(MAGIC) + 1 + 1 + 2
10 PORT_SIZE = 2
11
12 def consumed(self):
13 return self.offset
14
15 def parseHeader(self, data):
16 if len(data) < self.HEADER_SIZE:
17 return False
18
19 if data[:len(self.MAGIC)] != self.MAGIC:
20 return False
21
22 value = struct.unpack('!B', bytes(bytearray([data[12]])))[0]
23 self.version = value >> 4
24 if self.version != 0x02:
25 return False
26
27 self.command = value & ~0x20
28 self.local = False
29 self.offset = self.HEADER_SIZE
30
31 if self.command == 0x00:
32 self.local = True
33 elif self.command == 0x01:
34 value = struct.unpack('!B', bytes(bytearray([data[13]])))[0]
35 self.family = value >> 4
36 if self.family == 0x01:
37 self.addrSize = 4
38 elif self.family == 0x02:
39 self.addrSize = 16
40 else:
41 return False
42
43 self.protocol = value & ~0xF0
44 if self.protocol == 0x01:
45 self.tcp = True
46 elif self.protocol == 0x02:
47 self.tcp = False
48 else:
49 return False
50 else:
51 return False
52
53 self.contentLen = struct.unpack("!H", data[14:16])[0]
54
55 if not self.local:
56 if self.contentLen < (self.addrSize * 2 + self.PORT_SIZE * 2):
57 return False
58
59 return True
60
61 def getAddr(self, data):
62 if len(data) < (self.consumed() + self.addrSize):
63 return False
64
65 value = None
66 if self.family == 0x01:
67 value = socket.inet_ntop(socket.AF_INET, data[self.offset:self.offset + self.addrSize])
68 else:
69 value = socket.inet_ntop(socket.AF_INET6, data[self.offset:self.offset + self.addrSize])
70
71 self.offset = self.offset + self.addrSize
72 return value
73
74 def getPort(self, data):
75 if len(data) < (self.consumed() + self.PORT_SIZE):
76 return False
77
78 value = struct.unpack('!H', data[self.offset:self.offset + self.PORT_SIZE])[0]
79 self.offset = self.offset + self.PORT_SIZE
80 return value
81
82 def parseAddressesAndPorts(self, data):
83 if self.local:
84 return True
85
86 if len(data) < (self.consumed() + self.addrSize * 2 + self.PORT_SIZE * 2):
87 return False
88
89 self.source = self.getAddr(data)
90 self.destination = self.getAddr(data)
91 self.sourcePort = self.getPort(data)
92 self.destinationPort = self.getPort(data)
93 return True
94
95 def parseAdditionalValues(self, data):
96 self.values = []
97 if self.local:
98 return True
99
100 if len(data) < (self.HEADER_SIZE + self.contentLen):
101 return False
102
103 remaining = self.HEADER_SIZE + self.contentLen - self.consumed()
104 if len(data) < remaining:
105 return False
106
107 while remaining >= 3:
108 valueType = struct.unpack("!B", bytes(bytearray([data[self.offset]])))[0]
109 self.offset = self.offset + 1
110 valueLen = struct.unpack("!H", data[self.offset:self.offset+2])[0]
111 self.offset = self.offset + 2
112
113 remaining = remaining - 3
114 if valueLen > 0:
115 if valueLen > remaining:
116 return False
117 self.values.append([valueType, data[self.offset:self.offset+valueLen]])
118 self.offset = self.offset + valueLen
119 remaining = remaining - valueLen
120
121 else:
122 self.values.append([valueType, ""])
123
124 return True