]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
revise approach for bundle deduping
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 4 May 2024 14:38:48 +0000 (10:38 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 4 May 2024 14:44:01 +0000 (10:44 -0400)
Revise the approach from 7d6d7ef73 to make a special case for
Bundle-targeted columns entirely, and don't involve the
_label_convention() callable.   Add tests for select() with
tablename labeling convention.

Fixes: #11347
Change-Id: I1d15523de5709d45b2b69bc17724831ac3425791
(cherry picked from commit 83f8dd53e362c3ea7562c0076add044740d2c4cc)

lib/sqlalchemy/orm/context.py
lib/sqlalchemy/sql/selectable.py
test/orm/test_bundle.py

index 694e98ae953274bb40e04c110c4442a47d0272bf..c2cb54e191c1b4c63a0f6c6efe834a4f7e07d667 100644 (file)
@@ -446,7 +446,7 @@ class ORMCompileState(AbstractORMCompileState):
     ) -> _LabelConventionCallable:
         if legacy:
 
-            def name(col, col_name=None, cancel_dedupe=False):
+            def name(col, col_name=None):
                 if col_name:
                     return col_name
                 else:
@@ -3050,7 +3050,10 @@ class _RawColumnEntity(_ColumnEntity):
         if not is_current_entities or column._is_text_clause:
             self._label_name = None
         else:
-            self._label_name = compile_state._label_convention(column)
+            if parent_bundle:
+                self._label_name = column._proxy_key
+            else:
+                self._label_name = compile_state._label_convention(column)
 
         if parent_bundle:
             parent_bundle._entities.append(self)
@@ -3144,11 +3147,12 @@ class _ORMColumnEntity(_ColumnEntity):
         self.raw_column_index = raw_column_index
 
         if is_current_entities:
-            self._label_name = compile_state._label_convention(
-                column,
-                col_name=orm_key,
-                cancel_dedupe=parent_bundle is not None,
-            )
+            if parent_bundle:
+                self._label_name = orm_key if orm_key else column._proxy_key
+            else:
+                self._label_name = compile_state._label_convention(
+                    column, col_name=orm_key
+                )
         else:
             self._label_name = None
 
index f33e0a41fb77bb0c6bd73e60c702397f38e5e002..be8be8e3add82a3a8b03758ce5f1cdef8db062af 100644 (file)
@@ -4556,14 +4556,13 @@ class SelectState(util.MemoizedSlots, CompileState):
         def go(
             c: Union[ColumnElement[Any], TextClause],
             col_name: Optional[str] = None,
-            cancel_dedupe: bool = False,
         ) -> Optional[str]:
             if is_text_clause(c):
                 return None
             elif TYPE_CHECKING:
                 assert is_column_element(c)
 
-            if not dedupe or cancel_dedupe:
+            if not dedupe:
                 name = c._proxy_key
                 if name is None:
                     name = "_no_label"
index 81e789d1cfe43279454f86259094bf038b3b2b07..a1bd399a4cb6cc39f8398d798b42212c8448b0e7 100644 (file)
@@ -3,6 +3,7 @@ from sqlalchemy import ForeignKey
 from sqlalchemy import func
 from sqlalchemy import Integer
 from sqlalchemy import select
+from sqlalchemy import SelectLabelStyle
 from sqlalchemy import String
 from sqlalchemy import testing
 from sqlalchemy import tuple_
@@ -159,29 +160,68 @@ class BundleTest(fixtures.MappedTest, AssertsCompiledSQL):
             select(b1.c.d1, b1.c.d2), "SELECT data.d1, data.d2 FROM data"
         )
 
-    @testing.variation("stmt_type", ["legacy", "newstyle"])
-    def test_dupe_col_name(self, stmt_type):
+    @testing.variation(
+        "stmt_type", ["legacy", "newstyle", "newstyle_w_label_conv"]
+    )
+    @testing.variation("col_type", ["orm", "core"])
+    def test_dupe_col_name(self, stmt_type, col_type):
         """test #11347"""
         Data = self.classes.Data
         sess = fixture_session()
 
-        b1 = Bundle("b1", Data.d1, Data.d3)
+        if col_type.orm:
+            b1 = Bundle("b1", Data.d1, Data.d3)
+            cols = Data.d1, Data.d2
+        elif col_type.core:
+            data_table = self.tables.data
+            b1 = Bundle("b1", data_table.c.d1, data_table.c.d3)
+            cols = data_table.c.d1, data_table.c.d2
+        else:
+            col_type.fail()
 
         if stmt_type.legacy:
             row = (
-                sess.query(Data.d1, Data.d2, b1)
+                sess.query(cols[0], cols[1], b1)
                 .filter(Data.d1 == "d0d1")
                 .one()
             )
         elif stmt_type.newstyle:
             row = sess.execute(
-                select(Data.d1, Data.d2, b1).filter(Data.d1 == "d0d1")
+                select(cols[0], cols[1], b1).filter(Data.d1 == "d0d1")
             ).one()
+        elif stmt_type.newstyle_w_label_conv:
+            row = sess.execute(
+                select(cols[0], cols[1], b1)
+                .filter(Data.d1 == "d0d1")
+                .set_label_style(
+                    SelectLabelStyle.LABEL_STYLE_TABLENAME_PLUS_COL
+                )
+            ).one()
+        else:
+            stmt_type.fail()
+
+        if stmt_type.newstyle_w_label_conv:
+            # decision is made here that even if a SELECT with the
+            # "tablename_plus_colname" label style, within a Bundle we still
+            # use straight column name, even though the overall row
+            # uses tablename_colname
+            eq_(
+                row._mapping,
+                {"data_d1": "d0d1", "data_d2": "d0d2", "b1": ("d0d1", "d0d3")},
+            )
+        else:
+            eq_(
+                row._mapping,
+                {"d1": "d0d1", "d2": "d0d2", "b1": ("d0d1", "d0d3")},
+            )
 
         eq_(row[2]._mapping, {"d1": "d0d1", "d3": "d0d3"})
 
-    @testing.variation("stmt_type", ["legacy", "newstyle"])
-    def test_dupe_col_name_nested(self, stmt_type):
+    @testing.variation(
+        "stmt_type", ["legacy", "newstyle", "newstyle_w_label_conv"]
+    )
+    @testing.variation("col_type", ["orm", "core"])
+    def test_dupe_col_name_nested(self, stmt_type, col_type):
         """test #11347"""
         Data = self.classes.Data
         sess = fixture_session()
@@ -193,9 +233,18 @@ class BundleTest(fixtures.MappedTest, AssertsCompiledSQL):
 
                 return proc
 
-        b1 = DictBundle("b1", Data.d1, Data.d3)
-        b2 = DictBundle("b2", Data.d2, Data.d3)
-        b3 = DictBundle("b3", Data.d2, Data.d3, b1, b2)
+        if col_type.core:
+            data_table = self.tables.data
+
+            b1 = DictBundle("b1", data_table.c.d1, data_table.c.d3)
+            b2 = DictBundle("b2", data_table.c.d2, data_table.c.d3)
+            b3 = DictBundle("b3", data_table.c.d2, data_table.c.d3, b1, b2)
+        elif col_type.orm:
+            b1 = DictBundle("b1", Data.d1, Data.d3)
+            b2 = DictBundle("b2", Data.d2, Data.d3)
+            b3 = DictBundle("b3", Data.d2, Data.d3, b1, b2)
+        else:
+            col_type.fail()
 
         if stmt_type.legacy:
             row = (
@@ -207,7 +256,45 @@ class BundleTest(fixtures.MappedTest, AssertsCompiledSQL):
             row = sess.execute(
                 select(Data.d1, Data.d2, b3).filter(Data.d1 == "d0d1")
             ).one()
-
+        elif stmt_type.newstyle_w_label_conv:
+            row = sess.execute(
+                select(Data.d1, Data.d2, b3)
+                .filter(Data.d1 == "d0d1")
+                .set_label_style(
+                    SelectLabelStyle.LABEL_STYLE_TABLENAME_PLUS_COL
+                )
+            ).one()
+        else:
+            stmt_type.fail()
+
+        if stmt_type.newstyle_w_label_conv:
+            eq_(
+                row._mapping,
+                {
+                    "data_d1": "d0d1",
+                    "data_d2": "d0d2",
+                    "b3": {
+                        "d2": "d0d2",
+                        "d3": "d0d3",
+                        "b1": {"d1": "d0d1", "d3": "d0d3"},
+                        "b2": {"d2": "d0d2", "d3": "d0d3"},
+                    },
+                },
+            )
+        else:
+            eq_(
+                row._mapping,
+                {
+                    "d1": "d0d1",
+                    "d2": "d0d2",
+                    "b3": {
+                        "d2": "d0d2",
+                        "d3": "d0d3",
+                        "b1": {"d1": "d0d1", "d3": "d0d3"},
+                        "b2": {"d2": "d0d2", "d3": "d0d3"},
+                    },
+                },
+            )
         eq_(
             row[2],
             {