From 3c3c006b00b809efb9dd7c70a4cedf599c3da6e4 Mon Sep 17 00:00:00 2001 From: Christian Hofstaedtler Date: Wed, 5 Feb 2014 16:06:44 +0100 Subject: [PATCH] Encode symbols in zone names for REST API URLs --- pdns/ws-api.cc | 59 ++++++++++++++++++++++++++++++ pdns/ws-api.hh | 4 ++ pdns/ws-auth.cc | 21 ++++++++--- pdns/ws-auth.hh | 8 ---- regression-tests.api/test_Zones.py | 13 ++++++- 5 files changed, 89 insertions(+), 16 deletions(-) diff --git a/pdns/ws-api.cc b/pdns/ws-api.cc index b0e097b3c2..4a25319c9a 100644 --- a/pdns/ws-api.cc +++ b/pdns/ws-api.cc @@ -31,6 +31,7 @@ #include #include #include +#include #ifndef HAVE_STRCASESTR @@ -222,3 +223,61 @@ void apiServerStatistics(HttpRequest* req, HttpResponse* resp) { resp->setBody(doc); } + +string apiZoneIdToName(const string& id) { + string zonename; + ostringstream ss; + + if(id.empty()) + throw HttpBadRequestException(); + + std::size_t lastpos = 0, pos = 0; + while ((pos = id.find('=', lastpos)) != string::npos) { + ss << id.substr(lastpos, pos-lastpos); + if ((id[pos+1] >= '0' && id[pos+1] <= '9') && + (id[pos+2] >= '0' && id[pos+2] <= '9')) { + char c = ((id[pos+1] - '0')*10) + (id[pos+2] - '0'); + ss << c; + } else { + throw HttpBadRequestException(); + } + + lastpos = pos+3; + } + if (lastpos < pos) { + ss << id.substr(lastpos, pos-lastpos); + } + + zonename = ss.str(); + + // strip trailing dot + if (zonename.substr(zonename.size()-1) == ".") { + zonename = zonename.substr(0, zonename.size()-1); + } + return zonename; +} + +string apiZoneNameToId(const string& name) { + ostringstream ss; + + for(string::const_iterator iter = name.begin(); iter != name.end(); ++iter) { + if ((*iter >= 'A' && *iter <= 'Z') || + (*iter >= 'a' && *iter <= 'z') || + (*iter >= '0' && *iter <= '9') || + (*iter == '.') || (*iter == '-')) { + ss << *iter; + } else { + ss << "=" << std::setfill('0') << std::setw(2) << (int)(*iter); + } + } + + // add trailing dot + string id = ss.str() + "."; + + // special handling for the root zone, as a dot on it's own doesn't work + // everywhere. + if (id == ".") { + id = (boost::format("=%d") % (int)('.')).str(); + } + return id; +} diff --git a/pdns/ws-api.hh b/pdns/ws-api.hh index 3f28212e67..76cfc522fe 100644 --- a/pdns/ws-api.hh +++ b/pdns/ws-api.hh @@ -31,5 +31,9 @@ void apiServerConfig(HttpRequest* req, HttpResponse* resp); void apiServerSearchLog(HttpRequest* req, HttpResponse* resp); void apiServerStatistics(HttpRequest* req, HttpResponse* resp); +// helpers +string apiZoneIdToName(const string& id); +string apiZoneNameToId(const string& name); + // To be provided by product code. void productServerStatisticsFetch(std::map& out); diff --git a/pdns/ws-auth.cc b/pdns/ws-auth.cc index 6935601533..f80b539033 100644 --- a/pdns/ws-auth.cc +++ b/pdns/ws-auth.cc @@ -38,6 +38,11 @@ #include "rapidjson/writer.h" #include "ws-api.hh" #include "version.hh" +#include + +#ifdef HAVE_CONFIG_H +# include +#endif // HAVE_CONFIG_H using namespace rapidjson; @@ -281,8 +286,10 @@ static void fillZone(const string& zonename, HttpResponse* resp) { doc.SetObject(); // id is the canonical lookup key, which doesn't actually match the name (in some cases) - doc.AddMember("id", di.zone.c_str(), doc.GetAllocator()); - string url = (boost::format("/servers/localhost/zones/%s") % di.zone).str(); + string zoneId = apiZoneNameToId(di.zone); + Value jzoneId(zoneId.c_str(), doc.GetAllocator()); // copy + doc.AddMember("id", jzoneId, doc.GetAllocator()); + string url = "/servers/localhost/zones/" + zoneId; Value jurl(url.c_str(), doc.GetAllocator()); // copy doc.AddMember("url", jurl, doc.GetAllocator()); doc.AddMember("name", di.zone.c_str(), doc.GetAllocator()); @@ -422,8 +429,10 @@ static void apiServerZones(HttpRequest* req, HttpResponse* resp) { Value jdi; jdi.SetObject(); // id is the canonical lookup key, which doesn't actually match the name (in some cases) - jdi.AddMember("id", di.zone.c_str(), doc.GetAllocator()); - string url = (boost::format("/servers/localhost/zones/%s") % di.zone).str(); + string zoneId = apiZoneNameToId(di.zone); + Value jzoneId(zoneId.c_str(), doc.GetAllocator()); // copy + jdi.AddMember("id", jzoneId, doc.GetAllocator()); + string url = "/servers/localhost/zones/" + zoneId; Value jurl(url.c_str(), doc.GetAllocator()); // copy jdi.AddMember("url", jurl, doc.GetAllocator()); jdi.AddMember("name", di.zone.c_str(), doc.GetAllocator()); @@ -444,7 +453,7 @@ static void apiServerZones(HttpRequest* req, HttpResponse* resp) { } static void apiServerZoneDetail(HttpRequest* req, HttpResponse* resp) { - string zonename = req->path_parameters["id"]; + string zonename = apiZoneIdToName(req->path_parameters["id"]); if(req->method == "PUT") { // update domain settings @@ -497,7 +506,7 @@ static void apiServerZoneRRset(HttpRequest* req, HttpResponse* resp) { UeberBackend B; DomainInfo di; - string zonename = req->path_parameters["id"]; + string zonename = apiZoneIdToName(req->path_parameters["id"]); if(!B.getDomainInfo(zonename, di)) throw ApiException("Could not find domain '"+zonename+"'"); diff --git a/pdns/ws-auth.hh b/pdns/ws-auth.hh index cf21275dbc..243623b95a 100644 --- a/pdns/ws-auth.hh +++ b/pdns/ws-auth.hh @@ -25,14 +25,6 @@ #include #include #include -#include -#include -#include - -#ifdef HAVE_CONFIG_H -# include -#endif // HAVE_CONFIG_H - #include "misc.hh" #include "namespaces.hh" diff --git a/regression-tests.api/test_Zones.py b/regression-tests.api/test_Zones.py index aa5a2a8b73..b4a5479fd8 100644 --- a/regression-tests.api/test_Zones.py +++ b/regression-tests.api/test_Zones.py @@ -41,17 +41,26 @@ class Servers(ApiTestCase): if k in payload: self.assertEquals(data[k], payload[k]) - @unittest.expectedFailure def test_CreateZoneWithSymbols(self): payload, data = self.create_zone(name='foo/bar.'+unique_zone_name()) name = payload['name'] - expected_id = name.replace('/', '\\047') + expected_id = (name.replace('/', '=47')) + '.' for k in ('id', 'url', 'name', 'masters', 'kind', 'last_check', 'notified_serial', 'serial'): self.assertIn(k, data) if k in payload: self.assertEquals(data[k], payload[k]) self.assertEquals(data['id'], expected_id) + def test_GetZoneWithSymbols(self): + payload, data = self.create_zone(name='foo/bar.'+unique_zone_name()) + name = payload['name'] + zone_id = (name.replace('/', '=47')) + '.' + r = self.session.get(self.url("/servers/localhost/zones/" + zone_id)) + for k in ('id', 'url', 'name', 'masters', 'kind', 'last_check', 'notified_serial', 'serial'): + self.assertIn(k, data) + if k in payload: + self.assertEquals(data[k], payload[k]) + def test_GetZone(self): r = self.session.get(self.url("/servers/localhost/zones")) domains = r.json() -- 2.47.2