]> git.ipfire.org Git - ipfire.org.git/blobdiff - src/backend/accounts.py
accounts: Implement page searches for LDAP
[ipfire.org.git] / src / backend / accounts.py
index ee7fca0870e4d3652c173e26f4a08ec19d1b9ae9..b1160ac2e76df782ee1a325185fd02f3a3c45ae6 100644 (file)
@@ -220,16 +220,47 @@ class Accounts(Object):
 
                t = time.time()
 
-               results = self.ldap.search_ext_s(search_base or self.search_base,
-                       ldap.SCOPE_SUBTREE, query, attrlist=attrlist, sizelimit=limit)
+               # Ask for up to 512 results being returned at a time
+               page_control = ldap.controls.SimplePagedResultsControl(True, size=512, cookie="")
+
+               results = []
+               pages = 0
+
+               # Perform the search
+               while True:
+                       response = self.ldap.search_ext(search_base or self.search_base,
+                               ldap.SCOPE_SUBTREE, query, attrlist=attrlist, sizelimit=limit,
+                               serverctrls=[page_control],
+                       )
+
+                       # Fetch all results
+                       type, data, rmsgid, serverctrls = self.ldap.result3(response)
+
+                       # Append to local copy
+                       results += data
+                       pages += 1
+
+                       controls = [c for c in serverctrls
+                               if c.controlType == ldap.controls.SimplePagedResultsControl.controlType]
+
+                       if not controls:
+                               logging.warning("The server ignores RFC 2696 control")
+                               break
+
+                       # Set the cookie for more results
+                       page_control.cookie = controls[0].cookie
+
+                       # There are no more results
+                       if not page_control.cookie:
+                               break
 
                # Log time it took to perform the query
-               logging.debug("Query took %.2fms" % ((time.time() - t) * 1000.0))
+               logging.debug("Query took %.2fms (%s page(s))" % ((time.time() - t) * 1000.0, pages))
 
                return results
 
        def _count(self, query):
-               res = self._query(query, attrlist=["dn"], limit=INT_MAX)
+               res = self._query(query, attrlist=["dn"])
 
                return len(res)