]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Add ANY/ALL construct support for PostgreSQL's ARRAY type
authorAudrius Kažukauskas <audrius@neutrino.lt>
Mon, 28 Jan 2013 17:58:06 +0000 (19:58 +0200)
committerAudrius Kažukauskas <audrius@neutrino.lt>
Mon, 28 Jan 2013 17:58:06 +0000 (19:58 +0200)
doc/build/dialects/postgresql.rst
lib/sqlalchemy/dialects/postgresql/__init__.py
lib/sqlalchemy/dialects/postgresql/base.py
test/dialect/test_postgresql.py

index c016c57a5d18536ba6b221c678bb193955c6d377..df141cce00cc7acc7547dde628a7641c93f503d2 100644 (file)
@@ -14,9 +14,9 @@ they originate from :mod:`sqlalchemy.types` or from the local dialect::
 
     from sqlalchemy.dialects.postgresql import \
         ARRAY, BIGINT, BIT, BOOLEAN, BYTEA, CHAR, CIDR, DATE, \
-        DOUBLE_PRECISION, ENUM, FLOAT, INET, INTEGER, INTERVAL, \
-        MACADDR, NUMERIC, REAL, SMALLINT, TEXT, TIME, TIMESTAMP, \
-        UUID, VARCHAR
+        DOUBLE_PRECISION, ENUM, FLOAT, HSTORE, INET, INTEGER, \
+        INTERVAL, MACADDR, NUMERIC, REAL, SMALLINT, TEXT, TIME, \
+        TIMESTAMP, UUID, VARCHAR
 
 Types which are specific to PostgreSQL, or have PostgreSQL-specific
 construction arguments, are as follows:
@@ -29,6 +29,10 @@ construction arguments, are as follows:
     :members: __init__, Comparator
     :show-inheritance:
 
+.. autoclass:: Any
+
+.. autoclass:: All
+
 .. autoclass:: BIT
     :members: __init__
     :show-inheritance:
index 3b35bbcdb3938eabb32ffd8be72f99d9ac044f46..5dc1e555c6f50de159dac2a231a3da8545b6e0e8 100644 (file)
@@ -11,7 +11,7 @@ base.dialect = psycopg2.dialect
 from .base import \
     INTEGER, BIGINT, SMALLINT, VARCHAR, CHAR, TEXT, NUMERIC, FLOAT, REAL, \
     INET, CIDR, UUID, BIT, MACADDR, DOUBLE_PRECISION, TIMESTAMP, TIME, \
-    DATE, BYTEA, BOOLEAN, INTERVAL, ARRAY, ENUM, dialect, array
+    DATE, BYTEA, BOOLEAN, INTERVAL, ARRAY, ENUM, dialect, array, Any, All
 from .hstore import HSTORE, hstore
 
 __all__ = (
index 3de727e94630b379d387c98a3efc3c1bb58d5075..de150f03fe33b800dd6ede6d76e80e0553ed49f3 100644 (file)
@@ -365,6 +365,40 @@ class _Slice(expression.ColumnElement):
                             operators.getitem, slice_.stop)
 
 
+class Any(expression.ColumnElement):
+    """Return the clause ``left operator ANY (right)``.  ``right`` must be
+    an array expression.
+
+    See also:
+
+    :class:`.postgresql.ARRAY`
+    """
+    __visit_name__ = 'any'
+
+    def __init__(self, left, right, operator=operators.eq):
+        self.type = sqltypes.Boolean()
+        self.left = expression._literal_as_binds(left)
+        self.right = right
+        self.operator = operator
+
+
+class All(expression.ColumnElement):
+    """Return the clause ``left operator ALL (right)``.  ``right`` must be
+    an array expression.
+
+    See also:
+
+    :class:`.postgresql.ARRAY`
+    """
+    __visit_name__ = 'all'
+
+    def __init__(self, left, right, operator=operators.eq):
+        self.type = sqltypes.Boolean()
+        self.left = expression._literal_as_binds(left)
+        self.right = right
+        self.operator = operator
+
+
 class array(expression.Tuple):
     """A Postgresql ARRAY literal.
 
@@ -502,6 +536,20 @@ class ARRAY(sqltypes.Concatenable, sqltypes.TypeEngine):
             return self._binary_operate(self.expr, operators.getitem, index,
                             result_type=return_type)
 
+        def any(self, other, operator=operators.eq):
+            """Return ``other operator ANY (array)`` clause.  Argument places
+            are switched, because ANY requires array expression to be on the
+            right hand-side.
+            """
+            return Any(other, self.expr, operator=operator)
+
+        def all(self, other, operator=operators.eq):
+            """Return ``other operator ALL (array)`` clause.  Argument places
+            are switched, because ALL requires array expression to be on the
+            right hand-side.
+            """
+            return All(other, self.expr, operator=operator)
+
         def contains(self, other, **kwargs):
             """Boolean expression.  Test if elements are a superset of the
             elements of the argument array expression.
@@ -807,6 +855,20 @@ class PGCompiler(compiler.SQLCompiler):
                     self.process(element.stop, **kw),
                 )
 
+    def visit_any(self, element, **kw):
+        return "%s%sANY (%s)" % (
+            self.process(element.left, **kw),
+            compiler.OPERATORS[element.operator],
+            self.process(element.right, **kw)
+        )
+
+    def visit_all(self, element, **kw):
+        return "%s%sALL (%s)" % (
+            self.process(element.left, **kw),
+            compiler.OPERATORS[element.operator],
+            self.process(element.right, **kw)
+        )
+
     def visit_getitem_binary(self, binary, operator, **kw):
         return "%s[%s]" % (
                 self.process(binary.left, **kw),
index ff958e7b86a6963ceaf7f27012da299bad7c58e3..3337fa6ab7b17c7129e41f8ac4aee07814a79e35 100644 (file)
@@ -20,7 +20,7 @@ from sqlalchemy.dialects.postgresql import base as postgresql
 from sqlalchemy.dialects.postgresql import HSTORE, hstore, array, ARRAY
 from sqlalchemy.util.compat import decimal
 from sqlalchemy.testing.util import round_decimal
-from sqlalchemy.sql import table, column
+from sqlalchemy.sql import table, column, operators
 import logging
 import re
 
@@ -290,6 +290,26 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
             'x && %(x_1)s',
             checkparams={'x_1': [3]}
         )
+        self.assert_compile(
+            postgresql.Any(4, c),
+            '%(param_1)s = ANY (x)',
+            checkparams={'param_1': 4}
+        )
+        self.assert_compile(
+            c.any(5, operator=operators.ne),
+            '%(param_1)s != ANY (x)',
+            checkparams={'param_1': 5}
+        )
+        self.assert_compile(
+            postgresql.All(6, c, operator=operators.gt),
+            '%(param_1)s > ALL (x)',
+            checkparams={'param_1': 6}
+        )
+        self.assert_compile(
+            c.all(7, operator=operators.lt),
+            '%(param_1)s < ALL (x)',
+            checkparams={'param_1': 7}
+        )
 
     def test_array_literal_type(self):
         is_(postgresql.array([1, 2]).type._type_affinity, postgresql.ARRAY)
@@ -2274,6 +2294,34 @@ class ArrayTest(fixtures.TestBase, AssertsExecutionResults):
                 [4, 5, 6]
             )
 
+    def test_array_any_exec(self):
+        with testing.db.connect() as conn:
+            conn.execute(
+                arrtable.insert(),
+                intarr=[4, 5, 6]
+            )
+            eq_(
+                conn.scalar(
+                    select([arrtable.c.intarr]).
+                        where(postgresql.Any(5, arrtable.c.intarr))
+                ),
+                [4, 5, 6]
+            )
+
+    def test_array_all_exec(self):
+        with testing.db.connect() as conn:
+            conn.execute(
+                arrtable.insert(),
+                intarr=[4, 5, 6]
+            )
+            eq_(
+                conn.scalar(
+                    select([arrtable.c.intarr]).
+                        where(arrtable.c.intarr.all(4, operator=operators.le))
+                ),
+                [4, 5, 6]
+            )
+
     @testing.provide_metadata
     def test_tuple_flag(self):
         metadata = self.metadata