#include "lib/layer/iterate.h" /* kr_response_classify */
#include "lib/cache/util.h"
+#include "contrib/cleanup.h"
#include "contrib/base64url.h"
#define MAKE_NV(K, KS, V, VS) \
return 0;
}
+/*
+ * Check endpoint and uri path
+ */
+static int check_uri(const char* uri_path)
+{
+ static const char key[] = "dns=";
+ static const char *delim = "&";
+ static const char *endpoins[] = {"dns-query", "doh"};
+ char *beg;
+ char *end_prev;
+ ssize_t endpoint_len;
+ ssize_t ret;
+
+ if (!uri_path)
+ return kr_error(EINVAL);
+
+ auto_free char *path = malloc(sizeof(*path) * (strlen(uri_path) + 1));
+ if (!path)
+ return kr_error(ENOMEM);
+
+ memcpy(path, uri_path, strlen(uri_path));
+ path[strlen(uri_path)] = '\0';
+
+ char *query_mark = strstr(path, "?");
+
+ /* calculating of endpoint_len - for POST or GET method */
+ endpoint_len = (query_mark) ? query_mark - path - 1 : strlen(path) - 1;
+
+ /* check endpoint */
+ ret = -1;
+ for(int i = 0; i < sizeof(endpoins)/sizeof(*endpoins); i++)
+ {
+ if (strlen(endpoins[i]) != endpoint_len)
+ continue;
+ ret = strncmp(path + 1, endpoins[i], strlen(endpoins[i]));
+ if (!ret)
+ break;
+ }
+
+ if (ret) /* no endpoint found */
+ return -1;
+ if (endpoint_len == strlen(path) - 1) /* done for POST method */
+ return 0;
+
+ /* go over key:value pair */
+ beg = strtok(query_mark + 1, delim);
+ if (beg) {
+ while (beg != NULL) {
+ if (!strncmp(beg, key, 4)) { /* dns variable in path found */
+ break;
+ }
+ end_prev = beg + strlen(beg);
+ beg = strtok(NULL, delim);
+ if (beg-1 != end_prev) { /* detect && */
+ return -1;
+ }
+ }
+
+ if (!beg) { /* no dns variable in path */
+ return -1;
+ }
+ }
+
+ return 0;
+}
+
/*
* Process a query from URI path if there's base64url encoded dns variable.
*/
struct http_ctx *ctx = (struct http_ctx *)user_data;
int32_t stream_id = frame->hd.stream_id;
-
if (frame->hd.type != NGHTTP2_HEADERS)
return 0;
}
if (!strcasecmp(":path", (const char *)name)) {
+ if (check_uri((const char *)value) < 0) {
+ refuse_stream(h2, stream_id);
+ return 0;
+ }
+
ctx->uri_path = malloc(sizeof(*ctx->uri_path) * (valuelen + 1));
if (!ctx->uri_path)
return kr_error(ENOMEM);
local function start_server()
local request = require('http.request')
local ssl_ctx = require('openssl.ssl.context')
- uri_templ = string.format('https://%s:%d/dns_query', host, port)
+ uri_templ = string.format('https://%s:%d/dns-query', host, port)
req_templ = assert(request.new_from_uri(uri_templ))
req_templ.headers:upsert('content-type', 'application/dns-message')
req_templ.ctx = ssl_ctx.new()