]> git.ipfire.org Git - thirdparty/jinja.git/commitdiff
Add ability to specify multiple attributes in sort,
authormayur-srivastava <mayur.p.srivastava@gmail.com>
Fri, 31 May 2019 20:50:15 +0000 (16:50 -0400)
committermayur-srivastava <mayur.p.srivastava@gmail.com>
Fri, 31 May 2019 20:50:15 +0000 (16:50 -0400)
e.g. foo|sort(attribute='name,id').

This fixes #670

jinja2/filters.py
tests/test_filters.py

index 10102b03110de214183cddf4ca6ab381045e6b9d..2245742241154eb253094ae19e94a857c42ecea6 100644 (file)
@@ -84,6 +84,44 @@ def make_attrgetter(environment, attribute, postprocess=None):
     return attrgetter
 
 
+def make_multi_attrgetter(environment, attribute, postprocess=None):
+    """Returns a callable that looks up the given comma separated
+    attributes from a passed object with the rules of the environment.
+    Dots are allowed to access attributes of each attribute.  Integer
+    parts in paths are looked up as integers.
+
+    The value returned by the returned callable is a list of extracted
+    attribute values.
+
+    Examples of attribute: "attr1,attr2", "attr1.inner1.0,attr2.inner2.0", etc.
+    """
+    def _prepare_attribute_parts(attr):
+        if attr is None:
+            return []
+        elif isinstance(attribute, string_types):
+            return [int(x) if x.isdigit() else x for x in attr.split('.')]
+        else:
+            return [attr]
+
+    attribute_parts = attribute.split(',') if isinstance(attribute, string_types) else [attribute]
+    attribute = [_prepare_attribute_parts(attribute_part) for attribute_part in attribute_parts]
+
+    def attrgetter(item):
+        items = [None] * len(attribute)
+        for i, attribute_part in enumerate(attribute):
+            item_i = item
+            for part in attribute_part:
+                item_i = environment.getitem(item_i, part)
+
+            if postprocess is not None:
+                item_i = postprocess(item_i)
+
+            items[i] = item_i
+        return items
+
+    return attrgetter
+
+
 def do_forceescape(value):
     """Enforce HTML escaping.  This will probably double escape variables."""
     if hasattr(value, '__html__'):
@@ -270,8 +308,10 @@ def do_sort(
 
     .. versionchanged:: 2.6
        The `attribute` parameter was added.
+       The attribute parameter can contain multiple comma separated
+       attributes, e.g. attr1,attr2.
     """
-    key_func = make_attrgetter(
+    key_func = make_multi_attrgetter(
         environment, attribute,
         postprocess=ignore_case if not case_sensitive else None
     )
index 60808f2edbd0243961f0fd4d406dfceca5497d97..0a337b6327a4c9b4debafec2514df734381e8ecb 100644 (file)
@@ -23,6 +23,16 @@ class Magic(object):
         return text_type(self.value)
 
 
+@implements_to_string
+class Magic2(object):
+    def __init__(self, value1, value2):
+        self.value1 = value1
+        self.value2 = value2
+
+    def __str__(self):
+        return u'(%s,%s)' % (text_type(self.value1), text_type(self.value2))
+
+
 @pytest.mark.filter
 class TestFilter(object):
 
@@ -417,6 +427,29 @@ class TestFilter(object):
         tmpl = env.from_string('''{{ items|sort(attribute='value')|join }}''')
         assert tmpl.render(items=map(Magic, [3, 2, 4, 1])) == '1234'
 
+    def test_sort5(self, env):
+        tmpl = env.from_string('''{{ items|sort(attribute='value.0')|join }}''')
+        assert tmpl.render(items=map(Magic, [[3], [2], [4], [1]])) == '[1][2][3][4]'
+
+    def test_sort6(self, env):
+        tmpl = env.from_string('''{{ items|sort(attribute='value1,value2')|join }}''')
+        assert (tmpl.render(items=map(
+            lambda x: Magic2(x[0], x[1]), [(3, 1), (2, 2), (2, 1), (2, 5)]))
+            == '(2,1)(2,2)(2,5)(3,1)')
+
+    def test_sort7(self, env):
+        tmpl = env.from_string('''{{ items|sort(attribute='value2,value1')|join }}''')
+        assert (tmpl.render(items=map(lambda x: Magic2(x[0], x[1]), [(3, 1), (2, 2), (2, 1), (2, 5)])) ==
+                '(2,1)(3,1)(2,2)(2,5)')
+
+    def test_sort8(self, env):
+        tmpl = env.from_string(
+            '''{{ items|sort(attribute='value1.0,value2.0')|join }}''')
+        assert (tmpl.render(items=map(
+            lambda x: Magic2(x[0], x[1]),
+            [([3], [1]), ([2], [2]), ([2], [1]), ([2], [5])]))
+            == '([2],[1])([2],[2])([2],[5])([3],[1])')
+
     def test_unique(self, env):
         t = env.from_string('{{ "".join(["b", "A", "a", "b"]|unique) }}')
         assert t.render() == "bA"