]> git.ipfire.org Git - ipfire.org.git/commitdiff
accounts: Create abstract class for LDAP objects
authorMichael Tremer <michael.tremer@ipfire.org>
Tue, 19 Nov 2019 12:02:41 +0000 (12:02 +0000)
committerMichael Tremer <michael.tremer@ipfire.org>
Tue, 19 Nov 2019 12:02:41 +0000 (12:02 +0000)
This allows us to use our wrapping layer for groups as well

Signed-off-by: Michael Tremer <michael.tremer@ipfire.org>
src/backend/accounts.py

index c5982cacffe1b5c25b10e0f1073a95eae2992df6..0d48270da99c179a29a370ae5d267d7ac5fe4a29 100644 (file)
@@ -26,6 +26,136 @@ from .misc import Object
 # Set the client keytab name
 os.environ["KRB5_CLIENT_KTNAME"] = "/etc/ipfire.org/ldap.keytab"
 
+class LDAPObject(Object):
+       def init(self, dn, attrs=None):
+               self.dn = dn
+
+               self.attributes = attrs or {}
+
+       def __eq__(self, other):
+               if isinstance(other, self.__class__):
+                       return self.dn == other.dn
+
+       @property
+       def ldap(self):
+               return self.accounts.ldap
+
+       def _exists(self, key):
+               try:
+                       self.attributes[key]
+               except KeyError:
+                       return False
+
+               return True
+
+       def _get(self, key):
+               for value in self.attributes.get(key, []):
+                       yield value
+
+       def _get_bytes(self, key, default=None):
+               for value in self._get(key):
+                       return value
+
+               return default
+
+       def _get_strings(self, key):
+               for value in self._get(key):
+                       yield value.decode()
+
+       def _get_string(self, key, default=None):
+               for value in self._get_strings(key):
+                       return value
+
+               return default
+
+       def _get_phone_numbers(self, key):
+               for value in self._get_strings(key):
+                       yield phonenumbers.parse(value, None)
+
+       def _get_timestamp(self, key):
+               value = self._get_string(key)
+
+               # Parse the timestamp value and returns a datetime object
+               if value:
+                       return datetime.datetime.strptime(value, "%Y%m%d%H%M%SZ")
+
+       def _modify(self, modlist):
+               logging.debug("Modifying %s: %s" % (self.dn, modlist))
+
+               # Authenticate before performing any write operations
+               self.accounts._authenticate()
+
+               # Run modify operation
+               self.ldap.modify_s(self.dn, modlist)
+
+               # Clear cache
+               self._clear_cache()
+
+       def _clear_cache(self):
+               """
+                       Clears cache
+               """
+               pass
+
+       def _set(self, key, values):
+               current = self._get(key)
+
+               # Don't do anything if nothing has changed
+               if list(current) == values:
+                       return
+
+               # Remove all old values and add all new ones
+               modlist = []
+
+               if self._exists(key):
+                       modlist.append((ldap.MOD_DELETE, key, None))
+
+               # Add new values
+               if values:
+                       modlist.append((ldap.MOD_ADD, key, values))
+
+               # Run modify operation
+               self._modify(modlist)
+
+               # Update cache
+               self.attributes.update({ key : values })
+
+       def _set_bytes(self, key, values):
+               return self._set(key, values)
+
+       def _set_strings(self, key, values):
+               return self._set(key, [e.encode() for e in values if e])
+
+       def _set_string(self, key, value):
+               return self._set_strings(key, [value,])
+
+       def _add(self, key, values):
+               modlist = [
+                       (ldap.MOD_ADD, key, values),
+               ]
+
+               self._modify(modlist)
+
+       def _add_strings(self, key, values):
+               return self._add(key, [e.encode() for e in values])
+
+       def _add_string(self, key, value):
+               return self._add_strings(key, [value,])
+
+       def _delete(self, key, values):
+               modlist = [
+                       (ldap.MOD_DELETE, key, values),
+               ]
+
+               self._modify(modlist)
+
+       def _delete_strings(self, key, values):
+               return self._delete(key, [e.encode() for e in values])
+
+       def _delete_string(self, key, value):
+               return self._delete_strings(key, [value,])
+
+
 class Accounts(Object):
        def init(self):
                self.search_base = self.settings.get("ldap_search_base")
@@ -361,13 +491,7 @@ class Accounts(Object):
                return h.hexdigest()
 
 
-class Account(Object):
-       def __init__(self, backend, dn, attrs=None):
-               Object.__init__(self, backend)
-               self.dn = dn
-
-               self.attributes = attrs or {}
-
+class Account(LDAPObject):
        def __str__(self):
                if self.nickname:
                        return self.nickname
@@ -377,127 +501,14 @@ class Account(Object):
        def __repr__(self):
                return "<%s %s>" % (self.__class__.__name__, self.dn)
 
-       def __eq__(self, other):
-               if isinstance(other, self.__class__):
-                       return self.dn == other.dn
-
        def __lt__(self, other):
                if isinstance(other, self.__class__):
                        return self.name < other.name
 
-       @property
-       def ldap(self):
-               return self.accounts.ldap
-
-       def _exists(self, key):
-               try:
-                       self.attributes[key]
-               except KeyError:
-                       return False
-
-               return True
-
-       def _get(self, key):
-               for value in self.attributes.get(key, []):
-                       yield value
-
-       def _get_bytes(self, key, default=None):
-               for value in self._get(key):
-                       return value
-
-               return default
-
-       def _get_strings(self, key):
-               for value in self._get(key):
-                       yield value.decode()
-
-       def _get_string(self, key, default=None):
-               for value in self._get_strings(key):
-                       return value
-
-               return default
-
-       def _get_phone_numbers(self, key):
-               for value in self._get_strings(key):
-                       yield phonenumbers.parse(value, None)
-
-       def _get_timestamp(self, key):
-               value = self._get_string(key)
-
-               # Parse the timestamp value and returns a datetime object
-               if value:
-                       return datetime.datetime.strptime(value, "%Y%m%d%H%M%SZ")
-
-       def _modify(self, modlist):
-               logging.debug("Modifying %s: %s" % (self.dn, modlist))
-
-               # Authenticate before performing any write operations
-               self.accounts._authenticate()
-
-               # Run modify operation
-               self.ldap.modify_s(self.dn, modlist)
-
+       def _clear_cache(self):
                # Delete cached attributes
                self.memcache.delete("accounts:%s:attrs" % self.dn)
 
-       def _set(self, key, values):
-               current = self._get(key)
-
-               # Don't do anything if nothing has changed
-               if list(current) == values:
-                       return
-
-               # Remove all old values and add all new ones
-               modlist = []
-
-               if self._exists(key):
-                       modlist.append((ldap.MOD_DELETE, key, None))
-
-               # Add new values
-               if values:
-                       modlist.append((ldap.MOD_ADD, key, values))
-
-               # Run modify operation
-               self._modify(modlist)
-
-               # Update cache
-               self.attributes.update({ key : values })
-
-       def _set_bytes(self, key, values):
-               return self._set(key, values)
-
-       def _set_strings(self, key, values):
-               return self._set(key, [e.encode() for e in values if e])
-
-       def _set_string(self, key, value):
-               return self._set_strings(key, [value,])
-
-       def _add(self, key, values):
-               modlist = [
-                       (ldap.MOD_ADD, key, values),
-               ]
-
-               self._modify(modlist)
-
-       def _add_strings(self, key, values):
-               return self._add(key, [e.encode() for e in values])
-
-       def _add_string(self, key, value):
-               return self._add_strings(key, [value,])
-
-       def _delete(self, key, values):
-               modlist = [
-                       (ldap.MOD_DELETE, key, values),
-               ]
-
-               self._modify(modlist)
-
-       def _delete_strings(self, key, values):
-               return self._delete(key, [e.encode() for e in values])
-
-       def _delete_string(self, key, value):
-               return self._delete_strings(key, [value,])
-
        @lazy_property
        def kerberos_attributes(self):
                res = self.backend.accounts._query(
@@ -1171,12 +1182,7 @@ class Groups(Object):
                )
 
 
-class Group(Object):
-       def init(self, dn, attrs=None):
-               self.dn = dn
-
-               self.attributes = attrs or {}
-
+class Group(LDAPObject):
        def __repr__(self):
                if self.description:
                        return "<%s %s (%s)>" % (
@@ -1190,10 +1196,6 @@ class Group(Object):
        def __str__(self):
                return self.description or self.gid
 
-       def __eq__(self, other):
-               if isinstance(other, self.__class__):
-                       return self.gid == other.gid
-
        def __lt__(self, other):
                if isinstance(other, self.__class__):
                        return (self.description or self.gid) < (other.description or other.gid)
@@ -1219,44 +1221,29 @@ class Group(Object):
 
        @property
        def gid(self):
-               try:
-                       gid = self.attributes["cn"][0]
-               except KeyError:
-                       return None
-
-               return gid.decode()
+               return self._get_string("cn")
 
        @property
        def description(self):
-               try:
-                       description = self.attributes["description"][0]
-               except KeyError:
-                       return None
-
-               return description.decode()
+               return self._get_string("description")
 
        @property
        def email(self):
-               try:
-                       email = self.attributes["mail"][0]
-               except KeyError:
-                       return None
-
-               return email.decode()
+               return self._get_string("mail")
 
        @lazy_property
        def members(self):
                members = []
 
                # Get all members by DN
-               for dn in self.attributes.get("member", []):
-                       member = self.backend.accounts.get_by_dn(dn.decode())
+               for dn in self._get_strings("member"):
+                       member = self.backend.accounts.get_by_dn(dn)
                        if member:
                                members.append(member)
 
-               # Get all meembers by UID
-               for uid in self.attributes.get("memberUid", []):
-                       member = self.backend.accounts.get_by_uid(uid.decode())
+               # Get all members by UID
+               for uid in self._get_strings("memberUid"):
+                       member = self.backend.accounts.get_by_uid(uid)
                        if member:
                                members.append(member)