From: Michael Tremer Date: Tue, 19 Nov 2019 12:02:41 +0000 (+0000) Subject: accounts: Create abstract class for LDAP objects X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=959d8d2a87b0e116709a452750acd2e324ac0ed1;p=ipfire.org.git accounts: Create abstract class for LDAP objects This allows us to use our wrapping layer for groups as well Signed-off-by: Michael Tremer --- diff --git a/src/backend/accounts.py b/src/backend/accounts.py index c5982cac..0d48270d 100644 --- a/src/backend/accounts.py +++ b/src/backend/accounts.py @@ -26,6 +26,136 @@ from .misc import Object # Set the client keytab name os.environ["KRB5_CLIENT_KTNAME"] = "/etc/ipfire.org/ldap.keytab" +class LDAPObject(Object): + def init(self, dn, attrs=None): + self.dn = dn + + self.attributes = attrs or {} + + def __eq__(self, other): + if isinstance(other, self.__class__): + return self.dn == other.dn + + @property + def ldap(self): + return self.accounts.ldap + + def _exists(self, key): + try: + self.attributes[key] + except KeyError: + return False + + return True + + def _get(self, key): + for value in self.attributes.get(key, []): + yield value + + def _get_bytes(self, key, default=None): + for value in self._get(key): + return value + + return default + + def _get_strings(self, key): + for value in self._get(key): + yield value.decode() + + def _get_string(self, key, default=None): + for value in self._get_strings(key): + return value + + return default + + def _get_phone_numbers(self, key): + for value in self._get_strings(key): + yield phonenumbers.parse(value, None) + + def _get_timestamp(self, key): + value = self._get_string(key) + + # Parse the timestamp value and returns a datetime object + if value: + return datetime.datetime.strptime(value, "%Y%m%d%H%M%SZ") + + def _modify(self, modlist): + logging.debug("Modifying %s: %s" % (self.dn, modlist)) + + # Authenticate before performing any write operations + self.accounts._authenticate() + + # Run modify operation + self.ldap.modify_s(self.dn, modlist) + + # Clear cache + self._clear_cache() + + def _clear_cache(self): + """ + Clears cache + """ + pass + + def _set(self, key, values): + current = self._get(key) + + # Don't do anything if nothing has changed + if list(current) == values: + return + + # Remove all old values and add all new ones + modlist = [] + + if self._exists(key): + modlist.append((ldap.MOD_DELETE, key, None)) + + # Add new values + if values: + modlist.append((ldap.MOD_ADD, key, values)) + + # Run modify operation + self._modify(modlist) + + # Update cache + self.attributes.update({ key : values }) + + def _set_bytes(self, key, values): + return self._set(key, values) + + def _set_strings(self, key, values): + return self._set(key, [e.encode() for e in values if e]) + + def _set_string(self, key, value): + return self._set_strings(key, [value,]) + + def _add(self, key, values): + modlist = [ + (ldap.MOD_ADD, key, values), + ] + + self._modify(modlist) + + def _add_strings(self, key, values): + return self._add(key, [e.encode() for e in values]) + + def _add_string(self, key, value): + return self._add_strings(key, [value,]) + + def _delete(self, key, values): + modlist = [ + (ldap.MOD_DELETE, key, values), + ] + + self._modify(modlist) + + def _delete_strings(self, key, values): + return self._delete(key, [e.encode() for e in values]) + + def _delete_string(self, key, value): + return self._delete_strings(key, [value,]) + + class Accounts(Object): def init(self): self.search_base = self.settings.get("ldap_search_base") @@ -361,13 +491,7 @@ class Accounts(Object): return h.hexdigest() -class Account(Object): - def __init__(self, backend, dn, attrs=None): - Object.__init__(self, backend) - self.dn = dn - - self.attributes = attrs or {} - +class Account(LDAPObject): def __str__(self): if self.nickname: return self.nickname @@ -377,127 +501,14 @@ class Account(Object): def __repr__(self): return "<%s %s>" % (self.__class__.__name__, self.dn) - def __eq__(self, other): - if isinstance(other, self.__class__): - return self.dn == other.dn - def __lt__(self, other): if isinstance(other, self.__class__): return self.name < other.name - @property - def ldap(self): - return self.accounts.ldap - - def _exists(self, key): - try: - self.attributes[key] - except KeyError: - return False - - return True - - def _get(self, key): - for value in self.attributes.get(key, []): - yield value - - def _get_bytes(self, key, default=None): - for value in self._get(key): - return value - - return default - - def _get_strings(self, key): - for value in self._get(key): - yield value.decode() - - def _get_string(self, key, default=None): - for value in self._get_strings(key): - return value - - return default - - def _get_phone_numbers(self, key): - for value in self._get_strings(key): - yield phonenumbers.parse(value, None) - - def _get_timestamp(self, key): - value = self._get_string(key) - - # Parse the timestamp value and returns a datetime object - if value: - return datetime.datetime.strptime(value, "%Y%m%d%H%M%SZ") - - def _modify(self, modlist): - logging.debug("Modifying %s: %s" % (self.dn, modlist)) - - # Authenticate before performing any write operations - self.accounts._authenticate() - - # Run modify operation - self.ldap.modify_s(self.dn, modlist) - + def _clear_cache(self): # Delete cached attributes self.memcache.delete("accounts:%s:attrs" % self.dn) - def _set(self, key, values): - current = self._get(key) - - # Don't do anything if nothing has changed - if list(current) == values: - return - - # Remove all old values and add all new ones - modlist = [] - - if self._exists(key): - modlist.append((ldap.MOD_DELETE, key, None)) - - # Add new values - if values: - modlist.append((ldap.MOD_ADD, key, values)) - - # Run modify operation - self._modify(modlist) - - # Update cache - self.attributes.update({ key : values }) - - def _set_bytes(self, key, values): - return self._set(key, values) - - def _set_strings(self, key, values): - return self._set(key, [e.encode() for e in values if e]) - - def _set_string(self, key, value): - return self._set_strings(key, [value,]) - - def _add(self, key, values): - modlist = [ - (ldap.MOD_ADD, key, values), - ] - - self._modify(modlist) - - def _add_strings(self, key, values): - return self._add(key, [e.encode() for e in values]) - - def _add_string(self, key, value): - return self._add_strings(key, [value,]) - - def _delete(self, key, values): - modlist = [ - (ldap.MOD_DELETE, key, values), - ] - - self._modify(modlist) - - def _delete_strings(self, key, values): - return self._delete(key, [e.encode() for e in values]) - - def _delete_string(self, key, value): - return self._delete_strings(key, [value,]) - @lazy_property def kerberos_attributes(self): res = self.backend.accounts._query( @@ -1171,12 +1182,7 @@ class Groups(Object): ) -class Group(Object): - def init(self, dn, attrs=None): - self.dn = dn - - self.attributes = attrs or {} - +class Group(LDAPObject): def __repr__(self): if self.description: return "<%s %s (%s)>" % ( @@ -1190,10 +1196,6 @@ class Group(Object): def __str__(self): return self.description or self.gid - def __eq__(self, other): - if isinstance(other, self.__class__): - return self.gid == other.gid - def __lt__(self, other): if isinstance(other, self.__class__): return (self.description or self.gid) < (other.description or other.gid) @@ -1219,44 +1221,29 @@ class Group(Object): @property def gid(self): - try: - gid = self.attributes["cn"][0] - except KeyError: - return None - - return gid.decode() + return self._get_string("cn") @property def description(self): - try: - description = self.attributes["description"][0] - except KeyError: - return None - - return description.decode() + return self._get_string("description") @property def email(self): - try: - email = self.attributes["mail"][0] - except KeyError: - return None - - return email.decode() + return self._get_string("mail") @lazy_property def members(self): members = [] # Get all members by DN - for dn in self.attributes.get("member", []): - member = self.backend.accounts.get_by_dn(dn.decode()) + for dn in self._get_strings("member"): + member = self.backend.accounts.get_by_dn(dn) if member: members.append(member) - # Get all meembers by UID - for uid in self.attributes.get("memberUid", []): - member = self.backend.accounts.get_by_uid(uid.decode()) + # Get all members by UID + for uid in self._get_strings("memberUid"): + member = self.backend.accounts.get_by_uid(uid) if member: members.append(member)