]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
implementation for <aggregate_fun> FILTER (WHERE ...)
authorIlja Everilä <saarni@gmail.com>
Wed, 10 Sep 2014 08:34:33 +0000 (11:34 +0300)
committerIlja Everilä <saarni@gmail.com>
Wed, 10 Sep 2014 08:34:33 +0000 (11:34 +0300)
lib/sqlalchemy/__init__.py
lib/sqlalchemy/sql/__init__.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/expression.py
lib/sqlalchemy/sql/functions.py

index 8535661723513498b2962513e00aa2bec4038896..1af0de3ba365d6d22919fade38e5acf5e7794c72 100644 (file)
@@ -7,6 +7,7 @@
 
 
 from .sql import (
+    aggregatefilter,
     alias,
     and_,
     asc,
index 4d013859c5d82512a4a2f594eb34af38c77008e8..8fbf1b536c5aa7d6c065c501ff09de3bc2fe7402 100644 (file)
@@ -19,6 +19,7 @@ from .expression import (
     Selectable,
     TableClause,
     Update,
+    aggregatefilter,
     alias,
     and_,
     asc,
index 5149fa4feebcdba26ea2433a3e3528da3823530f..6ebd61e9c8b07329bb6d4e65878289123300522c 100644 (file)
@@ -760,6 +760,12 @@ class SQLCompiler(Compiled):
             )
         )
 
+    def visit_aggregatefilter(self, aggregatefilter, **kwargs):
+        return "%s FILTER (WHERE %s)" % (
+            aggregatefilter.func._compiler_dispatch(self, **kwargs),
+            aggregatefilter.criterion._compiler_dispatch(self, **kwargs)
+        )
+
     def visit_extract(self, extract, **kwargs):
         field = self.extract_map.get(extract.field, extract.field)
         return "EXTRACT(%s FROM %s)" % (
index 8ec0aa7002d74eaf2c21b6ffe9c1fda463c0a90f..5562e80d7b6cd3944e5096c9261fcf95dbe99819 100644 (file)
@@ -2888,6 +2888,71 @@ class Over(ColumnElement):
         ))
 
 
+class AggregateFilter(ColumnElement):
+    """Represent an aggregate FILTER clause.
+
+    This is a special operator against aggregate functions,
+    which controls which rows are passed to it.
+    It's supported only by certain database backends.
+
+    """
+    __visit_name__ = 'aggregatefilter'
+
+    criterion = None
+
+    def __init__(self, func, *criterion):
+        """Produce an :class:`.AggregateFilter` object against a function.
+
+        Used against aggregate functions,
+        for database backends that support aggregate "FILTER" clause.
+
+        E.g.::
+
+        from sqlalchemy import aggregatefilter
+        aggregatefilter(func.count(1), MyClass.name == 'some name')
+
+        Would produce "COUNT(1) FILTER (WHERE myclass.name = 'some name')".
+
+        This function is also available from the :data:`~.expression.func`
+        construct itself via the :meth:`.FunctionElement.filter` method.
+
+        """
+        self.func = func
+        self.filter(*criterion)
+
+    def filter(self, *criterion):
+        for criterion in list(criterion):
+            criterion = _expression_literal_as_text(criterion)
+
+            if self.criterion is not None:
+                self.criterion = self.criterion & criterion
+            else:
+                self.criterion = criterion
+
+        return self
+
+    @util.memoized_property
+    def type(self):
+        return self.func.type
+
+    def get_children(self, **kwargs):
+        return [c for c in
+                (self.func, self.criterion)
+                if c is not None]
+
+    def _copy_internals(self, clone=_clone, **kw):
+        self.func = clone(self.func, **kw)
+        if self.criterion is not None:
+            self.criterion = clone(self.criterion, **kw)
+
+    @property
+    def _from_objects(self):
+        return list(itertools.chain(
+            *[c._from_objects for c in (self.func, self.criterion)
+              if c is not None]
+        ))
+
+
 class Label(ColumnElement):
     """Represents a column label (AS).
 
index d96f048b98edd7ac3ae9c429f291b2cf2bcf560f..7b22cab3e7212433955ee57348d7e3852aca446e 100644 (file)
@@ -36,7 +36,7 @@ from .elements import ClauseElement, ColumnElement,\
     True_, False_, BinaryExpression, Tuple, TypeClause, Extract, \
     Grouping, not_, \
     collate, literal_column, between,\
-    literal, outparam, type_coerce, ClauseList
+    literal, outparam, type_coerce, ClauseList, AggregateFilter
 
 from .elements import SavepointClause, RollbackToSavepointClause, \
     ReleaseSavepointClause
@@ -97,6 +97,8 @@ outerjoin = public_factory(Join._create_outerjoin, ".expression.outerjoin")
 insert = public_factory(Insert, ".expression.insert")
 update = public_factory(Update, ".expression.update")
 delete = public_factory(Delete, ".expression.delete")
+aggregatefilter = public_factory(
+    AggregateFilter, ".expression.aggregatefilter")
 
 
 # internal functions still being called from tests and the ORM,
index 7efb1e916e57e240b1c473553856b7824788feb8..46f3e27dc0590f6c068348345b2c56078d02e7ee 100644 (file)
@@ -12,7 +12,7 @@ from . import sqltypes, schema
 from .base import Executable, ColumnCollection
 from .elements import ClauseList, Cast, Extract, _literal_as_binds, \
     literal_column, _type_from_args, ColumnElement, _clone,\
-    Over, BindParameter
+    Over, BindParameter, AggregateFilter
 from .selectable import FromClause, Select, Alias
 
 from . import operators
@@ -116,6 +116,28 @@ class FunctionElement(Executable, ColumnElement, FromClause):
         """
         return Over(self, partition_by=partition_by, order_by=order_by)
 
+    def filter(self, *criterion):
+        """Produce a FILTER clause against this function.
+
+        Used against aggregate functions,
+        for database backends that support aggregate "FILTER" clause.
+
+        The expression::
+
+            func.count(1).filter(True)
+
+        is shorthand for::
+
+            from sqlalchemy import aggregatefilter
+            aggregatefilter(func.count(1), True)
+
+        See :func:`~.expression.aggregatefilter` for a full description.
+
+        """
+        if not criterion:
+            return self
+        return AggregateFilter(self, *criterion)
+
     @property
     def _from_objects(self):
         return self.clauses._from_objects