]> git.ipfire.org Git - ipfire.org.git/blobdiff - src/backend/accounts.py
talk: Add a search
[ipfire.org.git] / src / backend / accounts.py
index faede4e31f10625746bb1e88ef8a4c9a34a1d91c..0151fb71edf8b769d13a9943bf8e5b0c3c1e1ea1 100644 (file)
@@ -14,9 +14,9 @@ from .misc import Object
 class Accounts(Object):
        def __iter__(self):
                # Only return developers (group with ID 1000)
-               accounts = self.search("(&(objectClass=posixAccount)(gidNumber=1000))")
+               accounts = self._search("(&(objectClass=posixAccount)(gidNumber=1000))")
 
-               return iter(accounts)
+               return iter(sorted(accounts))
 
        @property
        def ldap(self):
@@ -33,7 +33,7 @@ class Accounts(Object):
 
                return self._ldap
 
-       def _search(self, query, attrlist=None, limit=0):
+       def _query(self, query, attrlist=None, limit=0):
                logging.debug("Performing LDAP query: %s" % query)
 
                search_base = self.settings.get("ldap_search_base")
@@ -49,28 +49,41 @@ class Accounts(Object):
 
                return results
 
-       def search(self, query, limit=0):
-               results = self._search(query, limit=limit)
-
+       def _search(self, query, attrlist=None, limit=0):
                accounts = []
-               for dn, attrs in results:
+
+               for dn, attrs in self._query(query, attrlist=attrlist, limit=limit):
                        account = Account(self.backend, dn, attrs)
                        accounts.append(account)
 
+               return accounts
+
+       def search(self, query):
+               # Search for exact matches
+               accounts = self._search("(&(objectClass=posixAccount) \
+                       (|(uid=%s)(mail=%s)(sipAuthenticationUser=%s)(telephoneNumber=%s)(homePhone=%s)(mobile=%s)))" \
+                       % (query, query, query, query, query, query))
+
+               # Find accounts by name
+               if not accounts:
+                       for account in self._search("(&(objectClass=posixAccount)(cn=*%s*))" % query):
+                               if not account in accounts:
+                                       accounts.append(account)
+
                return sorted(accounts)
 
-       def search_one(self, query):
-               result = self.search(query, limit=1)
+       def _search_one(self, query):
+               result = self._search(query, limit=1)
                assert len(result) <= 1
 
                if result:
                        return result[0]
 
        def get_by_uid(self, uid):
-               return self.search_one("(&(objectClass=posixAccount)(uid=%s))" % uid)
+               return self._search_one("(&(objectClass=posixAccount)(uid=%s))" % uid)
 
        def get_by_mail(self, mail):
-               return self.search_one("(&(objectClass=posixAccount)(mail=%s))" % mail)
+               return self._search_one("(&(objectClass=posixAccount)(mail=%s))" % mail)
 
        find = get_by_uid
 
@@ -82,7 +95,7 @@ class Accounts(Object):
                return self.get_by_mail(s)
 
        def get_by_sip_id(self, sip_id):
-               return self.search_one("(|(&(objectClass=sipUser)(sipAuthenticationUser=%s)) \
+               return self._search_one("(|(&(objectClass=sipUser)(sipAuthenticationUser=%s)) \
                        (&(objectClass=sipRoutingObject)(sipLocalAddress=%s)))" % (sip_id, sip_id))
 
        # Session stuff
@@ -218,20 +231,19 @@ class Account(Object):
        def first_name(self):
                return self._get_first_attribute("givenName")
 
-       @property
+       @lazy_property
        def groups(self):
-               if not hasattr(self, "_groups"):
-                       self._groups = []
+               groups = []
 
-                       res = self.accounts._search("(&(objectClass=posixGroup) \
-                               (memberUid=%s))" % self.uid, ["cn"])
+               res = self.accounts._query("(&(objectClass=posixGroup) \
+                       (memberUid=%s))" % self.uid, ["cn"])
 
-                       for dn, attrs in res:
-                               cns = attrs.get("cn")
-                               if cns:
-                                       self._groups.append(cns[0].decode())
+               for dn, attrs in res:
+                       cns = attrs.get("cn")
+                       if cns:
+                               groups.append(cns[0].decode())
 
-               return self._groups
+               return groups
 
        @property
        def address(self):