]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
ensure datatype roundtrips for JSON dialects
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 22 Aug 2025 22:12:13 +0000 (18:12 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 25 Aug 2025 14:09:34 +0000 (10:09 -0400)
Improved the behavior of JSON accessors :meth:`.JSON.Comparator.as_string`,
:meth:`.JSON.Comparator.as_boolean`, :meth:`.JSON.Comparator.as_float`,
:meth:`.JSON.Comparator.as_integer` to use CAST in a similar way that
the PostgreSQL, MySQL and SQL Server dialects do to help enforce the
expected Python type is returned.

The :meth:`.JSON.Comparator.as_boolean` method when used on a JSON value on
SQL Server will now force a cast to occur for values that are not simple
`true`/`false` JSON literals, forcing SQL Server to attempt to interpret
the given value as a 1/0 BIT, or raise an error if not possible. Previously
the expression would return NULL.

Fixes: #11074
Change-Id: I5024b78ec2fa6b61a9c6ee176112f1b761eeab98

doc/build/changelog/unreleased_21/11074.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/mssql/base.py
lib/sqlalchemy/dialects/sqlite/base.py
lib/sqlalchemy/testing/suite/test_types.py

diff --git a/doc/build/changelog/unreleased_21/11074.rst b/doc/build/changelog/unreleased_21/11074.rst
new file mode 100644 (file)
index 0000000..c5741e3
--- /dev/null
@@ -0,0 +1,23 @@
+.. change::
+    :tags: bug, sqlite
+    :tickets: 11074
+
+    Improved the behavior of JSON accessors :meth:`.JSON.Comparator.as_string`,
+    :meth:`.JSON.Comparator.as_boolean`, :meth:`.JSON.Comparator.as_float`,
+    :meth:`.JSON.Comparator.as_integer` to use CAST in a similar way that
+    the PostgreSQL, MySQL and SQL Server dialects do to help enforce the
+    expected Python type is returned.
+
+
+
+.. change::
+    :tags: bug, mssql
+    :tickets: 11074
+
+    The :meth:`.JSON.Comparator.as_boolean` method when used on a JSON value on
+    SQL Server will now force a cast to occur for values that are not simple
+    `true`/`false` JSON literals, forcing SQL Server to attempt to interpret
+    the given value as a 1/0 BIT, or raise an error if not possible. Previously
+    the expression would return NULL.
+
+
index c0bf43304af4fa94ff9cb9fded0a67d2d9f293a3..88fea92d8fa7a73fb46a593078b28c65d8f99acf 100644 (file)
@@ -2479,7 +2479,12 @@ class MSSQLCompiler(compiler.SQLCompiler):
             # the NULL handling is particularly weird with boolean, so
             # explicitly return numeric (BIT) constants
             type_expression = (
-                "WHEN 'true' THEN 1 WHEN 'false' THEN 0 ELSE NULL"
+                "WHEN 'true' THEN 1 WHEN 'false' THEN 0 ELSE "
+                "CAST(JSON_VALUE(%s, %s) AS BIT)"
+                % (
+                    self.process(binary.left, **kw),
+                    self.process(binary.right, **kw),
+                )
             )
         elif binary.type._type_affinity is sqltypes.String:
             # TODO: does this comment (from mysql) apply to here, too?
index b8bb052a99ffc1b99e689fdf8824fb45c3765528..d1abf26c3c5c31aa2e192deb56802f826a14163c 100644 (file)
@@ -1521,7 +1521,16 @@ class SQLiteCompiler(compiler.SQLCompiler):
             self.process(binary.right),
         )
 
-    def visit_json_getitem_op_binary(self, binary, operator, **kw):
+    def visit_json_getitem_op_binary(
+        self, binary, operator, _cast_applied=False, **kw
+    ):
+        if (
+            not _cast_applied
+            and binary.type._type_affinity is not sqltypes.JSON
+        ):
+            kw["_cast_applied"] = True
+            return self.process(sql.cast(binary, binary.type), **kw)
+
         if binary.type._type_affinity is sqltypes.JSON:
             expr = "JSON_QUOTE(JSON_EXTRACT(%s, %s))"
         else:
@@ -1532,7 +1541,16 @@ class SQLiteCompiler(compiler.SQLCompiler):
             self.process(binary.right, **kw),
         )
 
-    def visit_json_path_getitem_op_binary(self, binary, operator, **kw):
+    def visit_json_path_getitem_op_binary(
+        self, binary, operator, _cast_applied=False, **kw
+    ):
+        if (
+            not _cast_applied
+            and binary.type._type_affinity is not sqltypes.JSON
+        ):
+            kw["_cast_applied"] = True
+            return self.process(sql.cast(binary, binary.type), **kw)
+
         if binary.type._type_affinity is sqltypes.JSON:
             expr = "JSON_QUOTE(JSON_EXTRACT(%s, %s))"
         else:
index 5f1bf75d504836b7b8b9abe3e401117788143444..112a5d0df13e6471c99ebc107c9c2dba84e9658f 100644 (file)
@@ -59,6 +59,7 @@ from ...orm import Session
 from ...sql import sqltypes
 from ...sql.sqltypes import LargeBinary
 from ...sql.sqltypes import PickleType
+from ...testing import Variation
 
 
 class _LiteralRoundTripFixture:
@@ -1782,6 +1783,111 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest):
                 ("null", None),
             )
 
+    @testing.combinations(
+        ("string",),
+        ("integer",),
+        ("float",),
+        ("numeric",),
+        ("boolean",),
+        argnames="cross_cast",
+    )
+    @testing.combinations(
+        ("boolean", True, {"string"}),
+        ("boolean", False, {"string"}),
+        ("boolean", None, {"all"}),
+        ("string", "45", {"integer", "float", "numeric"}),
+        ("string", "45.684", {"float", "numeric"}),
+        ("string", "some string", {"string"}),
+        ("string", None, {"all"}),
+        ("string", "réve illé", {"string"}),
+        ("string", "true", {"boolean"}),
+        ("string", "false", {"boolean"}),
+        ("integer", 15, {"string", "numeric", "float"}),
+        ("integer", 1, {"all"}),
+        ("integer", 0, {"all"}),
+        ("integer", None, {"all"}),
+        ("float", None, {"all"}),
+        ("float", 1234567.89, {"string", "numeric"}),
+        ("numeric", 1234567.89, {"string", "float"}),
+        argnames="datatype, value, allowed_targets",
+    )
+    @testing.variation("json_access", ["getitem", "path"])
+    def test_index_cross_casts(
+        self,
+        datatype,
+        value,
+        allowed_targets,
+        cross_cast,
+        json_access: Variation,
+        connection,
+    ):
+        """cross cast tests set up for #11074"""
+
+        data_table = self.tables.data_table
+        if json_access.getitem:
+            data_element = {"key1": value}
+        elif json_access.path:
+            data_element = {"attr1": {"key1": value}}
+        else:
+            json_access.fail()
+
+        datatype, _, _ = self._json_value_insert(
+            connection, datatype, value, data_element
+        )
+
+        if json_access.getitem:
+            expr = data_table.c.data["key1"]
+        elif json_access.path:
+            expr = data_table.c.data[("attr1", "key1")]
+        else:
+            json_access.fail()
+
+        if cross_cast == "numeric":
+            expr = getattr(expr, "as_%s" % cross_cast)(10, 2)
+        else:
+            expr = getattr(expr, "as_%s" % cross_cast)()
+
+        if (
+            cross_cast != datatype
+            and "all" not in allowed_targets
+            and cross_cast not in allowed_targets
+        ):
+            try:
+                roundtrip = connection.scalar(select(expr))
+            except Exception:
+                # We can't predict in a backend-agnostic way what CASTS
+                # will fail and which will proceed with a (possibly
+                # useless) value.  PostgreSQL CASTS fail in 100% of cases
+                # that the types aren't compatible. SQL Server fails in
+                # most, except for booleans because it uses ints for
+                # booleans which are easier to cast. MySQL and SQLite do
+                # not raise for CAST under any circumstances for the four
+                # of string/int/float/boolean. one way to force a fail
+                # would be to have backends inject a special version of
+                # Float/Unicode/Integer/Boolean that enforces a python
+                # check of the expected data value. However for now we let
+                # the backends ensure the expected type is returned but we
+                # don't try to validate the value itself for non-sensical
+                # casts.
+                return
+        else:
+            roundtrip = connection.scalar(select(expr))
+
+        if value is None:
+            eq_(roundtrip, None)
+        elif cross_cast == "string":
+            assert isinstance(roundtrip, str)
+        elif cross_cast == "integer":
+            assert isinstance(roundtrip, int)
+        elif cross_cast == "float":
+            assert isinstance(roundtrip, float)
+        elif cross_cast == "numeric":
+            assert isinstance(roundtrip, decimal.Decimal)
+        elif cross_cast == "boolean":
+            assert isinstance(roundtrip, bool)
+        else:
+            assert False
+
 
 class JSONLegacyStringCastIndexTest(
     _LiteralRoundTripFixture, fixtures.TablesTest