]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
limit None->null coercion to not occur with crud
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 15 Jul 2021 15:25:31 +0000 (11:25 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 15 Jul 2021 19:09:28 +0000 (15:09 -0400)
Fixed issue where type-specific bound parameter handlers would not be
called upon in the case of using the :meth:`_sql.Insert.values` method with
the Python ``None`` value; in particular, this would be noticed when using
the :class:`_types.JSON` datatype as well as related PostgreSQL specific
types such as :class:`_postgresql.JSONB` which would fail to encode the
Python ``None`` value into JSON null, however the issue was generalized to
any bound parameter handler in conjunction with this specific method of
:class:`_sql.Insert`.

The issue with coercions forcing out ``null()`` may still impact
SQL expression usage as well; the change here is limited to crud
as the behavior there is relevant to some use cases, which may
need to be evaluated separately.

Fixes: #6770
Change-Id: If53edad811b37dada7578a89daf395628db058a6

doc/build/changelog/unreleased_14/6770.rst [new file with mode: 0644]
lib/sqlalchemy/sql/coercions.py
lib/sqlalchemy/sql/sqltypes.py
lib/sqlalchemy/testing/suite/test_types.py
test/sql/test_types.py

diff --git a/doc/build/changelog/unreleased_14/6770.rst b/doc/build/changelog/unreleased_14/6770.rst
new file mode 100644 (file)
index 0000000..7da35f4
--- /dev/null
@@ -0,0 +1,13 @@
+.. change::
+    :tags: bug, sql
+    :tickets: 6770
+
+    Fixed issue where type-specific bound parameter handlers would not be
+    called upon in the case of using the :meth:`_sql.Insert.values` method with
+    the Python ``None`` value; in particular, this would be noticed when using
+    the :class:`_types.JSON` datatype as well as related PostgreSQL specific
+    types such as :class:`_postgresql.JSONB` which would fail to encode the
+    Python ``None`` value into JSON null, however the issue was generalized to
+    any bound parameter handler in conjunction with this specific method of
+    :class:`_sql.Insert`.
+
index 16a68c8ffd06f38861ed6a4da517ba67dca9b87b..e21f4a9a5fcf895688da878ff463864490177d8c 100644 (file)
@@ -464,7 +464,14 @@ class ExpressionElementImpl(_ColumnCoercions, RoleImpl):
     def _literal_coercion(
         self, element, name=None, type_=None, argname=None, is_crud=False, **kw
     ):
-        if element is None:
+        if (
+            element is None
+            and not is_crud
+            and (type_ is None or not type_.should_evaluate_none)
+        ):
+            # TODO: there's no test coverage now for the
+            # "should_evaluate_none" part of this, as outside of "crud" this
+            # codepath is not normally used except in some special cases
             return elements.Null()
         else:
             try:
index cc5f9a6b7dad0b94b6e57607a0bec2aeb8e59962..1b05465c99742520637ad76a8dafa1bf7a3ba18d 100644 (file)
@@ -2371,6 +2371,14 @@ class JSON(Indexable, TypeEngine):
               :paramref:`_schema.Column.server_default`; a value of ``None``
               passed for these parameters means "no default present".
 
+              Additionally, when used in SQL comparison expressions, the
+              Python value ``None`` continues to refer to SQL null, and not
+              JSON NULL.  The :paramref:`_types.JSON.none_as_null` flag refers
+              explicitly to the **persistence** of the value within an
+              INSERT or UPDATE statement.   The :attr:`_types.JSON.NULL`
+              value should be used for SQL expressions that wish to compare to
+              JSON null.
+
          .. seealso::
 
               :attr:`.types.JSON.NULL`
index 3e54cc2e45a7f8ed7b99a708ebc1d6a584a14611..f793ff5290d6993cab6594e5ae762dfbf383e0f4 100644 (file)
@@ -817,7 +817,7 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest):
             metadata,
             Column("id", Integer, primary_key=True),
             Column("name", String(30), nullable=False),
-            Column("data", cls.datatype),
+            Column("data", cls.datatype, nullable=False),
             Column("nulldata", cls.datatype(none_as_null=True)),
         )
 
@@ -1101,13 +1101,47 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest):
             eq_(js.mock_calls, [mock.call(data_element)])
             eq_(jd.mock_calls, [mock.call(json.dumps(data_element))])
 
-    def test_round_trip_none_as_sql_null(self, connection):
+    @testing.combinations(
+        ("parameters",),
+        ("multiparameters",),
+        ("values",),
+        ("omit",),
+        argnames="insert_type",
+    )
+    def test_round_trip_none_as_sql_null(self, connection, insert_type):
         col = self.tables.data_table.c["nulldata"]
 
         conn = connection
-        conn.execute(
-            self.tables.data_table.insert(), {"name": "r1", "data": None}
-        )
+
+        if insert_type == "parameters":
+            stmt, params = self.tables.data_table.insert(), {
+                "name": "r1",
+                "nulldata": None,
+                "data": None,
+            }
+        elif insert_type == "multiparameters":
+            stmt, params = self.tables.data_table.insert(), [
+                {"name": "r1", "nulldata": None, "data": None}
+            ]
+        elif insert_type == "values":
+            stmt, params = (
+                self.tables.data_table.insert().values(
+                    name="r1",
+                    nulldata=None,
+                    data=None,
+                ),
+                {},
+            )
+        elif insert_type == "omit":
+            stmt, params = (
+                self.tables.data_table.insert(),
+                {"name": "r1", "data": None},
+            )
+
+        else:
+            assert False
+
+        conn.execute(stmt, params)
 
         eq_(
             conn.scalar(
@@ -1138,24 +1172,45 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest):
 
         eq_(conn.scalar(select(col)), None)
 
-    def test_round_trip_none_as_json_null(self):
+    @testing.combinations(
+        ("parameters",),
+        ("multiparameters",),
+        ("values",),
+        argnames="insert_type",
+    )
+    def test_round_trip_none_as_json_null(self, connection, insert_type):
         col = self.tables.data_table.c["data"]
 
-        with config.db.begin() as conn:
-            conn.execute(
-                self.tables.data_table.insert(), {"name": "r1", "data": None}
+        if insert_type == "parameters":
+            stmt, params = self.tables.data_table.insert(), {
+                "name": "r1",
+                "data": None,
+            }
+        elif insert_type == "multiparameters":
+            stmt, params = self.tables.data_table.insert(), [
+                {"name": "r1", "data": None}
+            ]
+        elif insert_type == "values":
+            stmt, params = (
+                self.tables.data_table.insert().values(name="r1", data=None),
+                {},
             )
+        else:
+            assert False
 
-            eq_(
-                conn.scalar(
-                    select(self.tables.data_table.c.name).where(
-                        cast(col, String) == "null"
-                    )
-                ),
-                "r1",
-            )
+        conn = connection
+        conn.execute(stmt, params)
+
+        eq_(
+            conn.scalar(
+                select(self.tables.data_table.c.name).where(
+                    cast(col, String) == "null"
+                )
+            ),
+            "r1",
+        )
 
-            eq_(conn.scalar(select(col)), None)
+        eq_(conn.scalar(select(col)), None)
 
     def test_unicode_round_trip(self):
         # note we include Unicode supplementary characters as well
index 3cbd2c07f948e63b80f8c1d57093113d848bd241..309baaabff7e1833ca89fe4b4a545168b68f8cef 100644 (file)
@@ -490,6 +490,8 @@ class _UserDefinedTypeFixture(object):
 
             def bind_processor(self, dialect):
                 def process(value):
+                    if value is None:
+                        value = "<null value>"
                     return "BIND_IN" + value
 
                 return process
@@ -513,6 +515,8 @@ class _UserDefinedTypeFixture(object):
                 ) or (lambda value: value)
 
                 def process(value):
+                    if value is None:
+                        value = "<null value>"
                     return "BIND_IN" + impl_processor(value)
 
                 return process
@@ -535,6 +539,8 @@ class _UserDefinedTypeFixture(object):
             cache_ok = True
 
             def process_bind_param(self, value, dialect):
+                if value is None:
+                    value = u"<null value>"
                 return "BIND_IN" + value
 
             def process_result_value(self, value, dialect):
@@ -548,6 +554,8 @@ class _UserDefinedTypeFixture(object):
             cache_ok = True
 
             def process_bind_param(self, value, dialect):
+                if value is None:
+                    value = 29
                 return value * 10
 
             def process_result_value(self, value, dialect):
@@ -573,6 +581,9 @@ class _UserDefinedTypeFixture(object):
                 ) or (lambda value: value)
 
                 def process(value):
+                    if value is None:
+                        value = u"<null value>"
+
                     return "BIND_IN" + impl_processor(value)
 
                 return process
@@ -654,6 +665,19 @@ class UserDefinedRoundTripTest(_UserDefinedTypeFixture, fixtures.TablesTest):
                 goofy10=9,
             ),
         )
+        connection.execute(
+            users.insert(),
+            dict(
+                user_id=5,
+                goofy=None,
+                goofy2=None,
+                goofy4=None,
+                goofy7=None,
+                goofy8=None,
+                goofy9=None,
+                goofy10=None,
+            ),
+        )
 
     def test_processing(self, connection):
         users = self.tables.users
@@ -662,22 +686,51 @@ class UserDefinedRoundTripTest(_UserDefinedTypeFixture, fixtures.TablesTest):
         result = connection.execute(
             users.select().order_by(users.c.user_id)
         ).fetchall()
-        for assertstr, assertint, assertint2, row in zip(
+        eq_(
+            result,
             [
-                "BIND_INjackBIND_OUT",
-                "BIND_INlalaBIND_OUT",
-                "BIND_INfredBIND_OUT",
+                (
+                    2,
+                    "BIND_INjackBIND_OUT",
+                    "BIND_INjackBIND_OUT",
+                    "BIND_INjackBIND_OUT",
+                    "BIND_INjackBIND_OUT",
+                    1200,
+                    1800,
+                    1200,
+                ),
+                (
+                    3,
+                    "BIND_INlalaBIND_OUT",
+                    "BIND_INlalaBIND_OUT",
+                    "BIND_INlalaBIND_OUT",
+                    "BIND_INlalaBIND_OUT",
+                    1500,
+                    2250,
+                    1500,
+                ),
+                (
+                    4,
+                    "BIND_INfredBIND_OUT",
+                    "BIND_INfredBIND_OUT",
+                    "BIND_INfredBIND_OUT",
+                    "BIND_INfredBIND_OUT",
+                    900,
+                    1350,
+                    900,
+                ),
+                (
+                    5,
+                    "BIND_IN<null value>BIND_OUT",
+                    "BIND_IN<null value>BIND_OUT",
+                    "BIND_IN<null value>BIND_OUT",
+                    "BIND_IN<null value>BIND_OUT",
+                    2900,
+                    4350,
+                    2900,
+                ),
             ],
-            [1200, 1500, 900],
-            [1800, 2250, 1350],
-            result,
-        ):
-            for col in list(row)[1:5]:
-                eq_(col, assertstr)
-            eq_(row[5], assertint)
-            eq_(row[6], assertint2)
-            for col in row[3], row[4]:
-                assert isinstance(col, util.text_type)
+        )
 
     def test_plain_in_typedec(self, connection):
         users = self.tables.users
@@ -728,6 +781,64 @@ class UserDefinedRoundTripTest(_UserDefinedTypeFixture, fixtures.TablesTest):
         eq_(result.fetchall(), [(3, 1500), (4, 900)])
 
 
+class BindProcessorInsertValuesTest(UserDefinedRoundTripTest):
+    """related to #6770, test that insert().values() applies to
+    bound parameter handlers including the None value."""
+
+    __backend__ = True
+
+    def _data_fixture(self, connection):
+        users = self.tables.users
+        connection.execute(
+            users.insert().values(
+                user_id=2,
+                goofy="jack",
+                goofy2="jack",
+                goofy4=util.u("jack"),
+                goofy7=util.u("jack"),
+                goofy8=12,
+                goofy9=12,
+                goofy10=12,
+            ),
+        )
+        connection.execute(
+            users.insert().values(
+                user_id=3,
+                goofy="lala",
+                goofy2="lala",
+                goofy4=util.u("lala"),
+                goofy7=util.u("lala"),
+                goofy8=15,
+                goofy9=15,
+                goofy10=15,
+            ),
+        )
+        connection.execute(
+            users.insert().values(
+                user_id=4,
+                goofy="fred",
+                goofy2="fred",
+                goofy4=util.u("fred"),
+                goofy7=util.u("fred"),
+                goofy8=9,
+                goofy9=9,
+                goofy10=9,
+            ),
+        )
+        connection.execute(
+            users.insert().values(
+                user_id=5,
+                goofy=None,
+                goofy2=None,
+                goofy4=None,
+                goofy7=None,
+                goofy8=None,
+                goofy9=None,
+                goofy10=None,
+            ),
+        )
+
+
 class UserDefinedTest(
     _UserDefinedTypeFixture, fixtures.TablesTest, AssertsCompiledSQL
 ):