]> git.ipfire.org Git - thirdparty/samba.git/commitdiff
ldb sort: allow sorting on attributes not returned in search
authorDouglas Bagnall <douglas.bagnall@catalyst.net.nz>
Tue, 8 Mar 2016 01:43:40 +0000 (14:43 +1300)
committerAndrew Bartlett <abartlet@samba.org>
Wed, 9 Mar 2016 09:32:17 +0000 (10:32 +0100)
The attribute is added to the search request, then peeled off again
before the sort module passes the results on.

Signed-off-by: Douglas Bagnall <douglas.bagnall@catalyst.net.nz>
Reviewed-by: Andrew Bartlett <abartlet@samba.org>
Reviewed-by: Garming Sam <garming@catalyst.net.nz>
lib/ldb/modules/sort.c
source4/dsdb/tests/python/sort.py

index 1b762f7e51b56370e4bdf7159f5eb0392b69a4b1..19cf60b776b0bfc026e4a026a0eef0dec4c69167 100644 (file)
@@ -56,6 +56,7 @@ struct sort_context {
        char **referrals;
        unsigned int num_msgs;
        unsigned int num_refs;
+       const char *extra_sort_key;
 
        const struct ldb_schema_attribute *a;
        int sort_result;
@@ -162,7 +163,9 @@ static int server_sort_results(struct sort_context *ac)
 
                ares->type = LDB_REPLY_ENTRY;
                ares->message = talloc_move(ares, &ac->msgs[i]);
-
+               if (ac->extra_sort_key) {
+                       ldb_msg_remove_attr(ares->message, ac->extra_sort_key);
+               }
                ret = ldb_module_send_entry(ac->req, ares->message, ares->controls);
                if (ret != LDB_SUCCESS) {
                        return ret;
@@ -256,6 +259,9 @@ static int server_sort_search(struct ldb_module *module, struct ldb_request *req
        struct sort_context *ac;
        struct ldb_context *ldb;
        int ret;
+       const char * const *attrs;
+       size_t n_attrs, i;
+       const char *sort_attr;
 
        ldb = ldb_module_get_ctx(module);
 
@@ -303,6 +309,40 @@ static int server_sort_search(struct ldb_module *module, struct ldb_request *req
                }
        }
 
+       /* We are asked to sort on an attribute, and if that attribute is not
+          already in the search attributes we need to add it (and later
+          remove it on the return journey).
+       */
+       sort_attr = sort_ctrls[0]->attributeName;
+       if (req->op.search.attrs == NULL) {
+               /* This means all non-operational attributes, which means
+                  there's nothing to add. */
+               attrs = NULL;
+       } else {
+               n_attrs = 0;
+               while (req->op.search.attrs[n_attrs] != NULL) {
+                       if (sort_attr &&
+                           strcmp(req->op.search.attrs[n_attrs], sort_attr) == 0) {
+                               sort_attr = NULL;
+                       }
+                       n_attrs++;
+               }
+
+               if (sort_attr == NULL) {
+                       attrs = req->op.search.attrs;
+               } else {
+                       const char **tmp = talloc_array(ac, const char *, n_attrs + 2);
+
+                       for (i = 0; i < n_attrs; i++) {
+                               tmp[i] = req->op.search.attrs[i];
+                       }
+                       ac->extra_sort_key = sort_attr;
+                       tmp[n_attrs] = sort_attr;
+                       tmp[n_attrs + 1] = NULL;
+                       attrs = tmp;
+               }
+       }
+
        ac->attributeName = sort_ctrls[0]->attributeName;
        ac->orderingRule = sort_ctrls[0]->orderingRule;
        ac->reverse = sort_ctrls[0]->reverse;
@@ -311,7 +351,7 @@ static int server_sort_search(struct ldb_module *module, struct ldb_request *req
                                        req->op.search.base,
                                        req->op.search.scope,
                                        req->op.search.tree,
-                                       req->op.search.attrs,
+                                       attrs,
                                        req->controls,
                                        ac,
                                        server_sort_search_callback,
index 436cb8c528543c9ef0b1dd0363038a9d4a3be874..c4d2c44526857c4755719c44a1a165295c512721 100644 (file)
@@ -273,9 +273,62 @@ class BaseSortTests(samba.tests.TestCase):
 
                     self.assertEquals(expected_order, received_order)
 
+    def _test_server_sort_different_attr(self):
+
+        def cmp_locale(a, b):
+            return locale.strcoll(a[0], b[0])
+
+        def cmp_binary(a, b):
+            return cmp(a[0], b[0])
+
+        def cmp_numeric(a, b):
+            return cmp(int(a[0]), int(b[0]))
+
+        # For testing simplicity, the attributes in here need to be
+        # unique for each user. Otherwise there are multiple possible
+        # valid answers.
+        sort_functions = {'cn': cmp_binary,
+                          "employeeNumber": cmp_locale,
+                          "accountExpires": cmp_numeric,
+                          "msTSExpireDate4":cmp_binary}
+        attrs = sort_functions.keys()
+        attr_pairs = zip(attrs, attrs[1:] + attrs[:1])
+
+        for sort_attr, result_attr in attr_pairs:
+            forward = sorted(((norm(x[sort_attr]), norm(x[result_attr]))
+                             for x in self.users),
+                             cmp=sort_functions[sort_attr])
+            reverse = list(reversed(forward))
+
+            for rev in (0, 1):
+                res = self.ldb.search(self.ou,
+                                      scope=ldb.SCOPE_ONELEVEL,
+                                      attrs=[result_attr],
+                                      controls=["server_sort:1:%d:%s" %
+                                                (rev, sort_attr)])
+                self.assertEqual(len(res), len(self.users))
+
+                expected_order = [x[1] for x in (forward, reverse)[rev]]
+                received_order = [norm(x[result_attr][0]) for x in res]
+
+                if expected_order != received_order:
+                    print sort_attr, result_attr, ['forward', 'reverse'][rev]
+                    print "expected", expected_order
+                    print "recieved", received_order
+                    print "unnormalised:", [x[result_attr][0] for x in res]
+                    print "unnormalised: «%s»" % '»  «'.join(x[result_attr][0]
+                                                             for x in res)
+                self.assertEquals(expected_order, received_order)
+                for x in res:
+                    if sort_attr in x:
+                        self.fail('the search for %s should not return %s' %
+                                  (result_attr, sort_attr))
+
 
 class SimpleSortTests(BaseSortTests):
     avoid_tricky_sort = True
+    def test_server_sort_different_attr(self):
+        self._test_server_sort_different_attr()
 
     def test_server_sort_default(self):
         self._test_server_sort_default()
@@ -296,6 +349,9 @@ class UnicodeSortTests(BaseSortTests):
     def test_server_sort_us_english(self):
         self._test_server_sort_us_english()
 
+    def test_server_sort_different_attr(self):
+        self._test_server_sort_different_attr()
+
 
 if "://" not in host:
     if os.path.isfile(host):