self.process(binary.right),
)
- def visit_json_getitem_op_binary(self, binary, operator, **kw):
+ def visit_json_getitem_op_binary(
+ self, binary, operator, _cast_applied=False, **kw
+ ):
+ if (
+ not _cast_applied
+ and binary.type._type_affinity is not sqltypes.JSON
+ ):
+ kw["_cast_applied"] = True
+ return self.process(sql.cast(binary, binary.type), **kw)
+
if binary.type._type_affinity is sqltypes.JSON:
expr = "JSON_QUOTE(JSON_EXTRACT(%s, %s))"
else:
self.process(binary.right, **kw),
)
- def visit_json_path_getitem_op_binary(self, binary, operator, **kw):
+ def visit_json_path_getitem_op_binary(
+ self, binary, operator, _cast_applied=False, **kw
+ ):
+ if (
+ not _cast_applied
+ and binary.type._type_affinity is not sqltypes.JSON
+ ):
+ kw["_cast_applied"] = True
+ return self.process(sql.cast(binary, binary.type), **kw)
+
if binary.type._type_affinity is sqltypes.JSON:
expr = "JSON_QUOTE(JSON_EXTRACT(%s, %s))"
else:
from ...sql import sqltypes
from ...sql.sqltypes import LargeBinary
from ...sql.sqltypes import PickleType
+from ...testing import Variation
class _LiteralRoundTripFixture:
("null", None),
)
+ @testing.combinations(
+ ("string",),
+ ("integer",),
+ ("float",),
+ ("numeric",),
+ ("boolean",),
+ argnames="cross_cast",
+ )
+ @testing.combinations(
+ ("boolean", True, {"string"}),
+ ("boolean", False, {"string"}),
+ ("boolean", None, {"all"}),
+ ("string", "45", {"integer", "float", "numeric"}),
+ ("string", "45.684", {"float", "numeric"}),
+ ("string", "some string", {"string"}),
+ ("string", None, {"all"}),
+ ("string", "réve illé", {"string"}),
+ ("string", "true", {"boolean"}),
+ ("string", "false", {"boolean"}),
+ ("integer", 15, {"string", "numeric", "float"}),
+ ("integer", 1, {"all"}),
+ ("integer", 0, {"all"}),
+ ("integer", None, {"all"}),
+ ("float", None, {"all"}),
+ ("float", 1234567.89, {"string", "numeric"}),
+ ("numeric", 1234567.89, {"string", "float"}),
+ argnames="datatype, value, allowed_targets",
+ )
+ @testing.variation("json_access", ["getitem", "path"])
+ def test_index_cross_casts(
+ self,
+ datatype,
+ value,
+ allowed_targets,
+ cross_cast,
+ json_access: Variation,
+ connection,
+ ):
+ """cross cast tests set up for #11074"""
+
+ data_table = self.tables.data_table
+ if json_access.getitem:
+ data_element = {"key1": value}
+ elif json_access.path:
+ data_element = {"attr1": {"key1": value}}
+ else:
+ json_access.fail()
+
+ datatype, _, _ = self._json_value_insert(
+ connection, datatype, value, data_element
+ )
+
+ if json_access.getitem:
+ expr = data_table.c.data["key1"]
+ elif json_access.path:
+ expr = data_table.c.data[("attr1", "key1")]
+ else:
+ json_access.fail()
+
+ if cross_cast == "numeric":
+ expr = getattr(expr, "as_%s" % cross_cast)(10, 2)
+ else:
+ expr = getattr(expr, "as_%s" % cross_cast)()
+
+ if (
+ cross_cast != datatype
+ and "all" not in allowed_targets
+ and cross_cast not in allowed_targets
+ ):
+ try:
+ roundtrip = connection.scalar(select(expr))
+ except Exception:
+ # We can't predict in a backend-agnostic way what CASTS
+ # will fail and which will proceed with a (possibly
+ # useless) value. PostgreSQL CASTS fail in 100% of cases
+ # that the types aren't compatible. SQL Server fails in
+ # most, except for booleans because it uses ints for
+ # booleans which are easier to cast. MySQL and SQLite do
+ # not raise for CAST under any circumstances for the four
+ # of string/int/float/boolean. one way to force a fail
+ # would be to have backends inject a special version of
+ # Float/Unicode/Integer/Boolean that enforces a python
+ # check of the expected data value. However for now we let
+ # the backends ensure the expected type is returned but we
+ # don't try to validate the value itself for non-sensical
+ # casts.
+ return
+ else:
+ roundtrip = connection.scalar(select(expr))
+
+ if value is None:
+ eq_(roundtrip, None)
+ elif cross_cast == "string":
+ assert isinstance(roundtrip, str)
+ elif cross_cast == "integer":
+ assert isinstance(roundtrip, int)
+ elif cross_cast == "float":
+ assert isinstance(roundtrip, float)
+ elif cross_cast == "numeric":
+ assert isinstance(roundtrip, decimal.Decimal)
+ elif cross_cast == "boolean":
+ assert isinstance(roundtrip, bool)
+ else:
+ assert False
+
class JSONLegacyStringCastIndexTest(
_LiteralRoundTripFixture, fixtures.TablesTest