]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Build string/int processors for JSONIndexType, JSONPathType
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 4 Aug 2016 15:56:31 +0000 (11:56 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 4 Aug 2016 16:38:58 +0000 (12:38 -0400)
Fixed regression in JSON datatypes where the "literal processor" for
a JSON index value, that needs to take effect for example within DDL,
would not be invoked for the value.  The native String and Integer
datatypes are now called upon from within the JSONIndexType
and JSONPathType.  This is applied to the generic, Postgresql, and
MySQL JSON types.

Change-Id: Ifa5f2acfeee57a79d01d7fc85d265a37bd27c716
Fixes: #3765
doc/build/changelog/changelog_11.rst
lib/sqlalchemy/dialects/mysql/base.py
lib/sqlalchemy/dialects/mysql/json.py
lib/sqlalchemy/dialects/postgresql/json.py
lib/sqlalchemy/sql/sqltypes.py
lib/sqlalchemy/testing/suite/test_types.py
test/sql/test_types.py

index fe6be3b17e7f1886aa7bbcd888bc6c8fd72f4a8f..b94104be83097c39f81aeb1be08f87b446ff07f1 100644 (file)
         Fixed bug where the "literal_binds" flag would not be propagated
         to a CAST expression under MySQL.
 
+    .. change::
+        :tags: bug, sql, postgresql, mysql
+        :tickets: 3765
+
+        Fixed regression in JSON datatypes where the "literal processor" for
+        a JSON index value would not be invoked.  The native String and Integer
+        datatypes are now called upon from within the JSONIndexType
+        and JSONPathType.  This is applied to the generic, Postgresql, and
+        MySQL JSON types and also has a dependency on :ticket:`3766`.
+
     .. change::
         :tags: change, orm
 
index 7ab9fad69fd5bf5593ecb67ae17ce18f9c6d5468..e7e5338905e77fa8758f6aa53a6b3bc47ecb136a 100644 (file)
@@ -763,13 +763,13 @@ class MySQLCompiler(compiler.SQLCompiler):
 
     def visit_json_getitem_op_binary(self, binary, operator, **kw):
         return "JSON_EXTRACT(%s, %s)" % (
-            self.process(binary.left),
-            self.process(binary.right))
+            self.process(binary.left, **kw),
+            self.process(binary.right, **kw))
 
     def visit_json_path_getitem_op_binary(self, binary, operator, **kw):
         return "JSON_EXTRACT(%s, %s)" % (
-            self.process(binary.left),
-            self.process(binary.right))
+            self.process(binary.left, **kw),
+            self.process(binary.right, **kw))
 
     def visit_concat_op_binary(self, binary, operator, **kw):
         return "concat(%s, %s)" % (self.process(binary.left),
index 3840a7cd67c5f07596d393ea54dc6356fd6430f8..8dd99bd45123a5edb1b46a075fddd62f91928edf 100644 (file)
@@ -31,25 +31,49 @@ class JSON(sqltypes.JSON):
 
     pass
 
-class JSONIndexType(sqltypes.JSON.JSONIndexType):
+
+class _FormatTypeMixin(object):
+    def _format_value(self, value):
+        raise NotImplementedError()
+
     def bind_processor(self, dialect):
+        super_proc = self.string_bind_processor(dialect)
+
         def process(value):
-            if isinstance(value, int):
-                return "$[%s]" % value
-            else:
-                return '$."%s"' % value
+            value = self._format_value(value)
+            if super_proc:
+                value = super_proc(value)
+            return value
 
         return process
 
+    def literal_processor(self, dialect):
+        super_proc = self.string_literal_processor(dialect)
 
-class JSONPathType(sqltypes.JSON.JSONPathType):
-    def bind_processor(self, dialect):
         def process(value):
-            return "$%s" % (
-                "".join([
-                    "[%s]" % elem if isinstance(elem, int)
-                    else '."%s"' % elem for elem in value
-                ])
-            )
+            value = self._format_value(value)
+            if super_proc:
+                value = super_proc(value)
+            return value
 
         return process
+
+
+class JSONIndexType(_FormatTypeMixin, sqltypes.JSON.JSONIndexType):
+
+    def _format_value(self, value):
+        if isinstance(value, int):
+            value = "$[%s]" % value
+        else:
+            value = '$."%s"' % value
+        return value
+
+
+class JSONPathType(_FormatTypeMixin, sqltypes.JSON.JSONPathType):
+    def _format_value(self, value):
+        return "$%s" % (
+            "".join([
+                "[%s]" % elem if isinstance(elem, int)
+                else '."%s"' % elem for elem in value
+            ])
+        )
index b0f0f7cf0c41226db0378da46f7da294a872f9f0..05c4d014d3a8a75f1584fd2cf27afc911e519c84 100644 (file)
@@ -49,10 +49,28 @@ CONTAINED_BY = operators.custom_op(
 
 class JSONPathType(sqltypes.JSON.JSONPathType):
     def bind_processor(self, dialect):
+        super_proc = self.string_bind_processor(dialect)
+
+        def process(value):
+            assert isinstance(value, collections.Sequence)
+            tokens = [util.text_type(elem)for elem in value]
+            value = "{%s}" % (", ".join(tokens))
+            if super_proc:
+                value = super_proc(value)
+            return value
+
+        return process
+
+    def literal_processor(self, dialect):
+        super_proc = self.string_literal_processor(dialect)
+
         def process(value):
             assert isinstance(value, collections.Sequence)
-            tokens = [util.text_type(elem) for elem in value]
-            return "{%s}" % (", ".join(tokens))
+            tokens = [util.text_type(elem)for elem in value]
+            value = "{%s}" % (", ".join(tokens))
+            if super_proc:
+                value = super_proc(value)
+            return value
 
         return process
 
index 9772313365135f5ca3a116431acd353be7a8d681..b55d435ad02ba74691c997dd41b06152a5629a8e 100644 (file)
@@ -1789,7 +1789,45 @@ class JSON(Indexable, TypeEngine):
          """
         self.none_as_null = none_as_null
 
-    class JSONIndexType(TypeEngine):
+    class JSONElementType(TypeEngine):
+        """common function for index / path elements in a JSON expression."""
+
+        _integer = Integer()
+        _string = String()
+
+        def string_bind_processor(self, dialect):
+            return self._string._cached_bind_processor(dialect)
+
+        def string_literal_processor(self, dialect):
+            return self._string._cached_literal_processor(dialect)
+
+        def bind_processor(self, dialect):
+            int_processor = self._integer._cached_bind_processor(dialect)
+            string_processor = self.string_bind_processor(dialect)
+
+            def process(value):
+                if int_processor and isinstance(value, int):
+                    value = int_processor(value)
+                elif string_processor and isinstance(value, util.string_types):
+                    value = string_processor(value)
+                return value
+
+            return process
+
+        def literal_processor(self, dialect):
+            int_processor = self._integer._cached_literal_processor(dialect)
+            string_processor = self.string_literal_processor(dialect)
+
+            def process(value):
+                if int_processor and isinstance(value, int):
+                    value = int_processor(value)
+                elif string_processor and isinstance(value, util.string_types):
+                    value = string_processor(value)
+                return value
+
+            return process
+
+    class JSONIndexType(JSONElementType):
         """Placeholder for the datatype of a JSON index value.
 
         This allows execution-time processing of JSON index values
@@ -1797,7 +1835,7 @@ class JSON(Indexable, TypeEngine):
 
         """
 
-    class JSONPathType(TypeEngine):
+    class JSONPathType(JSONElementType):
         """Placeholder type for JSON path operations.
 
         This allows execution-time processing of a path-based
index d74ef60da70731b605aa5bd969384c810a417a90..d85531396954099ff6345cca9e3f17ffd875c077 100644 (file)
@@ -736,14 +736,18 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest):
     def _test_index_criteria(self, crit, expected):
         self._criteria_fixture()
         with config.db.connect() as conn:
+            stmt = select([self.tables.data_table.c.name]).where(crit)
+
             eq_(
-                conn.scalar(
-                    select([self.tables.data_table.c.name]).
-                    where(crit)
-                ),
+                conn.scalar(stmt),
                 expected
             )
 
+            literal_sql = str(stmt.compile(
+                config.db, compile_kwargs={"literal_binds": True}))
+
+            eq_(conn.scalar(literal_sql), expected)
+
     def test_crit_spaces_in_key(self):
         name = self.tables.data_table.c.name
         col = self.tables.data_table.c['data']
index 49a1d8f15ffc0183aedf09a618f0f1b8e17a95d9..3374a67213163831adafe96228ae12f2b1ebcb9e 100644 (file)
@@ -1630,6 +1630,77 @@ class JSONTest(fixtures.TestBase):
             None
         )
 
+    def _dialect_index_fixture(self, int_processor, str_processor):
+        class MyInt(Integer):
+            def bind_processor(self, dialect):
+                return lambda value: value + 10
+
+            def literal_processor(self, diaect):
+                return lambda value: str(value + 15)
+
+        class MyString(String):
+            def bind_processor(self, dialect):
+                return lambda value: value + "10"
+
+            def literal_processor(self, diaect):
+                return lambda value: value + "15"
+
+        class MyDialect(default.DefaultDialect):
+            colspecs = {}
+            if int_processor:
+                colspecs[Integer] = MyInt
+            if str_processor:
+                colspecs[String] = MyString
+
+        return MyDialect()
+
+    def test_index_bind_proc_int(self):
+        expr = self.test_table.c.test_column[5]
+
+        int_dialect = self._dialect_index_fixture(True, True)
+        non_int_dialect = self._dialect_index_fixture(False, True)
+
+        bindproc = expr.right.type._cached_bind_processor(int_dialect)
+        eq_(bindproc(expr.right.value), 15)
+
+        bindproc = expr.right.type._cached_bind_processor(non_int_dialect)
+        eq_(bindproc(expr.right.value), 5)
+
+    def test_index_literal_proc_int(self):
+        expr = self.test_table.c.test_column[5]
+
+        int_dialect = self._dialect_index_fixture(True, True)
+        non_int_dialect = self._dialect_index_fixture(False, True)
+
+        bindproc = expr.right.type._cached_literal_processor(int_dialect)
+        eq_(bindproc(expr.right.value), "20")
+
+        bindproc = expr.right.type._cached_literal_processor(non_int_dialect)
+        eq_(bindproc(expr.right.value), "5")
+
+    def test_index_bind_proc_str(self):
+        expr = self.test_table.c.test_column['five']
+
+        str_dialect = self._dialect_index_fixture(True, True)
+        non_str_dialect = self._dialect_index_fixture(False, False)
+
+        bindproc = expr.right.type._cached_bind_processor(str_dialect)
+        eq_(bindproc(expr.right.value), 'five10')
+
+        bindproc = expr.right.type._cached_bind_processor(non_str_dialect)
+        eq_(bindproc(expr.right.value), 'five')
+
+    def test_index_literal_proc_str(self):
+        expr = self.test_table.c.test_column['five']
+
+        str_dialect = self._dialect_index_fixture(True, True)
+        non_str_dialect = self._dialect_index_fixture(False, False)
+
+        bindproc = expr.right.type._cached_literal_processor(str_dialect)
+        eq_(bindproc(expr.right.value), "five15")
+
+        bindproc = expr.right.type._cached_literal_processor(non_str_dialect)
+        eq_(bindproc(expr.right.value), "'five'")
 
 class ArrayTest(fixtures.TestBase):