]> git.ipfire.org Git - thirdparty/pdns.git/blob - regression-tests.common/proxyprotocol.py
Merge pull request #8945 from rgacogne/ddist-x-forwarded-for
[thirdparty/pdns.git] / regression-tests.common / proxyprotocol.py
1 #!/usr/bin/env python
2
3 import copy
4 import socket
5 import struct
6
7 class ProxyProtocol(object):
8 MAGIC = b'\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A'
9 # Header is magic + versioncommand (1) + family (1) + content length (2)
10 HEADER_SIZE = len(MAGIC) + 1 + 1 + 2
11 PORT_SIZE = 2
12
13 def consumed(self):
14 return self.offset
15
16 def parseHeader(self, data):
17 if len(data) < self.HEADER_SIZE:
18 return False
19
20 if data[:len(self.MAGIC)] != self.MAGIC:
21 return False
22
23 value = struct.unpack('!B', bytes(bytearray([data[12]])))[0]
24 self.version = value >> 4
25 if self.version != 0x02:
26 return False
27
28 self.command = value & ~0x20
29 self.local = False
30 self.offset = self.HEADER_SIZE
31
32 if self.command == 0x00:
33 self.local = True
34 elif self.command == 0x01:
35 value = struct.unpack('!B', bytes(bytearray([data[13]])))[0]
36 self.family = value >> 4
37 if self.family == 0x01:
38 self.addrSize = 4
39 elif self.family == 0x02:
40 self.addrSize = 16
41 else:
42 return False
43
44 self.protocol = value & ~0xF0
45 if self.protocol == 0x01:
46 self.tcp = True
47 elif self.protocol == 0x02:
48 self.tcp = False
49 else:
50 return False
51 else:
52 return False
53
54 self.contentLen = struct.unpack("!H", data[14:16])[0]
55
56 if not self.local:
57 if self.contentLen < (self.addrSize * 2 + self.PORT_SIZE * 2):
58 return False
59
60 return True
61
62 def getAddr(self, data):
63 if len(data) < (self.consumed() + self.addrSize):
64 return False
65
66 value = None
67 if self.family == 0x01:
68 value = socket.inet_ntop(socket.AF_INET, data[self.offset:self.offset + self.addrSize])
69 else:
70 value = socket.inet_ntop(socket.AF_INET6, data[self.offset:self.offset + self.addrSize])
71
72 self.offset = self.offset + self.addrSize
73 return value
74
75 def getPort(self, data):
76 if len(data) < (self.consumed() + self.PORT_SIZE):
77 return False
78
79 value = struct.unpack('!H', data[self.offset:self.offset + self.PORT_SIZE])[0]
80 self.offset = self.offset + self.PORT_SIZE
81 return value
82
83 def parseAddressesAndPorts(self, data):
84 if self.local:
85 return True
86
87 if len(data) < (self.consumed() + self.addrSize * 2 + self.PORT_SIZE * 2):
88 return False
89
90 self.source = self.getAddr(data)
91 self.destination = self.getAddr(data)
92 self.sourcePort = self.getPort(data)
93 self.destinationPort = self.getPort(data)
94 return True
95
96 def parseAdditionalValues(self, data):
97 self.values = []
98 if self.local:
99 return True
100
101 if len(data) < (self.HEADER_SIZE + self.contentLen):
102 return False
103
104 remaining = self.HEADER_SIZE + self.contentLen - self.consumed()
105 if len(data) < remaining:
106 return False
107
108 while remaining >= 3:
109 valueType = struct.unpack("!B", bytes(bytearray([data[self.offset]])))[0]
110 self.offset = self.offset + 1
111 valueLen = struct.unpack("!H", data[self.offset:self.offset+2])[0]
112 self.offset = self.offset + 2
113
114 remaining = remaining - 3
115 if valueLen > 0:
116 if valueLen > remaining:
117 return False
118 self.values.append([valueType, data[self.offset:self.offset+valueLen]])
119 self.offset = self.offset + valueLen
120 remaining = remaining - valueLen
121
122 else:
123 self.values.append([valueType, ""])
124
125 return True
126
127 @classmethod
128 def getPayload(cls, local, tcp, v6, source, destination, sourcePort, destinationPort, values):
129 payload = copy.deepcopy(cls.MAGIC)
130 version = 0x02
131
132 if local:
133 command = 0x00
134 else:
135 command = 0x01
136
137 value = struct.pack('!B', (version << 4) + command)
138 payload = payload + value
139
140 addrSize = 0
141 family = 0x00
142 protocol = 0x00
143 if not local:
144 if tcp:
145 protocol = 0x01
146 else:
147 protocol = 0x02
148 # sorry but compatibility with python 2 is awful for this,
149 # not going to waste time on it
150 if not v6:
151 family = 0x01
152 addrSize = 4
153 else:
154 family = 0x02
155 addrSize = 16
156
157 value = struct.pack('!B', (family << 4) + protocol)
158 payload = payload + value
159
160 contentSize = 0
161 if not local:
162 contentSize = contentSize + addrSize * 2 + cls.PORT_SIZE *2
163
164 valuesSize = 0
165 for value in values:
166 valuesSize = valuesSize + 3 + len(value[1])
167
168 contentSize = contentSize + valuesSize
169
170 value = struct.pack('!H', contentSize)
171 payload = payload + value
172
173 if not local:
174 if family == 0x01:
175 af = socket.AF_INET
176 else:
177 af = socket.AF_INET6
178
179 value = socket.inet_pton(af, source)
180 payload = payload + value
181 value = socket.inet_pton(af, destination)
182 payload = payload + value
183 value = struct.pack('!H', sourcePort)
184 payload = payload + value
185 value = struct.pack('!H', destinationPort)
186 payload = payload + value
187
188 for value in values:
189 valueType = struct.pack('!B', value[0])
190 valueLen = struct.pack('!H', len(value[1]))
191 payload = payload + valueType + valueLen + value[1]
192
193 return payload