]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
Use Arc instead of global static
authorOtto Moerbeek <otto.moerbeek@open-xchange.com>
Mon, 25 Nov 2024 10:13:36 +0000 (11:13 +0100)
committerOtto Moerbeek <otto.moerbeek@open-xchange.com>
Tue, 11 Feb 2025 15:28:22 +0000 (16:28 +0100)
pdns/recursordist/settings/rust/src/web.rs
pdns/recursordist/ws-recursor.cc

index d934f21456090ed0dfd04121555a6a595d90effd..06e3559653772b98f24d9967e82f9db7e10cd32c 100644 (file)
@@ -1,3 +1,18 @@
+/*
+TODO
+
+- Logging
+- Table based routing including OPTIONS request handling
+- Requests taking e.g. an <id>
+- ACLs
+- Authorization
+- Allow multipe listen addreses in settings (singlevalued right now)
+- TLS?
+- Code is now in settings dir. It's only possible to split the modules into separate Rust libs if we
+  use shared libs (in theory, I did not try). Currenlty all CXX using Rust cargo's must be compiled
+  as one and refer to a single static Rust runtime,
+*/
+
 use std::net::SocketAddr;
 
 use bytes::Bytes;
@@ -9,9 +24,10 @@ use hyper_util::rt::TokioIo;
 use tokio::net::TcpListener;
 use tokio::runtime::Builder;
 use tokio::task::JoinSet;
-
 use std::io::ErrorKind;
 use std::str::FromStr;
+use std::sync::Arc;
+use tokio::sync::Mutex;
 
 type GenericError = Box<dyn std::error::Error + Send + Sync>;
 type MyResult<T> = std::result::Result<T, GenericError>;
@@ -28,17 +44,66 @@ fn full<T: Into<Bytes>>(chunk: T) -> BoxBody {
 type Func = fn(&rustweb::Request, &mut rustweb::Response) -> Result<(), cxx::Exception>;
 
 fn api_wrapper(
+    ctx: &Context,
     handler: Func,
     request: &rustweb::Request,
     response: &mut rustweb::Response,
+    reqheaders: &header::HeaderMap,
     headers: &mut header::HeaderMap,
 ) {
-    response.status = StatusCode::OK.as_u16(); // 200;
-                                               // security headers
+    // security headers
     headers.insert(
         header::ACCESS_CONTROL_ALLOW_ORIGIN,
         header::HeaderValue::from_static("*"),
     );
+    if ctx.api_key.is_empty() {
+        // XXX log
+        // Www-Authenticate: X-API-Key realm="PowerDNS"
+        let status =  StatusCode::UNAUTHORIZED;
+        response.status = status.as_u16();
+        headers.insert(
+            header::WWW_AUTHENTICATE,
+            header::HeaderValue::from_static("X-API-Key ream=\"PowerDNS\""),
+        );
+        response.body = status.canonical_reason().unwrap().as_bytes().to_vec();
+        return;
+    }
+
+    // XXX encrypted credentials handling, password handling!
+    let allow_password = false;
+    let mut auth_ok = false;
+    if let Some(api) = reqheaders.get("x-api-key") {
+        auth_ok = api.as_bytes() == ctx.api_key.as_bytes();
+        println!("OK {}", auth_ok);
+    }
+    if !auth_ok {
+        for kv in &request.vars {
+            if kv.key == "x-api-key" && kv.value == ctx.api_key {
+                auth_ok = true;
+                break;
+            }
+        }
+    }
+    if !auth_ok && allow_password {
+        if !ctx.webserver_password.is_empty() {
+            //auth_ok = req->compareAuthorization(*d_webserverPassword); XXX
+        } else {
+            auth_ok = true;
+        }
+    }
+    if !auth_ok {
+        // XXX log
+        let status =  StatusCode::UNAUTHORIZED;
+        response.status = status.as_u16();
+        headers.insert(
+            header::WWW_AUTHENTICATE,
+            header::HeaderValue::from_static("X-API-Key ream=\"PowerDNS\""),
+        );
+        response.body = status.canonical_reason().unwrap().as_bytes().to_vec();
+        return;
+    }
+    response.status = StatusCode::OK.as_u16(); // 200;
+
     headers.insert(
         header::X_CONTENT_TYPE_OPTIONS,
         header::HeaderValue::from_static("nosniff"),
@@ -63,15 +128,28 @@ fn api_wrapper(
     match handler(request, response) {
         Ok(_) => {}
         Err(_) => {
-            response.status = StatusCode::UNPROCESSABLE_ENTITY.as_u16(); // 422
+            let status =  StatusCode::UNPROCESSABLE_ENTITY; // 422
+            response.status = status.as_u16();
+            response.body = status.canonical_reason().unwrap().as_bytes().to_vec();
         }
     }
 }
 
+struct Context {
+    urls: Vec<String>,
+    api_key: String,
+    webserver_password: String,
+    counter: Mutex<u32>,
+}
+
 async fn hello(
     rust_request: Request<IncomingBody>,
-    urls: &[String],
+    ctx: Arc<Context>
 ) -> MyResult<Response<BoxBody>> {
+    {
+        let mut counter = ctx.counter.lock().await;
+        *counter += 1;
+    }
     let mut rust_response = Response::builder();
     let mut vars: Vec<rustweb::KeyValue> = vec![];
     if let Some(query) = rust_request.uri().query() {
@@ -98,173 +176,55 @@ async fn hello(
         headers: vec![],
     };
     let headers = rust_response.headers_mut().expect("no headers?");
-    match (rust_request.method(), rust_request.uri().path()) {
-        (&Method::GET, "/jsonstat") => {
-            api_wrapper(
-                rustweb::jsonstat as Func,
-                &request,
-                &mut response,
-                headers,
-            );
-        }
-        (&Method::PUT, "/api/v1/servers/localhost/cache/flush") => {
-            request.body = rust_request.collect().await?.to_bytes().to_vec();
-            api_wrapper(
-                rustweb::apiServerCacheFlush as Func,
-                &request,
-                &mut response,
-                headers,
-            );
-        }
-        (&Method::PUT, "/api/v1/servers/localhost/config/allow-from") => {
-            request.body = rust_request.collect().await?.to_bytes().to_vec();
-            api_wrapper(
-                rustweb::apiServerConfigAllowFromPUT as Func,
-                &request,
-                &mut response,
-                headers,
-            );
-        }
-        (&Method::GET, "/api/v1/servers/localhost/config/allow-from") => {
-            api_wrapper(
-                rustweb::apiServerConfigAllowFromGET as Func,
-                &request,
-                &mut response,
-                headers,
-            );
-        }
-        (&Method::PUT, "/api/v1/servers/localhost/config/allow-notify-from") => {
-            request.body = rust_request.collect().await?.to_bytes().to_vec();
-            api_wrapper(
-                rustweb::apiServerConfigAllowNotifyFromPUT as Func,
-                &request,
-                &mut response,
-                headers,
-            );
-        }
-        (&Method::GET, "/api/v1/servers/localhost/config/allow-notify-from") => {
-            api_wrapper(
-                rustweb::apiServerConfigAllowNotifyFromGET as Func,
-                &request,
-                &mut response,
-                headers,
-            );
-        }
-        (&Method::GET, "/api/v1/servers/localhost/config") => {
-            api_wrapper(
-                rustweb::apiServerConfig as Func,
-                &request,
-                &mut response,
-                headers,
-            );
-        }
-        (&Method::GET, "/api/v1/servers/localhost/rpzstatistics") => {
-            api_wrapper(
-                rustweb::apiServerRPZStats as Func,
-                &request,
-                &mut response,
-                headers,
-            );
-        }
-        (&Method::GET, "/api/v1/servers/localhost/search-data") => {
-            api_wrapper(
-                rustweb::apiServerSearchData as Func,
-                &request,
-                &mut response,
-                headers,
-            );
-        }
-        (&Method::GET, "/api/v1/servers/localhost/zones/") => {
-            api_wrapper(
-                rustweb::apiServerZoneDetailGET as Func,
-                &request,
-                &mut response,
-                headers,
-            );
-        }
-        (&Method::PUT, "/api/v1/servers/localhost/zones/") => {
-            request.body = rust_request.collect().await?.to_bytes().to_vec();
-            api_wrapper(
-                rustweb::apiServerZoneDetailPUT as Func,
-                &request,
-                &mut response,
-                headers,
-            );
-        }
-        (&Method::DELETE, "/api/v1/servers/localhost/zones/") => {
-            api_wrapper(
-                rustweb::apiServerZoneDetailDELETE as Func,
-                &request,
-                &mut response,
-                headers,
-            );
-        }
-        (&Method::GET, "/api/v1/servers/localhost/statistics") => {
-            api_wrapper(
-                rustweb::apiServerStatistics as Func,
-                &request,
-                &mut response,
-                headers,
-            );
-        }
-        (&Method::GET, "/api/v1/servers/localhost/zones") => {
-            api_wrapper(
-                rustweb::apiServerZonesGET as Func,
-                &request,
-                &mut response,
-                headers,
-            );
-        }
-        (&Method::POST, "/api/v1/servers/localhost/zones") => {
-            request.body = rust_request.collect().await?.to_bytes().to_vec();
-            api_wrapper(
-                rustweb::apiServerZonesPOST as Func,
-                &request,
-                &mut response,
-                headers,
-            );
-        }
-        (&Method::GET, "/api/v1/servers/localhost") => {
-            api_wrapper(
-                rustweb::apiServerDetail as Func,
-                &request,
-                &mut response,
-                headers,
-            );
-        }
-        (&Method::GET, "/api/v1/servers") => {
-            api_wrapper(
-                rustweb::apiServer as Func,
-                &request,
-                &mut response,
-                headers,
-            );
-        }
-        (&Method::GET, "/api/v1") => {
-            api_wrapper(
-                rustweb::apiDiscoveryV1 as Func,
-                &request,
-                &mut response,
-                headers,
-            );
-        }
-        (&Method::GET, "/api") => {
-            api_wrapper(
-                rustweb::apiDiscovery as Func,
-                &request,
-                &mut response,
-                headers,
-            );
-        }
-        (&Method::GET, "/metrics") => {
-            rustweb::prometheusMetrics(&request, &mut response).unwrap();
-        }
+    let mut apifunc: Option<Func> = None;
+    let method = rust_request.method().to_owned();
+    match (&method, rust_request.uri().path()) {
+        (&Method::GET, "/jsonstat") =>
+            apifunc = Some(rustweb::jsonstat),
+        (&Method::PUT, "/api/v1/servers/localhost/cache/flush") =>
+            apifunc = Some(rustweb::apiServerCacheFlush),
+        (&Method::PUT, "/api/v1/servers/localhost/config/allow-from") =>
+            apifunc = Some(rustweb::apiServerConfigAllowFromPUT),
+        (&Method::GET, "/api/v1/servers/localhost/config/allow-from") =>
+            apifunc = Some(rustweb::apiServerConfigAllowFromGET),
+        (&Method::PUT, "/api/v1/servers/localhost/config/allow-notify-from") =>
+            apifunc = Some(rustweb::apiServerConfigAllowNotifyFromPUT),
+        (&Method::GET, "/api/v1/servers/localhost/config/allow-notify-from") =>
+            apifunc = Some(rustweb::apiServerConfigAllowNotifyFromGET),
+        (&Method::GET, "/api/v1/servers/localhost/config") =>
+            apifunc = Some(rustweb::apiServerConfig),
+        (&Method::GET, "/api/v1/servers/localhost/rpzstatistics") =>
+            apifunc = Some(rustweb::apiServerRPZStats),
+        (&Method::GET, "/api/v1/servers/localhost/search-data") =>
+            apifunc = Some(rustweb::apiServerSearchData),
+        (&Method::GET, "/api/v1/servers/localhost/zones/") =>
+            apifunc = Some(rustweb::apiServerZoneDetailGET),
+        (&Method::PUT, "/api/v1/servers/localhost/zones/") =>
+            apifunc = Some(rustweb::apiServerZoneDetailPUT),
+        (&Method::DELETE, "/api/v1/servers/localhost/zones/") =>
+            apifunc = Some(rustweb::apiServerZoneDetailDELETE),
+        (&Method::GET, "/api/v1/servers/localhost/statistics") =>
+            apifunc = Some(rustweb::apiServerStatistics),
+        (&Method::GET, "/api/v1/servers/localhost/zones") =>
+            apifunc = Some(rustweb::apiServerZonesGET),
+        (&Method::POST, "/api/v1/servers/localhost/zones") =>
+            apifunc = Some(rustweb::apiServerZonesPOST),
+        (&Method::GET, "/api/v1/servers/localhost") =>
+            apifunc = Some(rustweb::apiServerDetail),
+        (&Method::GET, "/api/v1/servers") =>
+            apifunc = Some(rustweb::apiServer),
+        (&Method::GET, "/api/v1") =>
+            apifunc = Some(rustweb::apiDiscoveryV1),
+        (&Method::GET, "/api") =>
+            apifunc = Some(rustweb::apiDiscovery),
+        (&Method::GET, "/metrics") =>
+            rustweb::prometheusMetrics(&request, &mut response).unwrap(),
         _ => {
             let mut path = rust_request.uri().path();
             if path == "/" {
                 path = "/index.html";
             }
-            let pos = urls.iter().position(|x| String::from("/") + x == path);
+            let pos = ctx.urls.iter().position(|x| String::from("/") + x == path);
             if pos.is_none() {
                 eprintln!("{} {} not found", rust_request.method(), path);
             }
@@ -276,52 +236,81 @@ async fn hello(
             }
         }
     }
+    if let Some(func) = apifunc {
+        let reqheaders = rust_request.headers().clone();
+        if rust_request.method()== Method::POST || rust_request.method() == Method::PUT {
+            request.body = rust_request.collect().await?.to_bytes().to_vec();
+        }
+        api_wrapper(
+            &ctx,
+            func,
+            &request,
+            &mut response,
+            &reqheaders,
+            headers,
+        );
+    }
+
+    let mut body = full(response.body);
+    if method == Method::HEAD {
+        body = full(vec!());
+    }
+
     let mut rust_response = rust_response
         .status(StatusCode::from_u16(response.status).unwrap())
-        .body(full(response.body))?;
+        .body(body)?;
     for kv in response.headers {
         rust_response.headers_mut().insert(
             header::HeaderName::from_bytes(kv.key.as_bytes()).unwrap(),
             header::HeaderValue::from_str(kv.value.as_str()).unwrap(),
         );
     }
+
+    rust_response.headers_mut().insert(
+        header::CONNECTION,
+        header::HeaderValue::from_str("close").unwrap(),
+    );
     Ok(rust_response)
 }
 
-async fn serveweb_async(listener: TcpListener, urls: &'static [String]) -> MyResult<()> {
-    //let request_counter = Arc::new(AtomicUsize::new(0));
-    /*
-        let fut = http1::Builder::new()
-            .serve_connection(move || {
-                service_fn(move |req| hello(req))
-    });
-        */
+async fn serveweb_async(listener: TcpListener, ctx: Arc<Context>) -> MyResult<()> {
+
     // We start a loop to continuously accept incoming connections
     loop {
+        let ctx = Arc::clone(&ctx);
+        let ctx2 = Arc::clone(&ctx);
         let (stream, _) = listener.accept().await?;
 
         // Use an adapter to access something implementing `tokio::io` traits as if they implement
         // `hyper::rt` IO traits.
         let io = TokioIo::new(stream);
         let fut =
-            http1::Builder::new().serve_connection(io, service_fn(move |req| hello(req, urls)));
+            http1::Builder::new().serve_connection(io, service_fn(move |req| {
+                let ctx = Arc::clone(&ctx);
+                hello(req, ctx)
+            }));
 
         // Spawn a tokio task to serve multiple connections concurrently
         tokio::task::spawn(async move {
             // Finally, we bind the incoming connection to our `hello` service
-            if let Err(err) = /* http1::Builder::new()
-                // `service_fn` converts our function in a `Service`
-                    .serve_connection(io, service_fn(|req| hello(req)))
-                    */
-                fut.await
+            if let Err(err) = fut.await
             {
                 eprintln!("Error serving connection: {:?}", err);
             }
         });
+        eprintln!("{}", ctx2.counter.lock().await);
     }
 }
 
-pub fn serveweb(addresses: &Vec<String>, urls: &'static [String]) -> Result<(), std::io::Error> {
+pub fn serveweb(addresses: &Vec<String>, urls: &[String], api_key: String, webserver_password: String) -> Result<(), std::io::Error> {
+    // Context (R/O for now)
+    let ctx = Arc::new(Context {
+        urls: urls.to_vec(),
+        api_key,
+        webserver_password,
+        counter: Mutex::new(0),
+    });
+
     let runtime = Builder::new_current_thread()
         .worker_threads(1)
         .thread_name("rec/web")
@@ -342,11 +331,11 @@ pub fn serveweb(addresses: &Vec<String>, urls: &'static [String]) -> Result<(),
         };
 
         let listener = runtime.block_on(async { TcpListener::bind(addr).await });
-
+        let ctx = Arc::clone(&ctx);
         match listener {
             Ok(val) => {
                 println!("Listening on {}", addr);
-                set.spawn_on(serveweb_async(val, urls), runtime.handle());
+                set.spawn_on(serveweb_async(val, ctx), runtime.handle());
             }
             Err(err) => {
                 let msg = format!("Unable to bind web socket: {}", err);
@@ -367,13 +356,13 @@ pub fn serveweb(addresses: &Vec<String>, urls: &'static [String]) -> Result<(),
 }
 
 #[cxx::bridge(namespace = "pdns::rust::web::rec")]
-/*
- * Functions callable from C++
- */
 mod rustweb {
 
+    /*
+     * Functions callable from C++
+     */
     extern "Rust" {
-        fn serveweb(addreses: &Vec<String>, urls: &'static [String]) -> Result<()>;
+        fn serveweb(addreses: &Vec<String>, urls: &[String], apikey: String, password: String) -> Result<()>;
     }
 
     struct KeyValue {
@@ -393,6 +382,9 @@ mod rustweb {
         headers: Vec<KeyValue>,
     }
 
+    /*
+     * Functions callable from Rust
+     */
     unsafe extern "C++" {
         include!("bridge.hh");
         fn apiDiscovery(request: &Request, response: &mut Response) -> Result<()>;
index 1e112647efc5344d8c546ca83c83cd371e7c28fb..d6d8b681ac678b3039ad76c7ee5e71458b00d598 100644 (file)
@@ -956,12 +956,12 @@ void AsyncWebServer::go()
 
 void serveRustWeb()
 {
-  static ::rust::Vec<::rust::String> urls;
+  ::rust::Vec<::rust::String> urls;
   for (const auto& [url, _] : g_urlmap) {
     urls.emplace_back(url);
   }
   auto address = ComboAddress(arg()["webserver-address"], arg().asNum("webserver-port"));
-  pdns::rust::web::rec::serveweb({::rust::String(address.toStringWithPort())}, ::rust::Slice<const ::rust::String>{urls.data(), urls.size()});
+  pdns::rust::web::rec::serveweb({::rust::String(address.toStringWithPort())}, ::rust::Slice<const ::rust::String>{urls.data(), urls.size()}, arg()["api-key"], arg()["webserver-password"]);
 }
 
 static void fromCxxToRust(const HttpResponse& cxxresp, pdns::rust::web::rec::Response& rustResponse)