]> git.ipfire.org Git - thirdparty/pdns.git/blame - pdns/tcpiohandler.hh
rec: Only log qname parsing errors when 'log-common-errors' is set
[thirdparty/pdns.git] / pdns / tcpiohandler.hh
CommitLineData
a227f47d
RG
1
2#pragma once
3#include <memory>
4
5#include "misc.hh"
6
d0ae6360
RG
7enum class IOState { Done, NeedRead, NeedWrite };
8
a227f47d
RG
9class TLSConnection
10{
11public:
12 virtual ~TLSConnection() { }
d0ae6360
RG
13 virtual void doHandshake() = 0;
14 virtual IOState tryHandshake() = 0;
a227f47d
RG
15 virtual size_t read(void* buffer, size_t bufferSize, unsigned int readTimeout, unsigned int totalTimeout=0) = 0;
16 virtual size_t write(const void* buffer, size_t bufferSize, unsigned int writeTimeout) = 0;
d0ae6360
RG
17 virtual IOState tryWrite(std::vector<uint8_t>& buffer, size_t& pos, size_t toWrite) = 0;
18 virtual IOState tryRead(std::vector<uint8_t>& buffer, size_t& pos, size_t toRead) = 0;
046bac5c 19 virtual std::string getServerNameIndication() = 0;
a227f47d
RG
20 virtual void close() = 0;
21
22protected:
23 int d_socket{-1};
24};
25
26class TLSCtx
27{
28public:
507bb0ee
RG
29 TLSCtx()
30 {
31 d_rotatingTicketsKey.clear();
32 }
a227f47d
RG
33 virtual ~TLSCtx() {}
34 virtual std::unique_ptr<TLSConnection> getConnection(int socket, unsigned int timeout, time_t now) = 0;
35 virtual void rotateTicketsKey(time_t now) = 0;
36 virtual void loadTicketsKeys(const std::string& file)
37 {
38 throw std::runtime_error("This TLS backend does not have the capability to load a tickets key from a file");
39 }
40
41 void handleTicketsKeyRotation(time_t now)
42 {
43 if (d_ticketsKeyRotationDelay != 0 && now > d_ticketsKeyNextRotation) {
44 if (d_rotatingTicketsKey.test_and_set()) {
45 /* someone is already rotating */
46 return;
47 }
48 try {
49 rotateTicketsKey(now);
50 d_rotatingTicketsKey.clear();
51 }
52 catch(const std::runtime_error& e) {
53 d_rotatingTicketsKey.clear();
1f65c18c
RG
54 throw std::runtime_error(std::string("Error generating a new tickets key for TLS context:") + e.what());
55 }
56 catch(...) {
57 d_rotatingTicketsKey.clear();
58 throw;
a227f47d
RG
59 }
60 }
61 }
62
63 time_t getNextTicketsKeyRotation() const
64 {
65 return d_ticketsKeyNextRotation;
66 }
67
68 virtual size_t getTicketsKeysCount() = 0;
69
70protected:
507bb0ee 71 std::atomic_flag d_rotatingTicketsKey;
a227f47d
RG
72 time_t d_ticketsKeyRotationDelay{0};
73 time_t d_ticketsKeyNextRotation{0};
74};
75
76class TLSFrontend
77{
78public:
79 bool setupTLS();
80
81 void rotateTicketsKey(time_t now)
82 {
83 if (d_ctx != nullptr) {
84 d_ctx->rotateTicketsKey(now);
85 }
86 }
87
88 void loadTicketsKeys(const std::string& file)
89 {
90 if (d_ctx != nullptr) {
91 d_ctx->loadTicketsKeys(file);
92 }
93 }
94
95 std::shared_ptr<TLSCtx> getContext()
96 {
97 return d_ctx;
98 }
99
100 void cleanup()
101 {
102 d_ctx.reset();
103 }
104
105 size_t getTicketsKeysCount()
106 {
107 if (d_ctx != nullptr) {
108 return d_ctx->getTicketsKeysCount();
109 }
110
111 return 0;
112 }
113
114 static std::string timeToString(time_t rotationTime)
115 {
116 char buf[20];
117 struct tm date_tm;
118
119 localtime_r(&rotationTime, &date_tm);
120 strftime(buf, sizeof(buf), "%Y-%m-%d %H:%M:%S", &date_tm);
121
122 return std::string(buf);
123 }
124
125 time_t getTicketsKeyRotationDelay() const
126 {
127 return d_ticketsKeyRotationDelay;
128 }
129
130 std::string getNextTicketsKeyRotation() const
131 {
132 std::string res;
133
134 if (d_ctx != nullptr) {
135 res = timeToString(d_ctx->getNextTicketsKeyRotation());
136 }
137
138 return res;
139 }
140
fa974ada 141 std::vector<std::pair<std::string, std::string>> d_certKeyPairs;
a227f47d 142 ComboAddress d_addr;
a227f47d 143 std::string d_ciphers;
9e67ac67 144 std::string d_ciphers13;
a227f47d 145 std::string d_provider;
a227f47d
RG
146 std::string d_ticketKeyFile;
147
d395c941 148 size_t d_maxStoredSessions{20480};
a227f47d 149 time_t d_ticketsKeyRotationDelay{43200};
a227f47d 150 uint8_t d_numberOfTicketsKeys{5};
ba20dc97 151 bool d_enableTickets{true};
a227f47d
RG
152
153private:
154 std::shared_ptr<TLSCtx> d_ctx{nullptr};
155};
156
157class TCPIOHandler
158{
159public:
d0ae6360 160
a227f47d
RG
161 TCPIOHandler(int socket, unsigned int timeout, std::shared_ptr<TLSCtx> ctx, time_t now): d_socket(socket)
162 {
163 if (ctx) {
164 d_conn = ctx->getConnection(d_socket, timeout, now);
165 }
166 }
d0ae6360 167
a227f47d
RG
168 ~TCPIOHandler()
169 {
170 if (d_conn) {
171 d_conn->close();
172 }
173 else if (d_socket != -1) {
174 shutdown(d_socket, SHUT_RDWR);
175 }
176 }
d0ae6360
RG
177
178 IOState tryHandshake()
179 {
180 if (d_conn) {
181 return d_conn->tryHandshake();
182 }
183 return IOState::Done;
184 }
185
a227f47d
RG
186 size_t read(void* buffer, size_t bufferSize, unsigned int readTimeout, unsigned int totalTimeout=0)
187 {
188 if (d_conn) {
189 return d_conn->read(buffer, bufferSize, readTimeout, totalTimeout);
190 } else {
191 return readn2WithTimeout(d_socket, buffer, bufferSize, readTimeout, totalTimeout);
192 }
193 }
d0ae6360
RG
194
195 /* Tries to read exactly toRead bytes into the buffer, starting at position pos.
196 Updates pos everytime a successful read occurs,
197 throws an std::runtime_error in case of IO error,
198 return Done when toRead bytes have been read, needRead or needWrite if the IO operation
199 would block.
200 */
201 IOState tryRead(std::vector<uint8_t>& buffer, size_t& pos, size_t toRead)
202 {
acadc544
RG
203 if (buffer.size() < (pos + toRead)) {
204 throw std::out_of_range("Calling tryRead() with a too small buffer (" + std::to_string(buffer.size()) + ") for a read of " + std::to_string(toRead) + " bytes starting at " + std::to_string(pos));
205 }
206
d0ae6360
RG
207 if (d_conn) {
208 return d_conn->tryRead(buffer, pos, toRead);
209 }
210
211 size_t got = 0;
212 do {
213 ssize_t res = ::read(d_socket, reinterpret_cast<char*>(&buffer.at(pos)), toRead - got);
214 if (res == 0) {
215 throw runtime_error("EOF while reading message");
216 }
217 if (res < 0) {
218 if (errno == EAGAIN || errno == EWOULDBLOCK) {
219 return IOState::NeedRead;
220 }
221 else {
222 throw std::runtime_error(std::string("Error while reading message: ") + strerror(errno));
223 }
224 }
225
226 pos += static_cast<size_t>(res);
227 got += static_cast<size_t>(res);
228 }
229 while (got < toRead);
230
231 return IOState::Done;
232 }
233
234 /* Tries to write exactly toWrite bytes from the buffer, starting at position pos.
235 Updates pos everytime a successful write occurs,
236 throws an std::runtime_error in case of IO error,
237 return Done when toWrite bytes have been written, needRead or needWrite if the IO operation
238 would block.
239 */
240 IOState tryWrite(std::vector<uint8_t>& buffer, size_t& pos, size_t toWrite)
241 {
242 if (d_conn) {
243 return d_conn->tryWrite(buffer, pos, toWrite);
244 }
245
246 size_t sent = 0;
247 do {
248 ssize_t res = ::write(d_socket, reinterpret_cast<char*>(&buffer.at(pos)), toWrite - sent);
249 if (res == 0) {
250 throw runtime_error("EOF while sending message");
251 }
252 if (res < 0) {
253 if (errno == EAGAIN || errno == EWOULDBLOCK) {
254 return IOState::NeedWrite;
255 }
256 else {
257 throw std::runtime_error(std::string("Error while writing message: ") + strerror(errno));
258 }
259 }
260
261 pos += static_cast<size_t>(res);
262 sent += static_cast<size_t>(res);
263 }
264 while (sent < toWrite);
265
266 return IOState::Done;
267 }
268
a227f47d
RG
269 size_t write(const void* buffer, size_t bufferSize, unsigned int writeTimeout)
270 {
271 if (d_conn) {
272 return d_conn->write(buffer, bufferSize, writeTimeout);
273 }
274 else {
275 return writen2WithTimeout(d_socket, buffer, bufferSize, writeTimeout);
276 }
277 }
278
046bac5c
RG
279 std::string getServerNameIndication()
280 {
281 if (d_conn) {
282 return d_conn->getServerNameIndication();
283 }
284 return std::string();
285 }
286
a227f47d
RG
287private:
288 std::unique_ptr<TLSConnection> d_conn{nullptr};
289 int d_socket{-1};
290};