]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
disable col deduping inside of Bundle
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 2 May 2024 15:45:31 +0000 (11:45 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 2 May 2024 19:14:20 +0000 (15:14 -0400)
Fixed issue where attribute key names in :class:`_orm.Bundle` would not be
correct when using ORM enabled :class:`_sql.select` vs.
:class:`_orm.Query`, when the statement contained duplicate column names.

Fixed issue in typing for :class:`_orm.Bundle` where creating a nested
:class:`_orm.Bundle` structure were not allowed.

Fixes: #11347
Change-Id: I24b37c99f83068c668736caaaa06e69a6801ff50
(cherry picked from commit 7d6d7ef73a680d1502ac675b9ae53a6c335b723e)

doc/build/changelog/unreleased_20/11347.rst [new file with mode: 0644]
lib/sqlalchemy/orm/context.py
lib/sqlalchemy/sql/_typing.py
lib/sqlalchemy/sql/selectable.py
test/orm/test_bundle.py
test/typing/plain_files/orm/orm_querying.py

diff --git a/doc/build/changelog/unreleased_20/11347.rst b/doc/build/changelog/unreleased_20/11347.rst
new file mode 100644 (file)
index 0000000..a0f9652
--- /dev/null
@@ -0,0 +1,13 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 11347
+
+    Fixed issue where attribute key names in :class:`_orm.Bundle` would not be
+    correct when using ORM enabled :class:`_sql.select` vs.
+    :class:`_orm.Query`, when the statement contained duplicate column names.
+
+.. change::
+    :tags: bug, typing
+
+    Fixed issue in typing for :class:`_orm.Bundle` where creating a nested
+    :class:`_orm.Bundle` structure were not allowed.
index fcd01e659161ee016f0eec9ba324293d33874678..694e98ae953274bb40e04c110c4442a47d0272bf 100644 (file)
@@ -446,7 +446,7 @@ class ORMCompileState(AbstractORMCompileState):
     ) -> _LabelConventionCallable:
         if legacy:
 
-            def name(col, col_name=None):
+            def name(col, col_name=None, cancel_dedupe=False):
                 if col_name:
                     return col_name
                 else:
@@ -3145,7 +3145,9 @@ class _ORMColumnEntity(_ColumnEntity):
 
         if is_current_entities:
             self._label_name = compile_state._label_convention(
-                column, col_name=orm_key
+                column,
+                col_name=orm_key,
+                cancel_dedupe=parent_bundle is not None,
             )
         else:
             self._label_name = None
index c861bae6e0ffa4563431f06c44819605286f1705..f3bbf4cf29db9bbbb229bb20a7efff805cd269c5 100644 (file)
@@ -185,6 +185,7 @@ _ColumnExpressionArgument = Union[
     _HasClauseElement[_T],
     "SQLCoreOperations[_T]",
     roles.ExpressionElementRole[_T],
+    roles.TypedColumnsClauseRole[_T],
     Callable[[], "ColumnElement[_T]"],
     "LambdaElement",
 ]
index aa823c16b1ec03f8136262bfb137411d3c8c4dee..f33e0a41fb77bb0c6bd73e60c702397f38e5e002 100644 (file)
@@ -4547,6 +4547,7 @@ class SelectState(util.MemoizedSlots, CompileState):
         cls, label_style: SelectLabelStyle
     ) -> _LabelConventionCallable:
         table_qualified = label_style is LABEL_STYLE_TABLENAME_PLUS_COL
+
         dedupe = label_style is not LABEL_STYLE_NONE
 
         pa = prefix_anon_map()
@@ -4555,13 +4556,14 @@ class SelectState(util.MemoizedSlots, CompileState):
         def go(
             c: Union[ColumnElement[Any], TextClause],
             col_name: Optional[str] = None,
+            cancel_dedupe: bool = False,
         ) -> Optional[str]:
             if is_text_clause(c):
                 return None
             elif TYPE_CHECKING:
                 assert is_column_element(c)
 
-            if not dedupe:
+            if not dedupe or cancel_dedupe:
                 name = c._proxy_key
                 if name is None:
                     name = "_no_label"
index 6d613091def16c4dce0a03a6b0692e93c143c521..81e789d1cfe43279454f86259094bf038b3b2b07 100644 (file)
@@ -159,6 +159,65 @@ class BundleTest(fixtures.MappedTest, AssertsCompiledSQL):
             select(b1.c.d1, b1.c.d2), "SELECT data.d1, data.d2 FROM data"
         )
 
+    @testing.variation("stmt_type", ["legacy", "newstyle"])
+    def test_dupe_col_name(self, stmt_type):
+        """test #11347"""
+        Data = self.classes.Data
+        sess = fixture_session()
+
+        b1 = Bundle("b1", Data.d1, Data.d3)
+
+        if stmt_type.legacy:
+            row = (
+                sess.query(Data.d1, Data.d2, b1)
+                .filter(Data.d1 == "d0d1")
+                .one()
+            )
+        elif stmt_type.newstyle:
+            row = sess.execute(
+                select(Data.d1, Data.d2, b1).filter(Data.d1 == "d0d1")
+            ).one()
+
+        eq_(row[2]._mapping, {"d1": "d0d1", "d3": "d0d3"})
+
+    @testing.variation("stmt_type", ["legacy", "newstyle"])
+    def test_dupe_col_name_nested(self, stmt_type):
+        """test #11347"""
+        Data = self.classes.Data
+        sess = fixture_session()
+
+        class DictBundle(Bundle):
+            def create_row_processor(self, query, procs, labels):
+                def proc(row):
+                    return dict(zip(labels, (proc(row) for proc in procs)))
+
+                return proc
+
+        b1 = DictBundle("b1", Data.d1, Data.d3)
+        b2 = DictBundle("b2", Data.d2, Data.d3)
+        b3 = DictBundle("b3", Data.d2, Data.d3, b1, b2)
+
+        if stmt_type.legacy:
+            row = (
+                sess.query(Data.d1, Data.d2, b3)
+                .filter(Data.d1 == "d0d1")
+                .one()
+            )
+        elif stmt_type.newstyle:
+            row = sess.execute(
+                select(Data.d1, Data.d2, b3).filter(Data.d1 == "d0d1")
+            ).one()
+
+        eq_(
+            row[2],
+            {
+                "d2": "d0d2",
+                "d3": "d0d3",
+                "b1": {"d1": "d0d1", "d3": "d0d3"},
+                "b2": {"d2": "d0d2", "d3": "d0d3"},
+            },
+        )
+
     def test_result(self):
         Data = self.classes.Data
         sess = fixture_session()
index 83e0fefabbc91b62a99c0269f7ed847a6abe0e36..8f18e2fcc18124d18fda3b47e18f4384ba3c2c43 100644 (file)
@@ -144,3 +144,8 @@ def test_10937() -> None:
     stmt3: ScalarSelect[str] = select(A.data + B.data).scalar_subquery()
 
     select(stmt, stmt2, stmt3, stmt1)
+
+
+def test_bundles() -> None:
+    b1 = orm.Bundle("b1", A.id, A.data)
+    orm.Bundle("b2", A.id, A.data, b1)