]>
Commit | Line | Data |
---|---|---|
a227f47d RG |
1 | |
2 | #pragma once | |
3 | #include <memory> | |
4 | ||
5 | #include "misc.hh" | |
6 | ||
d0ae6360 RG |
7 | enum class IOState { Done, NeedRead, NeedWrite }; |
8 | ||
a227f47d RG |
9 | class TLSConnection |
10 | { | |
11 | public: | |
12 | virtual ~TLSConnection() { } | |
d0ae6360 RG |
13 | virtual void doHandshake() = 0; |
14 | virtual IOState tryHandshake() = 0; | |
a227f47d RG |
15 | virtual size_t read(void* buffer, size_t bufferSize, unsigned int readTimeout, unsigned int totalTimeout=0) = 0; |
16 | virtual size_t write(const void* buffer, size_t bufferSize, unsigned int writeTimeout) = 0; | |
d0ae6360 RG |
17 | virtual IOState tryWrite(std::vector<uint8_t>& buffer, size_t& pos, size_t toWrite) = 0; |
18 | virtual IOState tryRead(std::vector<uint8_t>& buffer, size_t& pos, size_t toRead) = 0; | |
046bac5c | 19 | virtual std::string getServerNameIndication() = 0; |
a227f47d RG |
20 | virtual void close() = 0; |
21 | ||
22 | protected: | |
23 | int d_socket{-1}; | |
24 | }; | |
25 | ||
26 | class TLSCtx | |
27 | { | |
28 | public: | |
507bb0ee RG |
29 | TLSCtx() |
30 | { | |
31 | d_rotatingTicketsKey.clear(); | |
32 | } | |
a227f47d RG |
33 | virtual ~TLSCtx() {} |
34 | virtual std::unique_ptr<TLSConnection> getConnection(int socket, unsigned int timeout, time_t now) = 0; | |
35 | virtual void rotateTicketsKey(time_t now) = 0; | |
36 | virtual void loadTicketsKeys(const std::string& file) | |
37 | { | |
38 | throw std::runtime_error("This TLS backend does not have the capability to load a tickets key from a file"); | |
39 | } | |
40 | ||
41 | void handleTicketsKeyRotation(time_t now) | |
42 | { | |
43 | if (d_ticketsKeyRotationDelay != 0 && now > d_ticketsKeyNextRotation) { | |
44 | if (d_rotatingTicketsKey.test_and_set()) { | |
45 | /* someone is already rotating */ | |
46 | return; | |
47 | } | |
48 | try { | |
49 | rotateTicketsKey(now); | |
50 | d_rotatingTicketsKey.clear(); | |
51 | } | |
52 | catch(const std::runtime_error& e) { | |
53 | d_rotatingTicketsKey.clear(); | |
1f65c18c RG |
54 | throw std::runtime_error(std::string("Error generating a new tickets key for TLS context:") + e.what()); |
55 | } | |
56 | catch(...) { | |
57 | d_rotatingTicketsKey.clear(); | |
58 | throw; | |
a227f47d RG |
59 | } |
60 | } | |
61 | } | |
62 | ||
63 | time_t getNextTicketsKeyRotation() const | |
64 | { | |
65 | return d_ticketsKeyNextRotation; | |
66 | } | |
67 | ||
68 | virtual size_t getTicketsKeysCount() = 0; | |
69 | ||
70 | protected: | |
507bb0ee | 71 | std::atomic_flag d_rotatingTicketsKey; |
a227f47d RG |
72 | time_t d_ticketsKeyRotationDelay{0}; |
73 | time_t d_ticketsKeyNextRotation{0}; | |
74 | }; | |
75 | ||
76 | class TLSFrontend | |
77 | { | |
78 | public: | |
79 | bool setupTLS(); | |
80 | ||
81 | void rotateTicketsKey(time_t now) | |
82 | { | |
83 | if (d_ctx != nullptr) { | |
84 | d_ctx->rotateTicketsKey(now); | |
85 | } | |
86 | } | |
87 | ||
88 | void loadTicketsKeys(const std::string& file) | |
89 | { | |
90 | if (d_ctx != nullptr) { | |
91 | d_ctx->loadTicketsKeys(file); | |
92 | } | |
93 | } | |
94 | ||
95 | std::shared_ptr<TLSCtx> getContext() | |
96 | { | |
97 | return d_ctx; | |
98 | } | |
99 | ||
100 | void cleanup() | |
101 | { | |
102 | d_ctx.reset(); | |
103 | } | |
104 | ||
105 | size_t getTicketsKeysCount() | |
106 | { | |
107 | if (d_ctx != nullptr) { | |
108 | return d_ctx->getTicketsKeysCount(); | |
109 | } | |
110 | ||
111 | return 0; | |
112 | } | |
113 | ||
114 | static std::string timeToString(time_t rotationTime) | |
115 | { | |
116 | char buf[20]; | |
117 | struct tm date_tm; | |
118 | ||
119 | localtime_r(&rotationTime, &date_tm); | |
120 | strftime(buf, sizeof(buf), "%Y-%m-%d %H:%M:%S", &date_tm); | |
121 | ||
122 | return std::string(buf); | |
123 | } | |
124 | ||
125 | time_t getTicketsKeyRotationDelay() const | |
126 | { | |
127 | return d_ticketsKeyRotationDelay; | |
128 | } | |
129 | ||
130 | std::string getNextTicketsKeyRotation() const | |
131 | { | |
132 | std::string res; | |
133 | ||
134 | if (d_ctx != nullptr) { | |
135 | res = timeToString(d_ctx->getNextTicketsKeyRotation()); | |
136 | } | |
137 | ||
138 | return res; | |
139 | } | |
140 | ||
fa974ada | 141 | std::vector<std::pair<std::string, std::string>> d_certKeyPairs; |
a227f47d | 142 | ComboAddress d_addr; |
a227f47d | 143 | std::string d_ciphers; |
9e67ac67 | 144 | std::string d_ciphers13; |
a227f47d | 145 | std::string d_provider; |
a227f47d RG |
146 | std::string d_ticketKeyFile; |
147 | ||
d395c941 | 148 | size_t d_maxStoredSessions{20480}; |
a227f47d | 149 | time_t d_ticketsKeyRotationDelay{43200}; |
a227f47d | 150 | uint8_t d_numberOfTicketsKeys{5}; |
ba20dc97 | 151 | bool d_enableTickets{true}; |
a227f47d RG |
152 | |
153 | private: | |
154 | std::shared_ptr<TLSCtx> d_ctx{nullptr}; | |
155 | }; | |
156 | ||
157 | class TCPIOHandler | |
158 | { | |
159 | public: | |
d0ae6360 | 160 | |
a227f47d RG |
161 | TCPIOHandler(int socket, unsigned int timeout, std::shared_ptr<TLSCtx> ctx, time_t now): d_socket(socket) |
162 | { | |
163 | if (ctx) { | |
164 | d_conn = ctx->getConnection(d_socket, timeout, now); | |
165 | } | |
166 | } | |
d0ae6360 | 167 | |
a227f47d RG |
168 | ~TCPIOHandler() |
169 | { | |
170 | if (d_conn) { | |
171 | d_conn->close(); | |
172 | } | |
173 | else if (d_socket != -1) { | |
174 | shutdown(d_socket, SHUT_RDWR); | |
175 | } | |
176 | } | |
d0ae6360 RG |
177 | |
178 | IOState tryHandshake() | |
179 | { | |
180 | if (d_conn) { | |
181 | return d_conn->tryHandshake(); | |
182 | } | |
183 | return IOState::Done; | |
184 | } | |
185 | ||
a227f47d RG |
186 | size_t read(void* buffer, size_t bufferSize, unsigned int readTimeout, unsigned int totalTimeout=0) |
187 | { | |
188 | if (d_conn) { | |
189 | return d_conn->read(buffer, bufferSize, readTimeout, totalTimeout); | |
190 | } else { | |
191 | return readn2WithTimeout(d_socket, buffer, bufferSize, readTimeout, totalTimeout); | |
192 | } | |
193 | } | |
d0ae6360 RG |
194 | |
195 | /* Tries to read exactly toRead bytes into the buffer, starting at position pos. | |
196 | Updates pos everytime a successful read occurs, | |
197 | throws an std::runtime_error in case of IO error, | |
198 | return Done when toRead bytes have been read, needRead or needWrite if the IO operation | |
199 | would block. | |
200 | */ | |
201 | IOState tryRead(std::vector<uint8_t>& buffer, size_t& pos, size_t toRead) | |
202 | { | |
acadc544 RG |
203 | if (buffer.size() < (pos + toRead)) { |
204 | throw std::out_of_range("Calling tryRead() with a too small buffer (" + std::to_string(buffer.size()) + ") for a read of " + std::to_string(toRead) + " bytes starting at " + std::to_string(pos)); | |
205 | } | |
206 | ||
d0ae6360 RG |
207 | if (d_conn) { |
208 | return d_conn->tryRead(buffer, pos, toRead); | |
209 | } | |
210 | ||
211 | size_t got = 0; | |
212 | do { | |
213 | ssize_t res = ::read(d_socket, reinterpret_cast<char*>(&buffer.at(pos)), toRead - got); | |
214 | if (res == 0) { | |
215 | throw runtime_error("EOF while reading message"); | |
216 | } | |
217 | if (res < 0) { | |
218 | if (errno == EAGAIN || errno == EWOULDBLOCK) { | |
219 | return IOState::NeedRead; | |
220 | } | |
221 | else { | |
222 | throw std::runtime_error(std::string("Error while reading message: ") + strerror(errno)); | |
223 | } | |
224 | } | |
225 | ||
226 | pos += static_cast<size_t>(res); | |
227 | got += static_cast<size_t>(res); | |
228 | } | |
229 | while (got < toRead); | |
230 | ||
231 | return IOState::Done; | |
232 | } | |
233 | ||
234 | /* Tries to write exactly toWrite bytes from the buffer, starting at position pos. | |
235 | Updates pos everytime a successful write occurs, | |
236 | throws an std::runtime_error in case of IO error, | |
237 | return Done when toWrite bytes have been written, needRead or needWrite if the IO operation | |
238 | would block. | |
239 | */ | |
240 | IOState tryWrite(std::vector<uint8_t>& buffer, size_t& pos, size_t toWrite) | |
241 | { | |
242 | if (d_conn) { | |
243 | return d_conn->tryWrite(buffer, pos, toWrite); | |
244 | } | |
245 | ||
246 | size_t sent = 0; | |
247 | do { | |
248 | ssize_t res = ::write(d_socket, reinterpret_cast<char*>(&buffer.at(pos)), toWrite - sent); | |
249 | if (res == 0) { | |
250 | throw runtime_error("EOF while sending message"); | |
251 | } | |
252 | if (res < 0) { | |
253 | if (errno == EAGAIN || errno == EWOULDBLOCK) { | |
254 | return IOState::NeedWrite; | |
255 | } | |
256 | else { | |
257 | throw std::runtime_error(std::string("Error while writing message: ") + strerror(errno)); | |
258 | } | |
259 | } | |
260 | ||
261 | pos += static_cast<size_t>(res); | |
262 | sent += static_cast<size_t>(res); | |
263 | } | |
264 | while (sent < toWrite); | |
265 | ||
266 | return IOState::Done; | |
267 | } | |
268 | ||
a227f47d RG |
269 | size_t write(const void* buffer, size_t bufferSize, unsigned int writeTimeout) |
270 | { | |
271 | if (d_conn) { | |
272 | return d_conn->write(buffer, bufferSize, writeTimeout); | |
273 | } | |
274 | else { | |
275 | return writen2WithTimeout(d_socket, buffer, bufferSize, writeTimeout); | |
276 | } | |
277 | } | |
278 | ||
046bac5c RG |
279 | std::string getServerNameIndication() |
280 | { | |
281 | if (d_conn) { | |
282 | return d_conn->getServerNameIndication(); | |
283 | } | |
284 | return std::string(); | |
285 | } | |
286 | ||
a227f47d RG |
287 | private: |
288 | std::unique_ptr<TLSConnection> d_conn{nullptr}; | |
289 | int d_socket{-1}; | |
290 | }; |