]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Add "eager_parenthesis" late-compilation rule, use w/ PG JSON/HSTORE
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 30 Sep 2016 14:09:56 +0000 (10:09 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 1 Oct 2016 13:46:11 +0000 (09:46 -0400)
Added compiler-level flags used by Postgresql to place additional
parenthesis than would normally be generated by precedence rules
around operations involving JSON, HSTORE indexing operators as well as
within their operands since it has been observed that Postgresql's
precedence rules for at least the HSTORE indexing operator is not
consistent between 9.4 and 9.5.

Fixes: #3806
Change-Id: I5899677b330595264543b055abd54f3c76bfabf2

doc/build/changelog/changelog_11.rst
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/dialects/postgresql/hstore.py
lib/sqlalchemy/dialects/postgresql/json.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/operators.py
test/dialect/postgresql/test_types.py
test/sql/test_operators.py

index af217170795967c5bd665ab5fd7bba830b7fe17d..c2dd2d84837dd919e14efe9cabc656d2346cf213 100644 (file)
         those of the "excluded" namespace would not be table-qualified
         in the WHERE clauses in the statement.
 
+     .. change::
+        :tags: bug, sql, postgresql
+        :tickets: 3806
+
+        Added compiler-level flags used by Postgresql to place additional
+        parenthesis than would normally be generated by precedence rules
+        around operations involving JSON, HSTORE indexing operators as well as
+        within their operands since it has been observed that Postgresql's
+        precedence rules for at least the HSTORE indexing operator is not
+        consistent between 9.4 and 9.5.
+
     .. change::
         :tags: bug, sql, mysql
         :tickets: 3803
index a9f11aae01c6f4cd92917a1186bbb87540f290d4..bde855fbe651c7524b53e574a38e0c25b6c39697 100644 (file)
@@ -1270,11 +1270,13 @@ class PGCompiler(compiler.SQLCompiler):
         )
 
     def visit_json_getitem_op_binary(self, binary, operator, **kw):
+        kw['eager_grouping'] = True
         return self._generate_generic_binary(
             binary, " -> ", **kw
         )
 
     def visit_json_path_getitem_op_binary(self, binary, operator, **kw):
+        kw['eager_grouping'] = True
         return self._generate_generic_binary(
             binary, " #> ", **kw
         )
index 67923fe39ecd604e1f121122a07e89fbcbcd25df..d3ff30efbecac13bdd043bb0a0a51d178c6747be 100644 (file)
@@ -16,29 +16,36 @@ from ... import util
 
 __all__ = ('HSTORE', 'hstore')
 
+idx_precedence = operators._PRECEDENCE[operators.json_getitem_op]
 
 GETITEM = operators.custom_op(
-    "->", precedence=15, natural_self_precedent=True,
+    "->", precedence=idx_precedence, natural_self_precedent=True,
+    eager_grouping=True
 )
 
 HAS_KEY = operators.custom_op(
-    "?", precedence=15, natural_self_precedent=True
+    "?", precedence=idx_precedence, natural_self_precedent=True,
+    eager_grouping=True
 )
 
 HAS_ALL = operators.custom_op(
-    "?&", precedence=15, natural_self_precedent=True
+    "?&", precedence=idx_precedence, natural_self_precedent=True,
+    eager_grouping=True
 )
 
 HAS_ANY = operators.custom_op(
-    "?|", precedence=15, natural_self_precedent=True
+    "?|", precedence=idx_precedence, natural_self_precedent=True,
+    eager_grouping=True
 )
 
 CONTAINS = operators.custom_op(
-    "@>", precedence=15, natural_self_precedent=True
+    "@>", precedence=idx_precedence, natural_self_precedent=True,
+    eager_grouping=True
 )
 
 CONTAINED_BY = operators.custom_op(
-    "<@", precedence=15, natural_self_precedent=True
+    "<@", precedence=idx_precedence, natural_self_precedent=True,
+    eager_grouping=True
 )
 
 
index 05c4d014d3a8a75f1584fd2cf27afc911e519c84..821018471ce1b5fc9f39323412c021185b8618ae 100644 (file)
@@ -17,33 +17,42 @@ from ... import util
 
 __all__ = ('JSON', 'JSONB')
 
+idx_precedence = operators._PRECEDENCE[operators.json_getitem_op]
+
 ASTEXT = operators.custom_op(
-    "->>", precedence=15, natural_self_precedent=True,
+    "->>", precedence=idx_precedence, natural_self_precedent=True,
+    eager_grouping=True
 )
 
 JSONPATH_ASTEXT = operators.custom_op(
-    "#>>", precedence=15, natural_self_precedent=True,
+    "#>>", precedence=idx_precedence, natural_self_precedent=True,
+    eager_grouping=True
 )
 
 
 HAS_KEY = operators.custom_op(
-    "?", precedence=15, natural_self_precedent=True
+    "?", precedence=idx_precedence, natural_self_precedent=True,
+    eager_grouping=True
 )
 
 HAS_ALL = operators.custom_op(
-    "?&", precedence=15, natural_self_precedent=True
+    "?&", precedence=idx_precedence, natural_self_precedent=True,
+    eager_grouping=True
 )
 
 HAS_ANY = operators.custom_op(
-    "?|", precedence=15, natural_self_precedent=True
+    "?|", precedence=idx_precedence, natural_self_precedent=True,
+    eager_grouping=True
 )
 
 CONTAINS = operators.custom_op(
-    "@>", precedence=15, natural_self_precedent=True
+    "@>", precedence=idx_precedence, natural_self_precedent=True,
+    eager_grouping=True
 )
 
 CONTAINED_BY = operators.custom_op(
-    "<@", precedence=15, natural_self_precedent=True
+    "<@", precedence=idx_precedence, natural_self_precedent=True,
+    eager_grouping=True
 )
 
 
index a2dbcee5c3a207fbe8ea024071d6ace1ad4ff243..6527eb8c653bdbb465f252995c13a7f853b5caa4 100644 (file)
@@ -994,7 +994,9 @@ class SQLCompiler(Compiled):
         return "NOT %s" % self.visit_binary(
             binary, override_operator=operators.match_op)
 
-    def visit_binary(self, binary, override_operator=None, **kw):
+    def visit_binary(self, binary, override_operator=None,
+                     eager_grouping=False, **kw):
+
         # don't allow "? = ?" to render
         if self.ansi_bind_rules and \
                 isinstance(binary.left, elements.BindParameter) and \
@@ -1014,6 +1016,7 @@ class SQLCompiler(Compiled):
                 return self._generate_generic_binary(binary, opstring, **kw)
 
     def visit_custom_op_binary(self, element, operator, **kw):
+        kw['eager_grouping'] = operator.eager_grouping
         return self._generate_generic_binary(
             element, " " + operator.opstring + " ", **kw)
 
@@ -1025,10 +1028,21 @@ class SQLCompiler(Compiled):
         return self._generate_generic_unary_modifier(
             element, " " + operator.opstring, **kw)
 
-    def _generate_generic_binary(self, binary, opstring, **kw):
-        return binary.left._compiler_dispatch(self, **kw) + \
+    def _generate_generic_binary(
+            self, binary, opstring, eager_grouping=False, **kw):
+
+        _in_binary = kw.get('_in_binary', False)
+
+        kw['_in_binary'] = True
+        text = binary.left._compiler_dispatch(
+            self, eager_grouping=eager_grouping, **kw) + \
             opstring + \
-            binary.right._compiler_dispatch(self, **kw)
+            binary.right._compiler_dispatch(
+                self, eager_grouping=eager_grouping, **kw)
+
+        if _in_binary and eager_grouping:
+            text = "(%s)" % text
+        return text
 
     def _generate_generic_unary_operator(self, unary, opstring, **kw):
         return opstring + unary.element._compiler_dispatch(self, **kw)
@@ -2215,6 +2229,12 @@ class StrSQLCompiler(SQLCompiler):
             self.process(binary.right, **kw)
         )
 
+    def visit_json_getitem_op_binary(self, binary, operator, **kw):
+        return self.visit_getitem_binary(binary, operator, **kw)
+
+    def visit_json_path_getitem_op_binary(self, binary, operator, **kw):
+        return self.visit_getitem_binary(binary, operator, **kw)
+
     def returning_clause(self, stmt, returning_cols):
         columns = [
             self._label_select_column(None, c, True, False, {})
index 14260668045c1f46074ed38d16c321cd07225b1c..69eee28abe3c3738e7c345f971ae1d2e441b041d 100644 (file)
@@ -215,11 +215,12 @@ class custom_op(object):
 
     def __init__(
             self, opstring, precedence=0, is_comparison=False,
-            natural_self_precedent=False):
+            natural_self_precedent=False, eager_grouping=False):
         self.opstring = opstring
         self.precedence = precedence
         self.is_comparison = is_comparison
         self.natural_self_precedent = natural_self_precedent
+        self.eager_grouping = eager_grouping
 
     def __eq__(self, other):
         return isinstance(other, custom_op) and \
@@ -935,9 +936,10 @@ _PRECEDENCE = {
     from_: 15,
     any_op: 15,
     all_op: 15,
+    getitem: 15,
     json_getitem_op: 15,
     json_path_getitem_op: 15,
-    getitem: 15,
+
     mul: 8,
     truediv: 8,
     div: 8,
@@ -985,6 +987,7 @@ _PRECEDENCE = {
 
     as_: -1,
     exists: 0,
+
     _asbool: -10,
     _smallest: _smallest,
     _largest: _largest
index 6bcc4cf9ad40e041cd8835b5ceb82d5c5b5906e9..b611c4a9dfdd7f1465d4482dde0aaf854af749ac 100644 (file)
@@ -1816,7 +1816,7 @@ class HStoreTest(AssertsCompiledSQL, fixtures.TestBase):
     def test_where_getitem(self):
         self._test_where(
             self.hashcol['bar'] == None,
-            "test_table.hash -> %(hash_1)s IS NULL"
+            "(test_table.hash -> %(hash_1)s) IS NULL"
         )
 
     def test_cols_get(self):
@@ -1902,6 +1902,12 @@ class HStoreTest(AssertsCompiledSQL, fixtures.TestBase):
             "(test_table.hash || test_table.hash) -> %(param_1)s AS anon_1"
         )
 
+    def test_cols_against_is(self):
+        self._test_cols(
+            self.hashcol['foo'] != None,
+            "(test_table.hash -> %(hash_1)s) IS NOT NULL AS anon_1"
+        )
+
     def test_cols_keys(self):
         self._test_cols(
             # hide from 2to3
@@ -2436,13 +2442,13 @@ class JSONTest(AssertsCompiledSQL, fixtures.TestBase):
     def test_where_getitem(self):
         self._test_where(
             self.jsoncol['bar'] == None,
-            "test_table.test_column -> %(test_column_1)s IS NULL"
+            "(test_table.test_column -> %(test_column_1)s) IS NULL"
         )
 
     def test_where_path(self):
         self._test_where(
             self.jsoncol[("foo", 1)] == None,
-            "test_table.test_column #> %(test_column_1)s IS NULL"
+            "(test_table.test_column #> %(test_column_1)s) IS NULL"
         )
 
     def test_path_typing(self):
@@ -2481,27 +2487,27 @@ class JSONTest(AssertsCompiledSQL, fixtures.TestBase):
     def test_where_getitem_as_text(self):
         self._test_where(
             self.jsoncol['bar'].astext == None,
-            "test_table.test_column ->> %(test_column_1)s IS NULL"
+            "(test_table.test_column ->> %(test_column_1)s) IS NULL"
         )
 
     def test_where_getitem_astext_cast(self):
         self._test_where(
             self.jsoncol['bar'].astext.cast(Integer) == 5,
-            "CAST(test_table.test_column ->> %(test_column_1)s AS INTEGER) "
+            "CAST((test_table.test_column ->> %(test_column_1)s) AS INTEGER) "
             "= %(param_1)s"
         )
 
     def test_where_getitem_json_cast(self):
         self._test_where(
             self.jsoncol['bar'].cast(Integer) == 5,
-            "CAST(test_table.test_column -> %(test_column_1)s AS INTEGER) "
+            "CAST((test_table.test_column -> %(test_column_1)s) AS INTEGER) "
             "= %(param_1)s"
         )
 
     def test_where_path_as_text(self):
         self._test_where(
             self.jsoncol[("foo", 1)].astext == None,
-            "test_table.test_column #>> %(test_column_1)s IS NULL"
+            "(test_table.test_column #>> %(test_column_1)s) IS NULL"
         )
 
     def test_cols_get(self):
@@ -2781,6 +2787,13 @@ class JSONRoundTripTest(fixtures.TablesTest):
         ).first()
         eq_(result, ({'k1': 'r3v1', 'k2': 'r3v2'},))
 
+        result = engine.execute(
+            select([data_table.c.data]).where(
+                data_table.c.data['k1'].astext.cast(String) == 'r3v1'
+            )
+        ).first()
+        eq_(result, ({'k1': 'r3v1', 'k2': 'r3v2'},))
+
     def _test_fixed_round_trip(self, engine):
         s = select([
             cast(
index 99f8a10ca3c605a119284174edea830fe63dc450..bbc912ffd770c715678e75b7eae9ab199d17256c 100644 (file)
@@ -670,13 +670,13 @@ class JSONIndexOpTest(fixtures.TestBase, testing.AssertsCompiledSQL):
 
             def visit_json_getitem_op_binary(self, binary, operator, **kw):
                 return self._generate_generic_binary(
-                    binary, " -> ", **kw
+                    binary, " -> ", eager_grouping=True, **kw
                 )
 
             def visit_json_path_getitem_op_binary(
                     self, binary, operator, **kw):
                 return self._generate_generic_binary(
-                    binary, " #> ", **kw
+                    binary, " #> ", eager_grouping=True, **kw
                 )
 
             def visit_getitem_binary(self, binary, operator, **kw):
@@ -748,12 +748,37 @@ class JSONIndexOpTest(fixtures.TestBase, testing.AssertsCompiledSQL):
             checkparams={}
         )
 
+    def test_getindex_sqlexpr_right_grouping(self):
+
+        col = Column('x', self.MyType())
+        col2 = Column('y', Integer())
+
         self.assert_compile(
             col[col2 + 8],
             "x -> (y + :y_1)",
             checkparams={'y_1': 8}
         )
 
+    def test_getindex_sqlexpr_left_grouping(self):
+
+        col = Column('x', self.MyType())
+
+        self.assert_compile(
+            col[8] != None,
+            "(x -> :x_1) IS NOT NULL"
+        )
+
+    def test_getindex_sqlexpr_both_grouping(self):
+
+        col = Column('x', self.MyType())
+        col2 = Column('y', Integer())
+
+        self.assert_compile(
+            col[col2 + 8] != None,
+            "(x -> (y + :y_1)) IS NOT NULL",
+            checkparams={'y_1': 8}
+        )
+
     def test_override_operators(self):
         special_index_op = operators.custom_op('$$>')