From 8134f2d0bfadc461b015104ff842f57962d89b0a Mon Sep 17 00:00:00 2001 From: =?utf8?q?Audrius=20Ka=C5=BEukauskas?= Date: Tue, 20 Nov 2012 23:24:34 +0200 Subject: [PATCH] Add special containment operation methods for PG array type --- doc/build/dialects/postgresql.rst | 2 +- lib/sqlalchemy/dialects/postgresql/base.py | 42 ++++++++++- test/dialect/test_postgresql.py | 81 ++++++++++++++++++---- 3 files changed, 109 insertions(+), 16 deletions(-) diff --git a/doc/build/dialects/postgresql.rst b/doc/build/dialects/postgresql.rst index ac89ab1234..c016c57a5d 100644 --- a/doc/build/dialects/postgresql.rst +++ b/doc/build/dialects/postgresql.rst @@ -26,7 +26,7 @@ construction arguments, are as follows: .. autoclass:: array .. autoclass:: ARRAY - :members: __init__ + :members: __init__, Comparator :show-inheritance: .. autoclass:: BIT diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index c7e84751d8..ed24bc1fe0 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -464,12 +464,21 @@ class ARRAY(sqltypes.Concatenable, sqltypes.TypeEngine): as well as UPDATE statements when the :meth:`.Update.values` method is used:: - mytable.update().values({mytable.c.data[5]:7, - mytable.c.data[2:7]:[1,2,3]}) + mytable.update().values({ + mytable.c.data[5]: 7, + mytable.c.data[2:7]: [1, 2, 3] + }) + + :class:`.ARRAY` provides special methods for containment operations, + e.g.:: + + mytable.c.data.contains([1, 2]) + + For a full list of special methods see :class:`.ARRAY.Comparator`. .. versionadded:: 0.8 Added support for index and slice operations to the :class:`.ARRAY` type, including support for UPDATE - statements. + statements, and special array containment operations. The :class:`.ARRAY` type may not be supported on all DBAPIs. It is known to work on psycopg2 and not pg8000. @@ -482,6 +491,8 @@ class ARRAY(sqltypes.Concatenable, sqltypes.TypeEngine): __visit_name__ = 'ARRAY' class Comparator(sqltypes.Concatenable.Comparator): + """Define comparison operations for :class:`.ARRAY`.""" + def __getitem__(self, index): if isinstance(index, slice): index = _Slice(index, self) @@ -491,6 +502,31 @@ class ARRAY(sqltypes.Concatenable, sqltypes.TypeEngine): return self._binary_operate(self.expr, operators.getitem, index, result_type=return_type) + def contains(self, other, **kwargs): + """Boolean expression. Test if elements are a superset of the + elements of the argument array expression. + """ + return self.expr.op('@>')(other) + + def contained_by(self, other): + """Boolean expression. Test if elements are a proper subset of the + elements of the argument array expression. + """ + return self.expr.op('<@')(other) + + def overlap(self, other): + """Boolean expression. Test if array has elements in common with + an argument array expression. + """ + return self.expr.op('&&')(other) + + def _adapt_expression(self, op, other_comparator): + if isinstance(op, operators.custom_op): + if op.opstring in ['@>', '<@', '&&']: + return op, sqltypes.Boolean + return sqltypes.Concatenable.Comparator.\ + _adapt_expression(self, op, other_comparator) + comparator_factory = Comparator def __init__(self, item_type, as_tuple=False, dimensions=None): diff --git a/test/dialect/test_postgresql.py b/test/dialect/test_postgresql.py index fe45c441c8..46dab3df91 100644 --- a/test/dialect/test_postgresql.py +++ b/test/dialect/test_postgresql.py @@ -290,25 +290,41 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): "CAST(x AS INTEGER[])" ) self.assert_compile( - c[5], - "x[%(x_1)s]", - checkparams={'x_1': 5} + c[5], + "x[%(x_1)s]", + checkparams={'x_1': 5} ) self.assert_compile( - c[5:7], - "x[%(x_1)s:%(x_2)s]", - checkparams={'x_2': 7, 'x_1': 5} + c[5:7], + "x[%(x_1)s:%(x_2)s]", + checkparams={'x_2': 7, 'x_1': 5} ) self.assert_compile( - c[5:7][2:3], - "x[%(x_1)s:%(x_2)s][%(param_1)s:%(param_2)s]", - checkparams={'x_2': 7, 'x_1': 5, 'param_1':2, 'param_2':3} + c[5:7][2:3], + "x[%(x_1)s:%(x_2)s][%(param_1)s:%(param_2)s]", + checkparams={'x_2': 7, 'x_1': 5, 'param_1':2, 'param_2':3} ) self.assert_compile( - c[5:7][3], - "x[%(x_1)s:%(x_2)s][%(param_1)s]", - checkparams={'x_2': 7, 'x_1': 5, 'param_1':3} + c[5:7][3], + "x[%(x_1)s:%(x_2)s][%(param_1)s]", + checkparams={'x_2': 7, 'x_1': 5, 'param_1':3} + ) + + self.assert_compile( + c.contains([1]), + 'x @> %(x_1)s', + checkparams={'x_1': [1]} + ) + self.assert_compile( + c.contained_by([2]), + 'x <@ %(x_1)s', + checkparams={'x_1': [2]} + ) + self.assert_compile( + c.overlap([3]), + 'x && %(x_1)s', + checkparams={'x_1': [3]} ) def test_array_literal_type(self): @@ -2244,6 +2260,47 @@ class ArrayTest(fixtures.TestBase, AssertsExecutionResults): [7, 8] ) + def test_array_contains_exec(self): + with testing.db.connect() as conn: + conn.execute( + arrtable.insert(), + intarr=[4, 5, 6] + ) + eq_( + conn.scalar( + select([arrtable.c.intarr]). + where(arrtable.c.intarr.contains([4, 5])) + ), + [4, 5, 6] + ) + + def test_array_contained_by_exec(self): + with testing.db.connect() as conn: + conn.execute( + arrtable.insert(), + intarr=[6, 5, 4] + ) + eq_( + conn.scalar( + select([arrtable.c.intarr.contained_by([4, 5, 6, 7])]) + ), + True + ) + + def test_array_overlap_exec(self): + with testing.db.connect() as conn: + conn.execute( + arrtable.insert(), + intarr=[4, 5, 6] + ) + eq_( + conn.scalar( + select([arrtable.c.intarr]). + where(arrtable.c.intarr.overlap([7, 6])) + ), + [4, 5, 6] + ) + @testing.provide_metadata def test_tuple_flag(self): metadata = self.metadata -- 2.47.2