8 enum class IOState { Done, NeedRead, NeedWrite };
13 virtual ~TLSConnection() { }
14 virtual void doHandshake() = 0;
15 virtual IOState tryHandshake() = 0;
16 virtual size_t read(void* buffer, size_t bufferSize, unsigned int readTimeout, unsigned int totalTimeout=0) = 0;
17 virtual size_t write(const void* buffer, size_t bufferSize, unsigned int writeTimeout) = 0;
18 virtual IOState tryWrite(std::vector<uint8_t>& buffer, size_t& pos, size_t toWrite) = 0;
19 virtual IOState tryRead(std::vector<uint8_t>& buffer, size_t& pos, size_t toRead) = 0;
20 virtual std::string getServerNameIndication() const = 0;
21 virtual LibsslTLSVersion getTLSVersion() const = 0;
22 virtual bool hasSessionBeenResumed() const = 0;
23 virtual void close() = 0;
25 void setUnknownTicketKey()
27 d_unknownTicketKey = true;
30 bool getUnknownTicketKey() const
32 return d_unknownTicketKey;
35 void setResumedFromInactiveTicketKey()
37 d_resumedFromInactiveTicketKey = true;
40 bool getResumedFromInactiveTicketKey() const
42 return d_resumedFromInactiveTicketKey;
47 bool d_unknownTicketKey{false};
48 bool d_resumedFromInactiveTicketKey{false};
56 d_rotatingTicketsKey.clear();
59 virtual std::unique_ptr<TLSConnection> getConnection(int socket, unsigned int timeout, time_t now) = 0;
60 virtual void rotateTicketsKey(time_t now) = 0;
61 virtual void loadTicketsKeys(const std::string& file)
63 throw std::runtime_error("This TLS backend does not have the capability to load a tickets key from a file");
66 void handleTicketsKeyRotation(time_t now)
68 if (d_ticketsKeyRotationDelay != 0 && now > d_ticketsKeyNextRotation) {
69 if (d_rotatingTicketsKey.test_and_set()) {
70 /* someone is already rotating */
74 rotateTicketsKey(now);
75 d_rotatingTicketsKey.clear();
77 catch(const std::runtime_error& e) {
78 d_rotatingTicketsKey.clear();
79 throw std::runtime_error(std::string("Error generating a new tickets key for TLS context:") + e.what());
82 d_rotatingTicketsKey.clear();
88 time_t getNextTicketsKeyRotation() const
90 return d_ticketsKeyNextRotation;
93 virtual size_t getTicketsKeysCount() = 0;
96 std::atomic_flag d_rotatingTicketsKey;
97 time_t d_ticketsKeyRotationDelay{0};
98 time_t d_ticketsKeyNextRotation{0};
106 void rotateTicketsKey(time_t now)
108 if (d_ctx != nullptr) {
109 d_ctx->rotateTicketsKey(now);
113 void loadTicketsKeys(const std::string& file)
115 if (d_ctx != nullptr) {
116 d_ctx->loadTicketsKeys(file);
120 std::shared_ptr<TLSCtx> getContext()
130 size_t getTicketsKeysCount()
132 if (d_ctx != nullptr) {
133 return d_ctx->getTicketsKeysCount();
139 static std::string timeToString(time_t rotationTime)
144 localtime_r(&rotationTime, &date_tm);
145 strftime(buf, sizeof(buf), "%Y-%m-%d %H:%M:%S", &date_tm);
147 return std::string(buf);
150 time_t getTicketsKeyRotationDelay() const
152 return d_tlsConfig.d_ticketsKeyRotationDelay;
155 std::string getNextTicketsKeyRotation() const
159 if (d_ctx != nullptr) {
160 res = timeToString(d_ctx->getNextTicketsKeyRotation());
166 TLSConfig d_tlsConfig;
167 TLSErrorCounters d_tlsCounters;
169 std::string d_provider;
172 std::shared_ptr<TLSCtx> d_ctx{nullptr};
179 TCPIOHandler(int socket, unsigned int timeout, std::shared_ptr<TLSCtx> ctx, time_t now): d_socket(socket)
182 d_conn = ctx->getConnection(d_socket, timeout, now);
191 else if (d_socket != -1) {
192 shutdown(d_socket, SHUT_RDWR);
196 IOState tryHandshake()
199 return d_conn->tryHandshake();
201 return IOState::Done;
204 size_t read(void* buffer, size_t bufferSize, unsigned int readTimeout, unsigned int totalTimeout=0)
207 return d_conn->read(buffer, bufferSize, readTimeout, totalTimeout);
209 return readn2WithTimeout(d_socket, buffer, bufferSize, readTimeout, totalTimeout);
213 /* Tries to read exactly toRead - pos bytes into the buffer, starting at position pos.
214 Updates pos everytime a successful read occurs,
215 throws an std::runtime_error in case of IO error,
216 return Done when toRead bytes have been read, needRead or needWrite if the IO operation
219 IOState tryRead(std::vector<uint8_t>& buffer, size_t& pos, size_t toRead)
221 if (buffer.size() < toRead || pos >= toRead) {
222 throw std::out_of_range("Calling tryRead() with a too small buffer (" + std::to_string(buffer.size()) + ") for a read of " + std::to_string(toRead - pos) + " bytes starting at " + std::to_string(pos));
226 return d_conn->tryRead(buffer, pos, toRead);
230 ssize_t res = ::read(d_socket, reinterpret_cast<char*>(&buffer.at(pos)), toRead - pos);
232 throw runtime_error("EOF while reading message");
235 if (errno == EAGAIN || errno == EWOULDBLOCK || errno == ENOTCONN) {
236 return IOState::NeedRead;
239 throw std::runtime_error("Error while reading message: " + stringerror());
243 pos += static_cast<size_t>(res);
245 while (pos < toRead);
247 return IOState::Done;
250 /* Tries to write exactly toWrite - pos bytes from the buffer, starting at position pos.
251 Updates pos everytime a successful write occurs,
252 throws an std::runtime_error in case of IO error,
253 return Done when toWrite bytes have been written, needRead or needWrite if the IO operation
256 IOState tryWrite(std::vector<uint8_t>& buffer, size_t& pos, size_t toWrite)
258 if (buffer.size() < toWrite || pos >= toWrite) {
259 throw std::out_of_range("Calling tryWrite() with a too small buffer (" + std::to_string(buffer.size()) + ") for a write of " + std::to_string(toWrite - pos) + " bytes starting at " + std::to_string(pos));
262 return d_conn->tryWrite(buffer, pos, toWrite);
266 ssize_t res = ::write(d_socket, reinterpret_cast<const char*>(&buffer.at(pos)), toWrite - pos);
268 throw runtime_error("EOF while sending message");
271 if (errno == EAGAIN || errno == EWOULDBLOCK || errno == ENOTCONN) {
272 return IOState::NeedWrite;
275 throw std::runtime_error("Error while writing message: " + stringerror());
279 pos += static_cast<size_t>(res);
281 while (pos < toWrite);
283 return IOState::Done;
286 size_t write(const void* buffer, size_t bufferSize, unsigned int writeTimeout)
289 return d_conn->write(buffer, bufferSize, writeTimeout);
292 return writen2WithTimeout(d_socket, buffer, bufferSize, writeTimeout);
296 std::string getServerNameIndication() const
299 return d_conn->getServerNameIndication();
301 return std::string();
304 LibsslTLSVersion getTLSVersion() const
307 return d_conn->getTLSVersion();
309 return LibsslTLSVersion::Unknown;
314 return d_conn != nullptr;
317 bool hasTLSSessionBeenResumed() const
319 return d_conn && d_conn->hasSessionBeenResumed();
322 bool getResumedFromInactiveTicketKey() const
324 return d_conn && d_conn->getResumedFromInactiveTicketKey();
327 bool getUnknownTicketKey() const
329 return d_conn && d_conn->getUnknownTicketKey();
333 std::unique_ptr<TLSConnection> d_conn{nullptr};