]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
fix ORM support for column-named bindparam() in crud .values()
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 10 Jan 2023 14:51:23 +0000 (09:51 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 11 Jan 2023 16:48:33 +0000 (11:48 -0500)
Fixed bug / regression where using :func:`.bindparam()` with the same name
as a column in the :meth:`.Update.values` method of :class:`.Update`, as
well as the :meth:`.Insert.values` method of :class:`.Insert` in 2.0 only,
would in some cases silently fail to honor the SQL expression in which the
parameter were presented, replacing the expression with a new parameter of
the same name and discarding any other elements of the SQL expression, such
as SQL functions, etc. The specific case would be statements that were
constructed against ORM entities rather than plain :class:`.Table`
instances, but would occur if the statement were invoked with a
:class:`.Session` or a :class:`.Connection`.

:class:`.Update` part of the issue was present in both 2.0 and 1.4 and is
backported to 1.4.

For 1.4, also backports the sqlalchemy.testing.Variation update
to the variation() API.

Fixes: #9075
Change-Id: Ie954bc1f492ec6a566163588182ef4910c7ee452
(cherry picked from commit b5b864e0fe50243a94c0ef04fddda6fa446c1524)

doc/build/changelog/unreleased_14/9075.rst [new file with mode: 0644]
lib/sqlalchemy/sql/crud.py
lib/sqlalchemy/testing/__init__.py
lib/sqlalchemy/testing/config.py
test/orm/test_core_compilation.py
test/sql/test_compiler.py
test/sql/test_insert.py
test/sql/test_update.py

diff --git a/doc/build/changelog/unreleased_14/9075.rst b/doc/build/changelog/unreleased_14/9075.rst
new file mode 100644 (file)
index 0000000..0d96be7
--- /dev/null
@@ -0,0 +1,18 @@
+.. change::
+    :tags: bug, sql
+    :tickets: 9075
+    :versions: 2.0.0rc3
+
+    Fixed bug / regression where using :func:`.bindparam()` with the same name
+    as a column in the :meth:`.Update.values` method of :class:`.Update`, as
+    well as the :meth:`.Insert.values` method of :class:`.Insert` in 2.0 only,
+    would in some cases silently fail to honor the SQL expression in which the
+    parameter were presented, replacing the expression with a new parameter of
+    the same name and discarding any other elements of the SQL expression, such
+    as SQL functions, etc. The specific case would be statements that were
+    constructed against ORM entities rather than plain :class:`.Table`
+    instances, but would occur if the statement were invoked with a
+    :class:`.Session` or a :class:`.Connection`.
+
+    :class:`.Update` part of the issue was present in both 2.0 and 1.4 and is
+    backported to 1.4.
index 48ab7212861ced4ee4a1e4372a0ebef0cfb5acea..4f509d9a562a39082537520660d6e56454f71eb6 100644 (file)
@@ -77,14 +77,17 @@ def _get_crud_params(compiler, stmt, compile_state, **kw):
     if compile_state._has_multi_parameters:
         spd = compile_state._multi_parameters[0]
         stmt_parameter_tuples = list(spd.items())
+        spd_str_key = {_column_as_key(key) for key in spd}
     elif compile_state._ordered_values:
         spd = compile_state._dict_parameters
         stmt_parameter_tuples = compile_state._ordered_values
+        spd_str_key = {_column_as_key(key) for key in spd}
     elif compile_state._dict_parameters:
         spd = compile_state._dict_parameters
         stmt_parameter_tuples = list(spd.items())
+        spd_str_key = {_column_as_key(key) for key in spd}
     else:
-        stmt_parameter_tuples = spd = None
+        stmt_parameter_tuples = spd = spd_str_key = None
 
     # if we have statement parameters - set defaults in the
     # compiled params
@@ -94,7 +97,7 @@ def _get_crud_params(compiler, stmt, compile_state, **kw):
         parameters = dict(
             (_column_as_key(key), REQUIRED)
             for key in compiler.column_keys
-            if key not in spd
+            if key not in spd_str_key
         )
     else:
         parameters = dict(
index bfd8cc3dc36bbb9ead781a43e899a7750a81dbb2..28bc3c5efd2c393c0dad87b13db5df45064d46e1 100644 (file)
@@ -50,6 +50,7 @@ from .config import db
 from .config import fixture
 from .config import requirements as requires
 from .config import skip_test
+from .config import Variation
 from .config import variation
 from .exclusions import _is_excluded
 from .exclusions import _server_version
index e61bf2694a11049a241622b6d12796dfb016a738..ab52d233da9427a478129119baf5c6941df532c7 100644 (file)
@@ -94,21 +94,57 @@ def combinations_list(arg_iterable, **kw):
     return combinations(*arg_iterable, **kw)
 
 
-class _variation_base(object):
-    __slots__ = ("name", "argname")
+class Variation(object):
+    __slots__ = ("_name", "_argname")
 
     def __init__(self, case, argname, case_names):
-        self.name = case
-        self.argname = argname
+        self._name = case
+        self._argname = argname
         for casename in case_names:
             setattr(self, casename, casename == case)
 
+    @property
+    def name(self):
+        return self._name
+
     def __bool__(self):
-        return self.name == self.argname
+        return self._name == self._argname
 
     def __nonzero__(self):
         return not self.__bool__()
 
+    def __str__(self):
+        return "%s=%r" % (self._argname, self._name)
+
+    def __repr__(self):
+        return str(self)
+
+    def fail(self):
+        # can't import util.fail() under py2.x without resolving
+        # import cycle
+        assert False, "Unknown %s" % (self,)
+
+    @classmethod
+    def idfn(cls, variation):
+        return variation.name
+
+    @classmethod
+    def generate_cases(cls, argname, cases):
+        case_names = [
+            argname if c is True else "not_" + argname if c is False else c
+            for c in cases
+        ]
+
+        typ = type(
+            argname,
+            (Variation,),
+            {
+                "__slots__": tuple(case_names),
+            },
+        )
+
+        return [typ(casename, argname, case_names) for casename in case_names]
+
 
 def variation(argname, cases):
     """a helper around testing.combinations that provides a single namespace
@@ -138,7 +174,7 @@ def variation(argname, cases):
             elif querytyp.legacy_query:
                 stmt = Session.query(Thing)
             else:
-                assert False
+                querytyp.fail()
 
 
     The variable provided is a slots object of boolean variables, as well
@@ -146,26 +182,35 @@ def variation(argname, cases):
 
     """
 
-    case_names = [
-        argname if c is True else "not_" + argname if c is False else c
-        for c in cases
+    cases_plus_limitations = [
+        entry
+        if (isinstance(entry, tuple) and len(entry) == 2)
+        else (entry, None)
+        for entry in cases
     ]
 
-    typ = type(
-        argname,
-        (_variation_base,),
-        {
-            "__slots__": tuple(case_names),
-        },
+    variations = Variation.generate_cases(
+        argname, [c for c, l in cases_plus_limitations]
     )
-
     return combinations(
-        *[
-            (casename, typ(casename, argname, case_names))
-            for casename in case_names
-        ],
         id_="ia",
-        argnames=argname
+        argnames=argname,
+        *[
+            (variation._name, variation, limitation)
+            if limitation is not None
+            else (variation._name, variation)
+            for variation, (case, limitation) in zip(
+                variations, cases_plus_limitations
+            )
+        ]
+    )
+
+
+def variation_fixture(argname, cases, scope="function"):
+    return fixture(
+        params=Variation.generate_cases(argname, cases),
+        ids=Variation.idfn,
+        scope=scope,
     )
 
 
index 16bdbf2fd4d3681249a7e3056ae74d52e30acc6a..c5a76f04f7b68c8132afba07496814b2dba819a4 100644 (file)
@@ -40,12 +40,14 @@ from sqlalchemy.testing import AssertsCompiledSQL
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import is_
+from sqlalchemy.testing import Variation
 from sqlalchemy.testing.assertions import expect_raises_message
 from sqlalchemy.testing.fixtures import fixture_session
 from sqlalchemy.testing.util import resolve_lambda
 from sqlalchemy.util.langhelpers import hybridproperty
 from .inheritance import _poly_fixtures
 from .test_query import QueryTest
+from ..sql import test_compiler
 from ..sql.test_compiler import CorrelateTest as _CoreCorrelateTest
 
 # TODO:
@@ -2643,3 +2645,29 @@ class CorrelateTest(fixtures.DeclarativeMappedTest, _CoreCorrelateTest):
     def _fixture(self):
         t1, t2 = self.classes("T1", "T2")
         return t1, t2, select(t1).where(t1.c.a == t2.c.a)
+
+
+class CrudParamOverlapTest(test_compiler.CrudParamOverlapTest):
+    @testing.fixture(
+        params=Variation.generate_cases("type_", ["orm"]),
+        ids=["orm"],
+    )
+    def crud_table_fixture(self, request):
+        type_ = request.param
+
+        if type_.orm:
+            from sqlalchemy.orm import declarative_base
+
+            Base = declarative_base()
+
+            class Foo(Base):
+                __tablename__ = "mytable"
+                myid = Column(Integer, primary_key=True)
+                name = Column(String)
+                description = Column(String)
+
+            table1 = Foo
+        else:
+            type_.fail()
+
+        yield table1
index 9ede4af9237b44e55028cc2177b53ffae2385129..79826d2fb8d474bd90cc37176d9bfb1e689a3fd1 100644 (file)
@@ -33,6 +33,7 @@ from sqlalchemy import Float
 from sqlalchemy import ForeignKey
 from sqlalchemy import func
 from sqlalchemy import Index
+from sqlalchemy import insert
 from sqlalchemy import Integer
 from sqlalchemy import intersect
 from sqlalchemy import join
@@ -61,6 +62,7 @@ from sqlalchemy import type_coerce
 from sqlalchemy import types
 from sqlalchemy import union
 from sqlalchemy import union_all
+from sqlalchemy import update
 from sqlalchemy import util
 from sqlalchemy.dialects import mysql
 from sqlalchemy.dialects import oracle
@@ -97,6 +99,7 @@ from sqlalchemy.testing import is_
 from sqlalchemy.testing import is_true
 from sqlalchemy.testing import mock
 from sqlalchemy.testing import ne_
+from sqlalchemy.testing import Variation
 from sqlalchemy.testing.schema import pep435_enum
 from sqlalchemy.types import UserDefinedType
 from sqlalchemy.util import u
@@ -4907,6 +4910,179 @@ class BindParameterTest(AssertsCompiledSQL, fixtures.TestBase):
                 )
 
 
+class CrudParamOverlapTest(AssertsCompiledSQL, fixtures.TestBase):
+    """tests for #9075.
+
+    we apparently allow same-column-named bindparams in values(), even though
+    we do *not* allow same-column-named bindparams in other parts of the
+    statement, but only if the bindparam is associated with that column in the
+    VALUES / SET clause. If you use a name that matches that of a column in
+    values() but associate it with a different column, you also get the error.
+
+    This is supported, see
+    test_insert.py::InsertTest::test_binds_that_match_columns and
+    test_update.py::UpdateTest::test_binds_that_match_columns.  The use
+    case makes sense because the "overlapping binds" issue is that using
+    a column name in bindparam() will conflict with the bindparam()
+    that crud.py is going to make for that column in VALUES / SET; but if we
+    are replacing the actual expression that would be in VALUES / SET, then
+    it's fine, there is no conflict.
+
+    The test suite is extended in
+    test/orm/test_core_compilation.py with ORM mappings that caused
+    the failure that was fixed by #9075.
+
+
+    """
+
+    __dialect__ = "default"
+
+    @testing.fixture(
+        params=Variation.generate_cases("type_", ["lowercase", "uppercase"]),
+        ids=["lowercase", "uppercase"],
+    )
+    def crud_table_fixture(self, request):
+        type_ = request.param
+
+        if type_.lowercase:
+            table1 = table(
+                "mytable",
+                column("myid", Integer),
+                column("name", String),
+                column("description", String),
+            )
+        elif type_.uppercase:
+            table1 = Table(
+                "mytable",
+                MetaData(),
+                Column("myid", Integer),
+                Column("name", String),
+                Column("description", String),
+            )
+        else:
+            type_.fail()
+
+        yield table1
+
+    def test_same_named_binds_insert_values(self, crud_table_fixture):
+        table1 = crud_table_fixture
+        stmt = insert(table1).values(
+            myid=bindparam("myid"),
+            description=func.coalesce(bindparam("description"), "default"),
+        )
+        self.assert_compile(
+            stmt,
+            "INSERT INTO mytable (myid, description) VALUES "
+            "(:myid, coalesce(:description, :coalesce_1))",
+        )
+
+        self.assert_compile(
+            stmt,
+            "INSERT INTO mytable (myid, description) VALUES "
+            "(:myid, coalesce(:description, :coalesce_1))",
+            params={"myid": 5, "description": "foo"},
+            checkparams={
+                "coalesce_1": "default",
+                "description": "foo",
+                "myid": 5,
+            },
+        )
+
+        self.assert_compile(
+            stmt,
+            "INSERT INTO mytable (myid, name, description) VALUES "
+            "(:myid, :name, coalesce(:description, :coalesce_1))",
+            params={"myid": 5, "description": "foo", "name": "bar"},
+            checkparams={
+                "coalesce_1": "default",
+                "description": "foo",
+                "myid": 5,
+                "name": "bar",
+            },
+        )
+
+    def test_same_named_binds_update_values(self, crud_table_fixture):
+        table1 = crud_table_fixture
+        stmt = update(table1).values(
+            myid=bindparam("myid"),
+            description=func.coalesce(bindparam("description"), "default"),
+        )
+        self.assert_compile(
+            stmt,
+            "UPDATE mytable SET myid=:myid, "
+            "description=coalesce(:description, :coalesce_1)",
+        )
+
+        self.assert_compile(
+            stmt,
+            "UPDATE mytable SET myid=:myid, "
+            "description=coalesce(:description, :coalesce_1)",
+            params={"myid": 5, "description": "foo"},
+            checkparams={
+                "coalesce_1": "default",
+                "description": "foo",
+                "myid": 5,
+            },
+        )
+
+        self.assert_compile(
+            stmt,
+            "UPDATE mytable SET myid=:myid, name=:name, "
+            "description=coalesce(:description, :coalesce_1)",
+            params={"myid": 5, "description": "foo", "name": "bar"},
+            checkparams={
+                "coalesce_1": "default",
+                "description": "foo",
+                "myid": 5,
+                "name": "bar",
+            },
+        )
+
+    def test_different_named_binds_insert_values(self, crud_table_fixture):
+        table1 = crud_table_fixture
+        stmt = insert(table1).values(
+            myid=bindparam("myid"),
+            name=func.coalesce(bindparam("description"), "default"),
+        )
+        self.assert_compile(
+            stmt,
+            "INSERT INTO mytable (myid, name) VALUES "
+            "(:myid, coalesce(:description, :coalesce_1))",
+        )
+
+        with expect_raises_message(
+            exc.CompileError, r"bindparam\(\) name 'description' is reserved "
+        ):
+            stmt.compile(column_keys=["myid", "description"])
+
+        with expect_raises_message(
+            exc.CompileError, r"bindparam\(\) name 'description' is reserved "
+        ):
+            stmt.compile(column_keys=["myid", "description", "name"])
+
+    def test_different_named_binds_update_values(self, crud_table_fixture):
+        table1 = crud_table_fixture
+        stmt = update(table1).values(
+            myid=bindparam("myid"),
+            name=func.coalesce(bindparam("description"), "default"),
+        )
+        self.assert_compile(
+            stmt,
+            "UPDATE mytable SET myid=:myid, "
+            "name=coalesce(:description, :coalesce_1)",
+        )
+
+        with expect_raises_message(
+            exc.CompileError, r"bindparam\(\) name 'description' is reserved "
+        ):
+            stmt.compile(column_keys=["myid", "description"])
+
+        with expect_raises_message(
+            exc.CompileError, r"bindparam\(\) name 'description' is reserved "
+        ):
+            stmt.compile(column_keys=["myid", "description", "name"])
+
+
 class UnsupportedTest(fixtures.TestBase):
     def test_unsupported_element_str_visit_name(self):
         from sqlalchemy.sql.expression import ClauseElement
index 741859fb2cf7abb4c3c16f80607a4d2b69282e65..c052ac5da43ad5c5f348d7721d403b6cfd2c4487 100644 (file)
@@ -68,7 +68,11 @@ class InsertTest(_InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL):
 
     def test_binds_that_match_columns(self):
         """test bind params named after column names
-        replace the normal SET/VALUES generation."""
+        replace the normal SET/VALUES generation.
+
+        See also test_compiler.py::CrudParamOverlapTest
+
+        """
 
         t = table("foo", column("x"), column("y"))
 
index 93deae5565edb7749fb4e7347ae1b7aed59e067e..214fb913fa6e0c21cfc71667dde2b5127bd9ea2e 100644 (file)
@@ -316,7 +316,11 @@ class UpdateTest(_UpdateFromTestBase, fixtures.TablesTest, AssertsCompiledSQL):
 
     def test_binds_that_match_columns(self):
         """test bind params named after column names
-        replace the normal SET/VALUES generation."""
+        replace the normal SET/VALUES generation.
+
+        See also test_compiler.py::CrudParamOverlapTest
+
+        """
 
         t = table("foo", column("x"), column("y"))