]> git.ipfire.org Git - thirdparty/samba.git/commitdiff
Add a daemon caching layer that wraps tdb
authorDavid Mulder <dmulder@samba.org>
Tue, 30 Jul 2024 18:56:41 +0000 (12:56 -0600)
committerDavid Mulder <dmulder@samba.org>
Wed, 23 Oct 2024 14:21:33 +0000 (14:21 +0000)
Signed-off-by: David Mulder <dmulder@samba.org>
Reviewed-by: Alexander Bokovoy <ab@samba.org>
himmelblaud/Cargo.toml
himmelblaud/src/cache.rs [new file with mode: 0644]

index 4206e498c0f687f69fd7450e0075054dcdcfd8ae..bce2de8fe08c12990b758c1ed1c94224c2181efe 100644 (file)
@@ -12,7 +12,9 @@ homepage.workspace = true
 version.workspace = true
 
 [dependencies]
+tdb = { workspace = true }
 dbg = { workspace = true }
+libc = "0.2.155"
 
 [workspace]
 members = [
diff --git a/himmelblaud/src/cache.rs b/himmelblaud/src/cache.rs
new file mode 100644 (file)
index 0000000..5104704
--- /dev/null
@@ -0,0 +1,536 @@
+/*
+   Unix SMB/CIFS implementation.
+
+   Himmelblau daemon cache
+
+   Copyright (C) David Mulder 2024
+
+   This program is free software; you can redistribute it and/or modify
+   it under the terms of the GNU General Public License as published by
+   the Free Software Foundation; either version 3 of the License, or
+   (at your option) any later version.
+
+   This program is distributed in the hope that it will be useful,
+   but WITHOUT ANY WARRANTY; without even the implied warranty of
+   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+   GNU General Public License for more details.
+
+   You should have received a copy of the GNU General Public License
+   along with this program.  If not, see <http://www.gnu.org/licenses/>.
+*/
+use dbg::DBG_ERR;
+use himmelblau::error::MsalError;
+use himmelblau::graph::DirectoryObject;
+use himmelblau::UserToken;
+use kanidm_hsm_crypto::{
+    AuthValue, BoxedDynTpm, LoadableIdentityKey, LoadableMachineKey,
+    LoadableMsOapxbcRsaKey, Tpm,
+};
+use libc::uid_t;
+use ntstatus_gen::*;
+use serde::{Deserialize, Serialize};
+use serde_json::{from_slice as json_from_slice, to_vec as json_to_vec};
+use std::collections::HashSet;
+use tdb::Tdb;
+
+struct BasicCache {
+    tdb: Tdb,
+}
+
+impl BasicCache {
+    fn new(cache_path: &str) -> Result<Self, Box<NTSTATUS>> {
+        let tdb =
+            Tdb::open(cache_path, None, None, None, None).map_err(|e| {
+                DBG_ERR!("{:?}", e);
+                Box::new(NT_STATUS_FILE_INVALID)
+            })?;
+        Ok(BasicCache { tdb })
+    }
+
+    fn fetch_str(&self, key: &str) -> Option<String> {
+        let exists = match self.tdb.exists(key) {
+            Ok(exists) => exists,
+            Err(e) => {
+                DBG_ERR!("Failed to fetch {}: {:?}", key, e);
+                false
+            }
+        };
+        if exists {
+            match self.tdb.fetch(key) {
+                Ok(val) => Some(val),
+                Err(e) => {
+                    DBG_ERR!("Failed to fetch {}: {:?}", key, e);
+                    None
+                }
+            }
+        } else {
+            None
+        }
+    }
+
+    fn fetch<'a, T>(&self, key: &str) -> Option<T>
+    where
+        T: for<'de> Deserialize<'de>,
+    {
+        match self.fetch_str(key) {
+            Some(val) => match json_from_slice::<T>(val.as_bytes()) {
+                Ok(res) => Some(res),
+                Err(e) => {
+                    DBG_ERR!("Failed to decode {}: {:?}", key, e);
+                    None
+                }
+            },
+            None => {
+                return None;
+            }
+        }
+    }
+
+    fn store_bytes(
+        &mut self,
+        key: &str,
+        val: &[u8],
+    ) -> Result<(), Box<NTSTATUS>> {
+        match self.tdb.transaction_start() {
+            Ok(start) => {
+                if !start {
+                    DBG_ERR!("Failed to start the database transaction.");
+                    return Err(Box::new(NT_STATUS_UNSUCCESSFUL));
+                }
+            }
+            Err(e) => {
+                DBG_ERR!("Failed to start the database transaction: {:?}", e);
+                return Err(Box::new(NT_STATUS_UNSUCCESSFUL));
+            }
+        };
+
+        let res = match self.tdb.store(key, val, None) {
+            Ok(res) => Some(res),
+            Err(e) => {
+                DBG_ERR!("Unable to persist {}: {:?}", key, e);
+                None
+            }
+        };
+
+        let res = match res {
+            Some(res) => res,
+            None => {
+                let _ = self.tdb.transaction_cancel();
+                return Err(Box::new(NT_STATUS_UNSUCCESSFUL));
+            }
+        };
+        if !res {
+            DBG_ERR!("Unable to persist {}", key);
+            let _ = self.tdb.transaction_cancel();
+            return Err(Box::new(NT_STATUS_UNSUCCESSFUL));
+        }
+
+        let success = match self.tdb.transaction_commit() {
+            Ok(success) => success,
+            Err(e) => {
+                DBG_ERR!("Failed to commit the database transaction: {:?}", e);
+                return Err(Box::new(NT_STATUS_UNSUCCESSFUL));
+            }
+        };
+        if !success {
+            DBG_ERR!("Failed to commit the database transaction.");
+            let _ = self.tdb.transaction_cancel();
+            return Err(Box::new(NT_STATUS_UNSUCCESSFUL));
+        }
+
+        Ok(())
+    }
+
+    fn store<T>(&mut self, key: &str, val: T) -> Result<(), Box<NTSTATUS>>
+    where
+        T: Serialize,
+    {
+        let val_bytes = match json_to_vec(&val) {
+            Ok(val_bytes) => val_bytes,
+            Err(e) => {
+                DBG_ERR!("Unable to serialize {}: {:?}", key, e);
+                return Err(Box::new(NT_STATUS_UNSUCCESSFUL));
+            }
+        };
+        self.store_bytes(key, &val_bytes)
+    }
+
+    fn keys(&self) -> Result<Vec<String>, Box<NTSTATUS>> {
+        self.tdb.keys().map_err(|e| {
+            DBG_ERR!("{:?}", e);
+            Box::new(NT_STATUS_UNSUCCESSFUL)
+        })
+    }
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub(crate) struct UserEntry {
+    pub(crate) upn: String,
+    pub(crate) uuid: String,
+    pub(crate) name: String,
+}
+
+impl TryFrom<&UserToken> for UserEntry {
+    type Error = MsalError;
+
+    fn try_from(token: &UserToken) -> Result<Self, Self::Error> {
+        Ok(UserEntry {
+            upn: token.spn()?,
+            uuid: token.uuid()?.to_string(),
+            name: token.id_token.name.clone(),
+        })
+    }
+}
+
+pub(crate) struct UserCache {
+    cache: BasicCache,
+}
+
+impl UserCache {
+    pub(crate) fn new(cache_path: &str) -> Result<Self, Box<NTSTATUS>> {
+        Ok(UserCache {
+            cache: BasicCache::new(cache_path)?,
+        })
+    }
+
+    pub(crate) fn fetch(&mut self, upn: &str) -> Option<UserEntry> {
+        self.cache.fetch::<UserEntry>(upn)
+    }
+
+    pub(crate) fn fetch_all(
+        &mut self,
+    ) -> Result<Vec<UserEntry>, Box<NTSTATUS>> {
+        let keys = self.cache.keys()?;
+        let mut res = Vec::new();
+        for upn in keys {
+            let entry = match self.cache.fetch::<UserEntry>(&upn) {
+                Some(entry) => entry,
+                None => {
+                    DBG_ERR!("Unable to fetch user {}", upn);
+                    return Err(Box::new(NT_STATUS_UNSUCCESSFUL));
+                }
+            };
+            res.push(entry);
+        }
+        Ok(res)
+    }
+
+    pub(crate) fn store(
+        &mut self,
+        entry: UserEntry,
+    ) -> Result<(), Box<NTSTATUS>> {
+        let key = entry.upn.clone();
+        self.cache.store::<UserEntry>(&key, entry)
+    }
+}
+
+pub(crate) struct UidCache {
+    cache: BasicCache,
+}
+
+impl UidCache {
+    pub(crate) fn new(cache_path: &str) -> Result<Self, Box<NTSTATUS>> {
+        Ok(UidCache {
+            cache: BasicCache::new(cache_path)?,
+        })
+    }
+
+    pub(crate) fn store(
+        &mut self,
+        uid: uid_t,
+        upn: &str,
+    ) -> Result<(), Box<NTSTATUS>> {
+        let key = format!("{}", uid);
+        self.cache.store_bytes(&key, upn.as_bytes())
+    }
+
+    pub(crate) fn fetch(&mut self, uid: uid_t) -> Option<String> {
+        let key = format!("{}", uid);
+        self.cache.fetch_str(&key)
+    }
+}
+
+#[derive(Clone, Debug, Serialize, Deserialize)]
+pub(crate) struct GroupEntry {
+    pub(crate) uuid: String,
+    members: HashSet<String>,
+}
+
+impl From<DirectoryObject> for GroupEntry {
+    fn from(obj: DirectoryObject) -> Self {
+        GroupEntry {
+            uuid: obj.id.clone(),
+            members: HashSet::new(),
+        }
+    }
+}
+
+impl GroupEntry {
+    pub(crate) fn add_member(&mut self, member: &str) {
+        // Only ever use lowercase names, otherwise the group
+        // memberships will have duplicates.
+        self.members.insert(member.to_lowercase());
+    }
+
+    pub(crate) fn remove_member(&mut self, member: &str) {
+        // Only ever use lowercase names, otherwise the group
+        // memberships will have duplicates.
+        self.members.remove(&member.to_lowercase());
+    }
+
+    pub(crate) fn into_with_member(obj: DirectoryObject, member: &str) -> Self {
+        let mut g: GroupEntry = obj.into();
+        g.add_member(member);
+        g
+    }
+
+    pub(crate) fn members(&self) -> Vec<String> {
+        self.members.clone().into_iter().collect::<Vec<String>>()
+    }
+}
+
+pub(crate) struct GroupCache {
+    cache: BasicCache,
+}
+
+impl GroupCache {
+    pub(crate) fn new(cache_path: &str) -> Result<Self, Box<NTSTATUS>> {
+        Ok(GroupCache {
+            cache: BasicCache::new(cache_path)?,
+        })
+    }
+
+    pub(crate) fn fetch(&mut self, uuid: &str) -> Option<GroupEntry> {
+        self.cache.fetch::<GroupEntry>(uuid)
+    }
+
+    pub(crate) fn fetch_all(
+        &mut self,
+    ) -> Result<Vec<GroupEntry>, Box<NTSTATUS>> {
+        let keys = self.cache.keys()?;
+        let mut res = Vec::new();
+        for uuid in keys {
+            let entry = match self.cache.fetch::<GroupEntry>(&uuid) {
+                Some(entry) => entry,
+                None => {
+                    DBG_ERR!("Unable to fetch group {}", uuid);
+                    return Err(Box::new(NT_STATUS_UNSUCCESSFUL));
+                }
+            };
+            res.push(entry);
+        }
+        Ok(res)
+    }
+
+    pub(crate) fn merge_groups(
+        &mut self,
+        member: &str,
+        entries: Vec<GroupEntry>,
+    ) -> Result<(), Box<NTSTATUS>> {
+        // We need to ensure the member is removed from non-intersecting
+        // groups, otherwise we don't honor group membership removals.
+        let group_uuids: HashSet<String> = entries
+            .clone()
+            .into_iter()
+            .map(|g| g.uuid.clone())
+            .collect();
+        let existing_group_uuids = {
+            let cache = &self.cache;
+            match cache.keys() {
+                Ok(keys) => keys,
+                Err(e) => {
+                    DBG_ERR!("Unable to fetch groups: {:?}", e);
+                    return Err(Box::new(NT_STATUS_UNSUCCESSFUL));
+                }
+            }
+        };
+        let existing_group_uuids: HashSet<String> =
+            existing_group_uuids.into_iter().collect();
+        let difference: HashSet<String> = existing_group_uuids
+            .difference(&group_uuids)
+            .cloned()
+            .collect();
+        for group_uuid in &difference {
+            if let Some(mut group) =
+                self.cache.fetch::<GroupEntry>(&group_uuid).clone()
+            {
+                group.remove_member(member);
+                if let Err(e) =
+                    self.cache.store::<GroupEntry>(&group.uuid.clone(), group)
+                {
+                    DBG_ERR!("Unable to store membership change: {:?}", e);
+                    return Err(Box::new(NT_STATUS_UNSUCCESSFUL));
+                }
+            }
+        }
+
+        // Now add the new entries, merging with existing memberships
+        for group in entries {
+            match self.cache.fetch::<GroupEntry>(&group.uuid) {
+                Some(mut existing_group) => {
+                    // Merge with an existing entry
+                    existing_group.add_member(member);
+                    if let Err(e) = self.cache.store::<GroupEntry>(
+                        &existing_group.uuid.clone(),
+                        existing_group,
+                    ) {
+                        DBG_ERR!("Unable to store membership change: {:?}", e);
+                        return Err(Box::new(NT_STATUS_UNSUCCESSFUL));
+                    }
+                }
+                None => {
+                    if let Err(e) = self
+                        .cache
+                        .store::<GroupEntry>(&group.uuid.clone(), group)
+                    {
+                        DBG_ERR!("Unable to store membership change: {:?}", e);
+                        return Err(Box::new(NT_STATUS_UNSUCCESSFUL));
+                    }
+                }
+            }
+        }
+        Ok(())
+    }
+}
+
+pub(crate) struct PrivateCache {
+    cache: BasicCache,
+}
+
+impl PrivateCache {
+    pub(crate) fn new(cache_path: &str) -> Result<Self, Box<NTSTATUS>> {
+        Ok(PrivateCache {
+            cache: BasicCache::new(cache_path)?,
+        })
+    }
+
+    pub(crate) fn hsm_pin_fetch_or_create(
+        &mut self,
+    ) -> Result<AuthValue, Box<NTSTATUS>> {
+        let hsm_pin = match self.cache.fetch_str("auth_value") {
+            Some(hsm_pin) => hsm_pin,
+            None => {
+                let auth_str = match AuthValue::generate() {
+                    Ok(auth_str) => auth_str,
+                    Err(e) => {
+                        DBG_ERR!("Failed to create hsm pin: {:?}", e);
+                        return Err(Box::new(NT_STATUS_UNSUCCESSFUL));
+                    }
+                };
+                self.cache.store_bytes("auth_value", auth_str.as_bytes())?;
+                auth_str
+            }
+        };
+        match AuthValue::try_from(hsm_pin.as_bytes()) {
+            Ok(auth_value) => Ok(auth_value),
+            Err(e) => {
+                DBG_ERR!("Invalid hsm pin: {:?}", e);
+                return Err(Box::new(NT_STATUS_UNSUCCESSFUL));
+            }
+        }
+    }
+
+    pub(crate) fn loadable_machine_key_fetch_or_create(
+        &mut self,
+        hsm: &mut BoxedDynTpm,
+        auth_value: &AuthValue,
+    ) -> Result<LoadableMachineKey, Box<NTSTATUS>> {
+        match self
+            .cache
+            .fetch::<LoadableMachineKey>("loadable_machine_key")
+        {
+            Some(loadable_machine_key) => Ok(loadable_machine_key),
+            None => {
+                // No machine key found - create one, and store it.
+                let loadable_machine_key =
+                    match hsm.machine_key_create(&auth_value) {
+                        Ok(loadable_machine_key) => loadable_machine_key,
+                        Err(e) => {
+                            DBG_ERR!(
+                                "Unable to create hsm loadable \
+                                machine key: {:?}",
+                                e
+                            );
+                            return Err(Box::new(NT_STATUS_UNSUCCESSFUL));
+                        }
+                    };
+
+                self.cache.store::<LoadableMachineKey>(
+                    "loadable_machine_key",
+                    loadable_machine_key.clone(),
+                )?;
+
+                Ok(loadable_machine_key)
+            }
+        }
+    }
+
+    pub(crate) fn loadable_transport_key_fetch(
+        &mut self,
+        realm: &str,
+    ) -> Option<LoadableMsOapxbcRsaKey> {
+        let transport_key_tag = format!("{}/transport", realm);
+        self.cache
+            .fetch::<LoadableMsOapxbcRsaKey>(&transport_key_tag)
+    }
+
+    pub(crate) fn loadable_cert_key_fetch(
+        &mut self,
+        realm: &str,
+    ) -> Option<LoadableIdentityKey> {
+        let cert_key_tag = format!("{}/certificate", realm);
+        self.cache.fetch::<LoadableIdentityKey>(&cert_key_tag)
+    }
+
+    pub(crate) fn loadable_hello_key_fetch(
+        &mut self,
+        account_id: &str,
+    ) -> Option<LoadableIdentityKey> {
+        let hello_key_tag = format!("{}/hello", account_id);
+        self.cache.fetch::<LoadableIdentityKey>(&hello_key_tag)
+    }
+
+    pub(crate) fn loadable_cert_key_store(
+        &mut self,
+        realm: &str,
+        cert_key: LoadableIdentityKey,
+    ) -> Result<(), Box<NTSTATUS>> {
+        let cert_key_tag = format!("{}/certificate", realm);
+        self.cache
+            .store::<LoadableIdentityKey>(&cert_key_tag, cert_key)
+    }
+
+    pub(crate) fn loadable_hello_key_store(
+        &mut self,
+        account_id: &str,
+        hello_key: LoadableIdentityKey,
+    ) -> Result<(), Box<NTSTATUS>> {
+        let hello_key_tag = format!("{}/hello", account_id);
+        self.cache
+            .store::<LoadableIdentityKey>(&hello_key_tag, hello_key)
+    }
+
+    pub(crate) fn loadable_transport_key_store(
+        &mut self,
+        realm: &str,
+        transport_key: LoadableMsOapxbcRsaKey,
+    ) -> Result<(), Box<NTSTATUS>> {
+        let transport_key_tag = format!("{}/transport", realm);
+        self.cache
+            .store::<LoadableMsOapxbcRsaKey>(&transport_key_tag, transport_key)
+    }
+
+    pub(crate) fn device_id(&mut self, realm: &str) -> Option<String> {
+        let device_id_tag = format!("{}/device_id", realm);
+        self.cache.fetch_str(&device_id_tag)
+    }
+
+    pub(crate) fn device_id_store(
+        &mut self,
+        realm: &str,
+        device_id: &str,
+    ) -> Result<(), Box<NTSTATUS>> {
+        let device_id_tag = format!("{}/device_id", realm);
+        self.cache.store_bytes(&device_id_tag, device_id.as_bytes())
+    }
+}