]> git.ipfire.org Git - thirdparty/pdns.git/blob - pdns/tcpiohandler.hh
Avoid throwing an exception in Logger::log().
[thirdparty/pdns.git] / pdns / tcpiohandler.hh
1
2 #pragma once
3 #include <memory>
4
5 #include "libssl.hh"
6 #include "misc.hh"
7
8 enum class IOState { Done, NeedRead, NeedWrite };
9
10 class TLSConnection
11 {
12 public:
13 virtual ~TLSConnection() { }
14 virtual void doHandshake() = 0;
15 virtual IOState tryHandshake() = 0;
16 virtual size_t read(void* buffer, size_t bufferSize, unsigned int readTimeout, unsigned int totalTimeout=0) = 0;
17 virtual size_t write(const void* buffer, size_t bufferSize, unsigned int writeTimeout) = 0;
18 virtual IOState tryWrite(std::vector<uint8_t>& buffer, size_t& pos, size_t toWrite) = 0;
19 virtual IOState tryRead(std::vector<uint8_t>& buffer, size_t& pos, size_t toRead) = 0;
20 virtual std::string getServerNameIndication() const = 0;
21 virtual LibsslTLSVersion getTLSVersion() const = 0;
22 virtual bool hasSessionBeenResumed() const = 0;
23 virtual void close() = 0;
24
25 void setUnknownTicketKey()
26 {
27 d_unknownTicketKey = true;
28 }
29
30 bool getUnknownTicketKey() const
31 {
32 return d_unknownTicketKey;
33 }
34
35 void setResumedFromInactiveTicketKey()
36 {
37 d_resumedFromInactiveTicketKey = true;
38 }
39
40 bool getResumedFromInactiveTicketKey() const
41 {
42 return d_resumedFromInactiveTicketKey;
43 }
44
45 protected:
46 int d_socket{-1};
47 bool d_unknownTicketKey{false};
48 bool d_resumedFromInactiveTicketKey{false};
49 };
50
51 class TLSCtx
52 {
53 public:
54 TLSCtx()
55 {
56 d_rotatingTicketsKey.clear();
57 }
58 virtual ~TLSCtx() {}
59 virtual std::unique_ptr<TLSConnection> getConnection(int socket, unsigned int timeout, time_t now) = 0;
60 virtual void rotateTicketsKey(time_t now) = 0;
61 virtual void loadTicketsKeys(const std::string& file)
62 {
63 throw std::runtime_error("This TLS backend does not have the capability to load a tickets key from a file");
64 }
65
66 void handleTicketsKeyRotation(time_t now)
67 {
68 if (d_ticketsKeyRotationDelay != 0 && now > d_ticketsKeyNextRotation) {
69 if (d_rotatingTicketsKey.test_and_set()) {
70 /* someone is already rotating */
71 return;
72 }
73 try {
74 rotateTicketsKey(now);
75 d_rotatingTicketsKey.clear();
76 }
77 catch(const std::runtime_error& e) {
78 d_rotatingTicketsKey.clear();
79 throw std::runtime_error(std::string("Error generating a new tickets key for TLS context:") + e.what());
80 }
81 catch(...) {
82 d_rotatingTicketsKey.clear();
83 throw;
84 }
85 }
86 }
87
88 time_t getNextTicketsKeyRotation() const
89 {
90 return d_ticketsKeyNextRotation;
91 }
92
93 virtual size_t getTicketsKeysCount() = 0;
94
95 protected:
96 std::atomic_flag d_rotatingTicketsKey;
97 time_t d_ticketsKeyRotationDelay{0};
98 time_t d_ticketsKeyNextRotation{0};
99 };
100
101 class TLSFrontend
102 {
103 public:
104 bool setupTLS();
105
106 void rotateTicketsKey(time_t now)
107 {
108 if (d_ctx != nullptr) {
109 d_ctx->rotateTicketsKey(now);
110 }
111 }
112
113 void loadTicketsKeys(const std::string& file)
114 {
115 if (d_ctx != nullptr) {
116 d_ctx->loadTicketsKeys(file);
117 }
118 }
119
120 std::shared_ptr<TLSCtx> getContext()
121 {
122 return d_ctx;
123 }
124
125 void cleanup()
126 {
127 d_ctx.reset();
128 }
129
130 size_t getTicketsKeysCount()
131 {
132 if (d_ctx != nullptr) {
133 return d_ctx->getTicketsKeysCount();
134 }
135
136 return 0;
137 }
138
139 static std::string timeToString(time_t rotationTime)
140 {
141 char buf[20];
142 struct tm date_tm;
143
144 localtime_r(&rotationTime, &date_tm);
145 strftime(buf, sizeof(buf), "%Y-%m-%d %H:%M:%S", &date_tm);
146
147 return std::string(buf);
148 }
149
150 time_t getTicketsKeyRotationDelay() const
151 {
152 return d_tlsConfig.d_ticketsKeyRotationDelay;
153 }
154
155 std::string getNextTicketsKeyRotation() const
156 {
157 std::string res;
158
159 if (d_ctx != nullptr) {
160 res = timeToString(d_ctx->getNextTicketsKeyRotation());
161 }
162
163 return res;
164 }
165
166 TLSConfig d_tlsConfig;
167 TLSErrorCounters d_tlsCounters;
168 ComboAddress d_addr;
169 std::string d_provider;
170
171 private:
172 std::shared_ptr<TLSCtx> d_ctx{nullptr};
173 };
174
175 class TCPIOHandler
176 {
177 public:
178
179 TCPIOHandler(int socket, unsigned int timeout, std::shared_ptr<TLSCtx> ctx, time_t now): d_socket(socket)
180 {
181 if (ctx) {
182 d_conn = ctx->getConnection(d_socket, timeout, now);
183 }
184 }
185
186 ~TCPIOHandler()
187 {
188 if (d_conn) {
189 d_conn->close();
190 }
191 else if (d_socket != -1) {
192 shutdown(d_socket, SHUT_RDWR);
193 }
194 }
195
196 IOState tryHandshake()
197 {
198 if (d_conn) {
199 return d_conn->tryHandshake();
200 }
201 return IOState::Done;
202 }
203
204 size_t read(void* buffer, size_t bufferSize, unsigned int readTimeout, unsigned int totalTimeout=0)
205 {
206 if (d_conn) {
207 return d_conn->read(buffer, bufferSize, readTimeout, totalTimeout);
208 } else {
209 return readn2WithTimeout(d_socket, buffer, bufferSize, readTimeout, totalTimeout);
210 }
211 }
212
213 /* Tries to read exactly toRead - pos bytes into the buffer, starting at position pos.
214 Updates pos everytime a successful read occurs,
215 throws an std::runtime_error in case of IO error,
216 return Done when toRead bytes have been read, needRead or needWrite if the IO operation
217 would block.
218 */
219 IOState tryRead(std::vector<uint8_t>& buffer, size_t& pos, size_t toRead)
220 {
221 if (buffer.size() < toRead || pos >= toRead) {
222 throw std::out_of_range("Calling tryRead() with a too small buffer (" + std::to_string(buffer.size()) + ") for a read of " + std::to_string(toRead - pos) + " bytes starting at " + std::to_string(pos));
223 }
224
225 if (d_conn) {
226 return d_conn->tryRead(buffer, pos, toRead);
227 }
228
229 do {
230 ssize_t res = ::read(d_socket, reinterpret_cast<char*>(&buffer.at(pos)), toRead - pos);
231 if (res == 0) {
232 throw runtime_error("EOF while reading message");
233 }
234 if (res < 0) {
235 if (errno == EAGAIN || errno == EWOULDBLOCK || errno == ENOTCONN) {
236 return IOState::NeedRead;
237 }
238 else {
239 throw std::runtime_error("Error while reading message: " + stringerror());
240 }
241 }
242
243 pos += static_cast<size_t>(res);
244 }
245 while (pos < toRead);
246
247 return IOState::Done;
248 }
249
250 /* Tries to write exactly toWrite - pos bytes from the buffer, starting at position pos.
251 Updates pos everytime a successful write occurs,
252 throws an std::runtime_error in case of IO error,
253 return Done when toWrite bytes have been written, needRead or needWrite if the IO operation
254 would block.
255 */
256 IOState tryWrite(std::vector<uint8_t>& buffer, size_t& pos, size_t toWrite)
257 {
258 if (buffer.size() < toWrite || pos >= toWrite) {
259 throw std::out_of_range("Calling tryWrite() with a too small buffer (" + std::to_string(buffer.size()) + ") for a write of " + std::to_string(toWrite - pos) + " bytes starting at " + std::to_string(pos));
260 }
261 if (d_conn) {
262 return d_conn->tryWrite(buffer, pos, toWrite);
263 }
264
265 do {
266 ssize_t res = ::write(d_socket, reinterpret_cast<const char*>(&buffer.at(pos)), toWrite - pos);
267 if (res == 0) {
268 throw runtime_error("EOF while sending message");
269 }
270 if (res < 0) {
271 if (errno == EAGAIN || errno == EWOULDBLOCK || errno == ENOTCONN) {
272 return IOState::NeedWrite;
273 }
274 else {
275 throw std::runtime_error("Error while writing message: " + stringerror());
276 }
277 }
278
279 pos += static_cast<size_t>(res);
280 }
281 while (pos < toWrite);
282
283 return IOState::Done;
284 }
285
286 size_t write(const void* buffer, size_t bufferSize, unsigned int writeTimeout)
287 {
288 if (d_conn) {
289 return d_conn->write(buffer, bufferSize, writeTimeout);
290 }
291 else {
292 return writen2WithTimeout(d_socket, buffer, bufferSize, writeTimeout);
293 }
294 }
295
296 std::string getServerNameIndication() const
297 {
298 if (d_conn) {
299 return d_conn->getServerNameIndication();
300 }
301 return std::string();
302 }
303
304 LibsslTLSVersion getTLSVersion() const
305 {
306 if (d_conn) {
307 return d_conn->getTLSVersion();
308 }
309 return LibsslTLSVersion::Unknown;
310 }
311
312 bool isTLS() const
313 {
314 return d_conn != nullptr;
315 }
316
317 bool hasTLSSessionBeenResumed() const
318 {
319 return d_conn && d_conn->hasSessionBeenResumed();
320 }
321
322 bool getResumedFromInactiveTicketKey() const
323 {
324 return d_conn && d_conn->getResumedFromInactiveTicketKey();
325 }
326
327 bool getUnknownTicketKey() const
328 {
329 return d_conn && d_conn->getUnknownTicketKey();
330 }
331
332 private:
333 std::unique_ptr<TLSConnection> d_conn{nullptr};
334 int d_socket{-1};
335 };