From: Audrius Kažukauskas Date: Mon, 28 Jan 2013 17:58:06 +0000 (+0200) Subject: Add ANY/ALL construct support for PostgreSQL's ARRAY type X-Git-Tag: rel_0_8_0~27^2~3 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=dbdf4f25e2b1054e8f843f8ed0256ece86d68080;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Add ANY/ALL construct support for PostgreSQL's ARRAY type --- diff --git a/doc/build/dialects/postgresql.rst b/doc/build/dialects/postgresql.rst index c016c57a5d..df141cce00 100644 --- a/doc/build/dialects/postgresql.rst +++ b/doc/build/dialects/postgresql.rst @@ -14,9 +14,9 @@ they originate from :mod:`sqlalchemy.types` or from the local dialect:: from sqlalchemy.dialects.postgresql import \ ARRAY, BIGINT, BIT, BOOLEAN, BYTEA, CHAR, CIDR, DATE, \ - DOUBLE_PRECISION, ENUM, FLOAT, INET, INTEGER, INTERVAL, \ - MACADDR, NUMERIC, REAL, SMALLINT, TEXT, TIME, TIMESTAMP, \ - UUID, VARCHAR + DOUBLE_PRECISION, ENUM, FLOAT, HSTORE, INET, INTEGER, \ + INTERVAL, MACADDR, NUMERIC, REAL, SMALLINT, TEXT, TIME, \ + TIMESTAMP, UUID, VARCHAR Types which are specific to PostgreSQL, or have PostgreSQL-specific construction arguments, are as follows: @@ -29,6 +29,10 @@ construction arguments, are as follows: :members: __init__, Comparator :show-inheritance: +.. autoclass:: Any + +.. autoclass:: All + .. autoclass:: BIT :members: __init__ :show-inheritance: diff --git a/lib/sqlalchemy/dialects/postgresql/__init__.py b/lib/sqlalchemy/dialects/postgresql/__init__.py index 3b35bbcdb3..5dc1e555c6 100644 --- a/lib/sqlalchemy/dialects/postgresql/__init__.py +++ b/lib/sqlalchemy/dialects/postgresql/__init__.py @@ -11,7 +11,7 @@ base.dialect = psycopg2.dialect from .base import \ INTEGER, BIGINT, SMALLINT, VARCHAR, CHAR, TEXT, NUMERIC, FLOAT, REAL, \ INET, CIDR, UUID, BIT, MACADDR, DOUBLE_PRECISION, TIMESTAMP, TIME, \ - DATE, BYTEA, BOOLEAN, INTERVAL, ARRAY, ENUM, dialect, array + DATE, BYTEA, BOOLEAN, INTERVAL, ARRAY, ENUM, dialect, array, Any, All from .hstore import HSTORE, hstore __all__ = ( diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 3de727e946..de150f03fe 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -365,6 +365,40 @@ class _Slice(expression.ColumnElement): operators.getitem, slice_.stop) +class Any(expression.ColumnElement): + """Return the clause ``left operator ANY (right)``. ``right`` must be + an array expression. + + See also: + + :class:`.postgresql.ARRAY` + """ + __visit_name__ = 'any' + + def __init__(self, left, right, operator=operators.eq): + self.type = sqltypes.Boolean() + self.left = expression._literal_as_binds(left) + self.right = right + self.operator = operator + + +class All(expression.ColumnElement): + """Return the clause ``left operator ALL (right)``. ``right`` must be + an array expression. + + See also: + + :class:`.postgresql.ARRAY` + """ + __visit_name__ = 'all' + + def __init__(self, left, right, operator=operators.eq): + self.type = sqltypes.Boolean() + self.left = expression._literal_as_binds(left) + self.right = right + self.operator = operator + + class array(expression.Tuple): """A Postgresql ARRAY literal. @@ -502,6 +536,20 @@ class ARRAY(sqltypes.Concatenable, sqltypes.TypeEngine): return self._binary_operate(self.expr, operators.getitem, index, result_type=return_type) + def any(self, other, operator=operators.eq): + """Return ``other operator ANY (array)`` clause. Argument places + are switched, because ANY requires array expression to be on the + right hand-side. + """ + return Any(other, self.expr, operator=operator) + + def all(self, other, operator=operators.eq): + """Return ``other operator ALL (array)`` clause. Argument places + are switched, because ALL requires array expression to be on the + right hand-side. + """ + return All(other, self.expr, operator=operator) + def contains(self, other, **kwargs): """Boolean expression. Test if elements are a superset of the elements of the argument array expression. @@ -807,6 +855,20 @@ class PGCompiler(compiler.SQLCompiler): self.process(element.stop, **kw), ) + def visit_any(self, element, **kw): + return "%s%sANY (%s)" % ( + self.process(element.left, **kw), + compiler.OPERATORS[element.operator], + self.process(element.right, **kw) + ) + + def visit_all(self, element, **kw): + return "%s%sALL (%s)" % ( + self.process(element.left, **kw), + compiler.OPERATORS[element.operator], + self.process(element.right, **kw) + ) + def visit_getitem_binary(self, binary, operator, **kw): return "%s[%s]" % ( self.process(binary.left, **kw), diff --git a/test/dialect/test_postgresql.py b/test/dialect/test_postgresql.py index ff958e7b86..3337fa6ab7 100644 --- a/test/dialect/test_postgresql.py +++ b/test/dialect/test_postgresql.py @@ -20,7 +20,7 @@ from sqlalchemy.dialects.postgresql import base as postgresql from sqlalchemy.dialects.postgresql import HSTORE, hstore, array, ARRAY from sqlalchemy.util.compat import decimal from sqlalchemy.testing.util import round_decimal -from sqlalchemy.sql import table, column +from sqlalchemy.sql import table, column, operators import logging import re @@ -290,6 +290,26 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): 'x && %(x_1)s', checkparams={'x_1': [3]} ) + self.assert_compile( + postgresql.Any(4, c), + '%(param_1)s = ANY (x)', + checkparams={'param_1': 4} + ) + self.assert_compile( + c.any(5, operator=operators.ne), + '%(param_1)s != ANY (x)', + checkparams={'param_1': 5} + ) + self.assert_compile( + postgresql.All(6, c, operator=operators.gt), + '%(param_1)s > ALL (x)', + checkparams={'param_1': 6} + ) + self.assert_compile( + c.all(7, operator=operators.lt), + '%(param_1)s < ALL (x)', + checkparams={'param_1': 7} + ) def test_array_literal_type(self): is_(postgresql.array([1, 2]).type._type_affinity, postgresql.ARRAY) @@ -2274,6 +2294,34 @@ class ArrayTest(fixtures.TestBase, AssertsExecutionResults): [4, 5, 6] ) + def test_array_any_exec(self): + with testing.db.connect() as conn: + conn.execute( + arrtable.insert(), + intarr=[4, 5, 6] + ) + eq_( + conn.scalar( + select([arrtable.c.intarr]). + where(postgresql.Any(5, arrtable.c.intarr)) + ), + [4, 5, 6] + ) + + def test_array_all_exec(self): + with testing.db.connect() as conn: + conn.execute( + arrtable.insert(), + intarr=[4, 5, 6] + ) + eq_( + conn.scalar( + select([arrtable.c.intarr]). + where(arrtable.c.intarr.all(4, operator=operators.le)) + ), + [4, 5, 6] + ) + @testing.provide_metadata def test_tuple_flag(self): metadata = self.metadata