/* called from the backend code when a new response has been received */
void IncomingTCPConnectionState::handleResponse(const struct timeval& now, TCPResponse&& response)
{
+ if (std::this_thread::get_id() != d_mainThreadID) {
+ handleCrossProtocolResponse(now, std::move(response));
+ return;
+ }
+
std::shared_ptr<IncomingTCPConnectionState> state = shared_from_this();
if (response.d_connection && response.d_connection->getDS() && response.d_connection->getDS()->d_config.useProxyProtocol) {
struct timeval d_now;
};
-class TCPCrossProtocolQuerySender : public TCPQuerySender
-{
-public:
- TCPCrossProtocolQuerySender(std::shared_ptr<IncomingTCPConnectionState>& state): d_state(state)
- {
- }
-
- bool active() const override
- {
- return d_state->active();
- }
-
- const ClientState* getClientState() const override
- {
- return d_state->getClientState();
- }
-
- void handleResponse(const struct timeval& now, TCPResponse&& response) override
- {
- if (d_state->d_threadData.crossProtocolResponsesPipe == -1) {
- throw std::runtime_error("Invalid pipe descriptor in TCP Cross Protocol Query Sender");
- }
-
- auto ptr = new TCPCrossProtocolResponse(std::move(response), d_state, now);
- static_assert(sizeof(ptr) <= PIPE_BUF, "Writes up to PIPE_BUF are guaranteed not to be interleaved and to either fully succeed or fail");
- ssize_t sent = write(d_state->d_threadData.crossProtocolResponsesPipe, &ptr, sizeof(ptr));
- if (sent != sizeof(ptr)) {
- if (errno == EAGAIN || errno == EWOULDBLOCK) {
- ++g_stats.tcpCrossProtocolResponsePipeFull;
- vinfolog("Unable to pass a cross-protocol response to the TCP worker thread because the pipe is full");
- }
- else {
- vinfolog("Unable to pass a cross-protocol response to the TCP worker thread because we couldn't write to the pipe: %s", stringerror());
- }
- delete ptr;
- }
- }
-
- void handleXFRResponse(const struct timeval& now, TCPResponse&& response) override
- {
- handleResponse(now, std::move(response));
- }
-
- void notifyIOError(IDState&& query, const struct timeval& now) override
- {
- TCPResponse response(PacketBuffer(), std::move(query), nullptr);
- handleResponse(now, std::move(response));
- }
-
-private:
- std::shared_ptr<IncomingTCPConnectionState> d_state;
-};
-
class TCPCrossProtocolQuery : public CrossProtocolQuery
{
public:
- TCPCrossProtocolQuery(PacketBuffer&& buffer, IDState&& ids, std::shared_ptr<DownstreamState>& ds, std::shared_ptr<TCPCrossProtocolQuerySender>& sender): d_sender(sender)
+ TCPCrossProtocolQuery(PacketBuffer&& buffer, IDState&& ids, std::shared_ptr<DownstreamState> ds, std::shared_ptr<IncomingTCPConnectionState> sender): CrossProtocolQuery(InternalQuery(std::move(buffer), std::move(ids)), ds), d_sender(std::move(sender))
{
- query = InternalQuery(std::move(buffer), std::move(ids));
- downstream = ds;
proxyProtocolPayloadSize = 0;
}
}
private:
- std::shared_ptr<TCPCrossProtocolQuerySender> d_sender;
+ std::shared_ptr<IncomingTCPConnectionState> d_sender;
};
+void IncomingTCPConnectionState::handleCrossProtocolResponse(const struct timeval& now, TCPResponse&& response)
+{
+ if (d_threadData.crossProtocolResponsesPipe == -1) {
+ throw std::runtime_error("Invalid pipe descriptor in TCP Cross Protocol Query Sender");
+ }
+
+ std::shared_ptr<IncomingTCPConnectionState> state = shared_from_this();
+ auto ptr = new TCPCrossProtocolResponse(std::move(response), state, now);
+ static_assert(sizeof(ptr) <= PIPE_BUF, "Writes up to PIPE_BUF are guaranteed not to be interleaved and to either fully succeed or fail");
+ ssize_t sent = write(d_threadData.crossProtocolResponsesPipe, &ptr, sizeof(ptr));
+ if (sent != sizeof(ptr)) {
+ if (errno == EAGAIN || errno == EWOULDBLOCK) {
+ ++g_stats.tcpCrossProtocolResponsePipeFull;
+ vinfolog("Unable to pass a cross-protocol response to the TCP worker thread because the pipe is full");
+ }
+ else {
+ vinfolog("Unable to pass a cross-protocol response to the TCP worker thread because we couldn't write to the pipe: %s", stringerror());
+ }
+ delete ptr;
+ }
+}
+
static void handleQuery(std::shared_ptr<IncomingTCPConnectionState>& state, const struct timeval& now)
{
if (state->d_querySize < sizeof(dnsheader)) {
proxyProtocolPayload = getProxyProtocolPayload(dq);
}
- auto incoming = std::make_shared<TCPCrossProtocolQuerySender>(state);
- auto cpq = std::make_unique<TCPCrossProtocolQuery>(std::move(state->d_buffer), std::move(ids), ds, incoming);
+ auto cpq = std::make_unique<TCPCrossProtocolQuery>(std::move(state->d_buffer), std::move(ids), ds, state);
cpq->query.d_proxyProtocolPayload = std::move(proxyProtocolPayload);
ds->passCrossProtocolQuery(std::move(cpq));
class IncomingTCPConnectionState : public TCPQuerySender, public std::enable_shared_from_this<IncomingTCPConnectionState>
{
public:
- IncomingTCPConnectionState(ConnectionInfo&& ci, TCPClientThreadData& threadData, const struct timeval& now): d_buffer(s_maxPacketCacheEntrySize), d_ci(std::move(ci)), d_handler(d_ci.fd, timeval{g_tcpRecvTimeout,0}, d_ci.cs->tlsFrontend ? d_ci.cs->tlsFrontend->getContext() : nullptr, now.tv_sec), d_connectionStartTime(now), d_ioState(make_unique<IOStateHandler>(*threadData.mplexer, d_ci.fd)), d_threadData(threadData)
+ IncomingTCPConnectionState(ConnectionInfo&& ci, TCPClientThreadData& threadData, const struct timeval& now): d_buffer(s_maxPacketCacheEntrySize), d_ci(std::move(ci)), d_handler(d_ci.fd, timeval{g_tcpRecvTimeout,0}, d_ci.cs->tlsFrontend ? d_ci.cs->tlsFrontend->getContext() : nullptr, now.tv_sec), d_connectionStartTime(now), d_ioState(make_unique<IOStateHandler>(*threadData.mplexer, d_ci.fd)), d_threadData(threadData), d_mainThreadID(std::this_thread::get_id())
{
d_origDest.reset();
d_origDest.sin4.sin_family = d_ci.remote.sin4.sin_family;
void handleXFRResponse(const struct timeval& now, TCPResponse&& response) override;
void notifyIOError(IDState&& query, const struct timeval& now) override;
+ void handleCrossProtocolResponse(const struct timeval& now, TCPResponse&& response);
+
void terminateClientConnection();
void queueQuery(TCPQuery&& query);
size_t d_proxyProtocolNeed{0};
size_t d_queriesCount{0};
size_t d_currentQueriesCount{0};
+ std::thread::id d_mainThreadID;
uint16_t d_querySize{0};
State d_state{State::doingHandshake};
bool d_isXFR{false};