]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Add Lua bindings to access the HTTP path and headers
authorRemi Gacogne <remi.gacogne@powerdns.com>
Fri, 2 Aug 2019 09:35:19 +0000 (11:35 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Mon, 5 Aug 2019 15:31:09 +0000 (17:31 +0200)
pdns/dnsdist-lua-actions.cc
pdns/dnsdist-lua-bindings-dnsquestion.cc
pdns/dnsdistdist/doh.cc
pdns/doh.hh

index 207bbe688ca9dc2aa6128800044e27116c49c99c..40c5412bb8c7ef80c809015bb8653ea62ad6322d 100644 (file)
@@ -1091,7 +1091,7 @@ public:
       return Action::None;
     }
 
-    DOHSetHTTPResponse(*dq->du, d_code, d_reason, d_body);
+    dq->du->setHTTPResponse(d_code, d_reason, d_body);
     dq->dh->qr = true; // for good measure
     return Action::HeaderModify;
   }
index 1392397e89c71874d3eb0ec67172a07182d1ab88..50685e6d269c98c72fb3c39bd6a9cba5e734b68f 100644 (file)
@@ -176,4 +176,48 @@ void setupLuaBindingsDNSQuestion()
       }
 #endif /* HAVE_NET_SNMP */
     });
+
+#ifdef HAVE_DNS_OVER_HTTPS
+    g_lua.registerFunction<std::string(DNSQuestion::*)(void)>("getHTTPPath", [](const DNSQuestion& dq) {
+      if (dq.du == nullptr) {
+        return std::string();
+      }
+      return dq.du->getHTTPPath();
+    });
+
+    g_lua.registerFunction<std::string(DNSQuestion::*)(void)>("getHTTPQueryString", [](const DNSQuestion& dq) {
+      if (dq.du == nullptr) {
+        return std::string();
+      }
+      return dq.du->getHTTPQueryString();
+    });
+
+    g_lua.registerFunction<std::string(DNSQuestion::*)(void)>("getHTTPHost", [](const DNSQuestion& dq) {
+      if (dq.du == nullptr) {
+        return std::string();
+      }
+      return dq.du->getHTTPHost();
+    });
+
+    g_lua.registerFunction<std::string(DNSQuestion::*)(void)>("getHTTPScheme", [](const DNSQuestion& dq) {
+      if (dq.du == nullptr) {
+        return std::string();
+      }
+      return dq.du->getHTTPScheme();
+    });
+
+    g_lua.registerFunction<std::unordered_map<std::string, std::string>(DNSQuestion::*)(void)>("getHTTPHeaders", [](const DNSQuestion& dq) {
+      if (dq.du == nullptr) {
+        return std::unordered_map<std::string, std::string>();
+      }
+      return dq.du->getHTTPHeaders();
+    });
+
+    g_lua.registerFunction<void(DNSQuestion::*)(uint16_t statusCode, std::string reason, std::string body)>("setHTTPResponse", [](DNSQuestion& dq, uint16_t statusCode, std::string reason, std::string body) {
+      if (dq.du == nullptr) {
+        return;
+      }
+      dq.du->setHTTPResponse(statusCode, reason, body);
+    });
+#endif /* HAVE_DNS_OVER_HTTPS */
 }
index 0ee6db985a8c25eda9e9ca127a4a3ce9762bef05..08fa6c6ecaa98bb9e583edd6d52a917d0962b24d 100644 (file)
@@ -217,7 +217,9 @@ static int processDOHQuery(DOHUnit* du)
     }
 
     if (result == ProcessQueryResult::SendAnswer) {
-      du->response = std::string(reinterpret_cast<char*>(dq.dh), dq.len);
+      if (du->response.empty()) {
+        du->response = std::string(reinterpret_cast<char*>(dq.dh), dq.len);
+      }
       send(du->rsock, &du, sizeof(du), 0);
       return 0;
     }
@@ -544,11 +546,11 @@ HTTPHeaderRule::HTTPHeaderRule(const std::string& header, const std::string& reg
 
 bool HTTPHeaderRule::matches(const DNSQuestion* dq) const
 {
-  if(!dq->du) {
+  if (!dq->du) {
     return false;
   }
 
-  for (unsigned int i = 0; i != dq->du->req->headers.size; ++i) {
+  for (size_t i = 0; i < dq->du->req->headers.size; ++i) {
     if(std::string(dq->du->req->headers.entries[i].name->base, dq->du->req->headers.entries[i].name->len) == d_header &&
        d_regex.match(std::string(dq->du->req->headers.entries[i].value.base, dq->du->req->headers.entries[i].value.len))) {
       return true;
@@ -593,29 +595,70 @@ HTTPPathRegexRule::HTTPPathRegexRule(const std::string& regex): d_regex(regex),
 
 bool HTTPPathRegexRule::matches(const DNSQuestion* dq) const
 {
-  if(!dq->du) {
+  if (!dq->du) {
     return false;
   }
 
-  if(dq->du->req->query_at == SIZE_MAX) {
-    return d_regex.match(std::string(dq->du->req->path.base, dq->du->req->path.len));
+  return d_regex.match(dq->du->getHTTPPath());
+}
+
+string HTTPPathRegexRule::toString() const
+{
+  return d_visual;
+}
+
+std::unordered_map<std::string, std::string> DOHUnit::getHTTPHeaders() const
+{
+  std::unordered_map<std::string, std::string> results;
+  results.reserve(req->headers.size);
+
+  for (size_t i = 0; i < req->headers.size; ++i) {
+    results.insert({std::string(req->headers.entries[i].name->base, req->headers.entries[i].name->len),
+                    std::string(req->headers.entries[i].value.base, req->headers.entries[i].value.len)});
+  }
+
+  return results;
+}
+
+std::string DOHUnit::getHTTPPath() const
+{
+  if (req->query_at == SIZE_MAX) {
+    return std::string(req->path.base, req->path.len);
   }
   else {
-    cerr<<std::string(dq->du->req->path.base, dq->du->req->path.len - dq->du->req->query_at)<<endl;
-    return d_regex.match(std::string(dq->du->req->path.base, dq->du->req->path.len - dq->du->req->query_at));
+    return std::string(req->path.base, req->query_at);
   }
 }
 
-string HTTPPathRegexRule::toString() const
+std::string DOHUnit::getHTTPHost() const
 {
-  return d_visual;
+  return std::string(req->authority.base, req->authority.len);
 }
 
-void DOHSetHTTPResponse(DOHUnit& du, uint16_t statusCode, const std::string& reason, const std::string& body)
+std::string DOHUnit::getHTTPScheme() const
 {
-  du.status_code = statusCode;
-  du.reason = reason;
-  du.body = body;
+  if (req->scheme == nullptr) {
+    return std::string();
+  }
+
+  return std::string(req->scheme->name.base, req->scheme->name.len);
+}
+
+std::string DOHUnit::getHTTPQueryString() const
+{
+  if (req->query_at == SIZE_MAX) {
+    return std::string();
+  }
+  else {
+    return std::string(req->path.base + req->query_at, req->path.len - req->query_at);
+  }
+}
+
+void DOHUnit::setHTTPResponse(uint16_t statusCode, const std::string& reason_, const std::string& body_)
+{
+  status_code = statusCode;
+  reason = reason_;
+  response = body_;
 }
 
 void dnsdistclient(int qsock, int rsock)
@@ -693,34 +736,33 @@ static void on_dnsdist(h2o_socket_t *listener, const char *err)
     du->req->res.status = 200;
     du->req->res.reason = "OK";
 
-    h2o_add_header(&du->req->pool, &du->req->res.headers, H2O_TOKEN_CONTENT_TYPE, nullptr, H2O_STRLIT("application/dns-message"));
-
     //    struct dnsheader* dh = (struct dnsheader*)du->query.c_str();
     //    cout<<"Attempt to send out "<<du->query.size()<<" bytes over https, TC="<<dh->tc<<", RCODE="<<dh->rcode<<", qtype="<<du->qtype<<", req="<<(void*)du->req<<endl;
 
+    h2o_add_header(&du->req->pool, &du->req->res.headers, H2O_TOKEN_CONTENT_TYPE, nullptr, H2O_STRLIT("application/dns-message"));
     du->req->res.content_length = du->response.size();
     h2o_send_inline(du->req, du->response.c_str(), du->response.size());
   }
   else if (du->status_code >= 300 && du->status_code < 400) {
-    /* in that case the body is actually a URL */
-    h2o_send_redirect(du->req, du->status_code, du->reason.c_str(), du->body.c_str(), du->body.size());
+    /* in that case the response is actually a URL */
+    h2o_send_redirect(du->req, du->status_code, du->reason.c_str(), du->response.c_str(), du->response.size());
     ++dsc->df->d_redirectresponses;
   }
   else {
     switch(du->status_code) {
     case 400:
-      h2o_send_error_400(du->req, du->reason.empty() ? "Bad Request" : du->reason.c_str(), du->body.empty() ? "invalid DNS query" : du->body.c_str(), 0);
+      h2o_send_error_400(du->req, du->reason.empty() ? "Bad Request" : du->reason.c_str(), du->response.empty() ? "invalid DNS query" : du->response.c_str(), 0);
       break;
     case 403:
-      h2o_send_error_403(du->req, du->reason.empty() ? "Forbidden" : du->reason.c_str(), du->body.empty() ? "dns query not allowed" : du->body.c_str(), 0);
+      h2o_send_error_403(du->req, du->reason.empty() ? "Forbidden" : du->reason.c_str(), du->response.empty() ? "dns query not allowed" : du->response.c_str(), 0);
       break;
     case 502:
-      h2o_send_error_502(du->req, du->reason.empty() ? "Bad Gateway" : du->reason.c_str(), du->body.empty() ? "no downstream server available" : du->body.c_str(), 0);
+      h2o_send_error_502(du->req, du->reason.empty() ? "Bad Gateway" : du->reason.c_str(), du->response.empty() ? "no downstream server available" : du->response.c_str(), 0);
       break;
     case 500:
       /* fall-through */
     default:
-      h2o_send_error_500(du->req, du->reason.empty() ? "Internal Server Error" : du->reason.c_str(), du->body.empty() ? "Internal Server Error" : du->body.c_str(), 0);
+      h2o_send_error_500(du->req, du->reason.empty() ? "Internal Server Error" : du->reason.c_str(), du->response.empty() ? "Internal Server Error" : du->response.c_str(), 0);
       break;
     }
 
index cc16ccfb5abf15ec3088d4eca2a0740112f5bec8..c1739d987fac44c7ad9942c7c403c0a1b8e6d396 100644 (file)
@@ -64,6 +64,8 @@ struct DOHUnit
 };
 
 #else /* HAVE_DNS_OVER_HTTPS */
+#include <unordered_map>
+
 struct st_h2o_req_t;
 
 struct DOHUnit
@@ -75,7 +77,6 @@ struct DOHUnit
   st_h2o_req_t* req{nullptr};
   DOHUnit** self{nullptr};
   std::string reason;
-  std::string body;
   int rsock;
   uint16_t qtype;
   /* the status_code is set from
@@ -87,9 +88,14 @@ struct DOHUnit
   */
   uint16_t status_code{200};
   bool ednsAdded{false};
-};
 
-void DOHSetHTTPResponse(DOHUnit& du, uint16_t statusCode, const std::string& reason, const std::string& body);
+  std::string getHTTPPath() const;
+  std::string getHTTPHost() const;
+  std::string getHTTPScheme() const;
+  std::string getHTTPQueryString() const;
+  std::unordered_map<std::string, std::string> getHTTPHeaders() const;
+  void setHTTPResponse(uint16_t statusCode, const std::string& reason, const std::string& body);
+};
 
 #endif /* HAVE_DNS_OVER_HTTPS  */