]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Further refine labeling for renamed columns
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 11 Feb 2021 19:05:49 +0000 (14:05 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 12 Feb 2021 23:58:53 +0000 (18:58 -0500)
Forked from I22f6cf0f0b3360e55299cdcb2452cead2b2458ea
we are attempting to decide the case for columns mapped
under a different name.   since the .key feature of
Column seems to support this fully, see if an annotation
can be used to indicate an effective .key for a column.

The effective change is that the labeling of column expressions
in rows has been improved to retain the original name of the ORM
attribute even if used in a subquery.

References: #5933
Change-Id: If251f556f7d723f50d349f765f1690d6c679d2ef

15 files changed:
doc/build/changelog/unreleased_14/5933.rst
lib/sqlalchemy/orm/attributes.py
lib/sqlalchemy/orm/context.py
lib/sqlalchemy/orm/descriptor_props.py
lib/sqlalchemy/orm/persistence.py
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/orm/relationships.py
lib/sqlalchemy/orm/util.py
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/schema.py
lib/sqlalchemy/sql/selectable.py
test/orm/test_query.py
test/orm/test_utils.py
test/sql/test_compiler.py
test/sql/test_selectable.py

index 2d510413c8a30f20bbd734e668f32deded6f961c..15409cde4f774eaec84b311ca4bcdd7b1f3ecf7e 100644 (file)
@@ -6,7 +6,9 @@
     style would not be preserved in the keys used by result rows; this has been
     applied to all combinations of Core/ORM columns / session vs. connection
     etc. so that the linkage from statement to result row is the same in all
-    cases.
+    cases.   As part of this change, the labeling of column expressions
+    in rows has been improved to retain the original name of the ORM
+    attribute even if used in a subquery.
 
 
 
index dd354c4e0740bebdb845a817321d26efe7b2f3a7..b96b3b61e34885125aa11d8f8ec8f80ff88e09a1 100644 (file)
@@ -212,7 +212,7 @@ class QueryableAttribute(
         """
 
         return self.comparator.__clause_element__()._annotate(
-            {"orm_key": self.key, "entity_namespace": self._entity_namespace}
+            {"proxy_key": self.key, "entity_namespace": self._entity_namespace}
         )
 
     @property
index 621ed826c022da57931b527a0d9f40c863674ff2..fa192a17e5f440f47999ae0ecbb9ebe362af15c9 100644 (file)
@@ -2619,12 +2619,12 @@ class _ORMColumnEntity(_ColumnEntity):
 
         _entity = parententity
 
-        # an AliasedClass won't have orm_key in the annotations for
+        # an AliasedClass won't have proxy_key in the annotations for
         # a column if it was acquired using the class' adapter directly,
         # such as using AliasedInsp._adapt_element().  this occurs
         # within internal loaders.
 
-        orm_key = annotations.get("orm_key", None)
+        orm_key = annotations.get("proxy_key", None)
         if orm_key:
             self.expr = getattr(_entity.entity, orm_key)
         else:
index 695b1a7b4ecfd9370739af4b58192d25dc8632c3..65fe58e615a39705e25e29e31368dc56c1d4174f 100644 (file)
@@ -413,7 +413,7 @@ class CompositeProperty(DescriptorProperty):
                 {
                     "parententity": self._parententity,
                     "parentmapper": self._parententity,
-                    "orm_key": self.prop.key,
+                    "proxy_key": self.prop.key,
                 }
             )
             return CompositeProperty.CompositeBundle(self.prop, clauses)
index 03de64392d82ec9349b8be7281ba015486917091..f19f29daa7e550d930dfb4f9a478be5f3e3e7273 100644 (file)
@@ -2018,7 +2018,7 @@ class BulkUDCompileState(CompileState):
                     elif "entity_namespace" in k._annotations:
                         k_anno = k._annotations
                         attr = _entity_namespace_key(
-                            k_anno["entity_namespace"], k_anno["orm_key"]
+                            k_anno["entity_namespace"], k_anno["proxy_key"]
                         )
                         values.extend(attr._bulk_update_tuples(v))
                     else:
index 4d0c7528bf6855cdfabea68c3e51f392d048f9c2..7823aca20d16f3bab5ac0ed3779c24d07d3a00dc 100644 (file)
@@ -359,7 +359,7 @@ class ColumnProperty(StrategizedProperty):
                 "entity_namespace": pe,
                 "parententity": pe,
                 "parentmapper": pe,
-                "orm_key": self.prop.key,
+                "proxy_key": self.prop.key,
             }
 
             col = column
index 1a01409148c5646ed2f8cfae06dbb7670344f665..b0a2f7b088a47b78c859a8b07834d287cacf7990 100644 (file)
@@ -2693,11 +2693,11 @@ class JoinCondition(object):
         """
 
         self.primaryjoin = _deep_deannotate(
-            self.primaryjoin, values=("parententity", "orm_key")
+            self.primaryjoin, values=("parententity", "proxy_key")
         )
         if self.secondaryjoin is not None:
             self.secondaryjoin = _deep_deannotate(
-                self.secondaryjoin, values=("parententity", "orm_key")
+                self.secondaryjoin, values=("parententity", "proxy_key")
             )
 
     def _determine_joins(self):
index 1bc0ceb4df3cc38cd7607698ba3e285773ee6524..290f502362275e954cbb37839de69f970b3bf5ee 100644 (file)
@@ -771,7 +771,7 @@ class AliasedInsp(
             "compile_state_plugin": "orm",
         }
         if key:
-            d["orm_key"] = key
+            d["proxy_key"] = key
         return (
             self._adapter.traverse(elem)
             ._annotate(d)
index 2bd1c3ae31f9e5d43b58d6397cf54dce8fd83447..85200bf25210dbb4cdeed2fb002968e562fa01e6 100644 (file)
@@ -894,7 +894,9 @@ class ColumnElement(
 
     @util.memoized_property
     def _proxy_key(self):
-        if self.key:
+        if self._annotations and "proxy_key" in self._annotations:
+            return self._annotations["proxy_key"]
+        elif self.key:
             return self.key
         else:
             try:
@@ -987,7 +989,22 @@ class ColumnElement(
         expressions and function calls.
 
         """
-        return self._anon_label(getattr(self, "name", None))
+        name = getattr(self, "name", None)
+        return self._anon_label(name)
+
+    @util.memoized_property
+    def anon_key_label(self):
+        """Provides a constant 'anonymous key label' for this ColumnElement.
+
+        Compare to ``anon_label``, except that the "key" of the column,
+        if available, is used to generate the label.
+
+        This is used when a deduplicating key is placed into the columns
+        collection of a selectable.
+
+        """
+        name = getattr(self, "key", None) or getattr(self, "name", None)
+        return self._anon_label(name)
 
     @util.memoized_property
     def _dedupe_anon_label(self):
@@ -998,6 +1015,10 @@ class ColumnElement(
     def _label_anon_label(self):
         return self._anon_label(getattr(self, "_label", None))
 
+    @util.memoized_property
+    def _label_anon_key_label(self):
+        return self._anon_label(getattr(self, "_key_label", None))
+
     @util.memoized_property
     def _dedupe_label_anon_label(self):
         label = getattr(self, "_label", None) or "anon"
@@ -3720,9 +3741,6 @@ class Grouping(GroupedElement, ColumnElement):
     def _key_label(self):
         return self._label
 
-    def _gen_label(self, name):
-        return name
-
     @property
     def _label(self):
         return getattr(self.element, "_label", None) or self.anon_label
@@ -4345,8 +4363,9 @@ class NamedColumn(ColumnElement):
 
     @HasMemoized.memoized_attribute
     def _key_label(self):
-        if self.key != self.name:
-            return self._gen_label(self.key)
+        proxy_key = self._proxy_key
+        if proxy_key != self.name:
+            return self._gen_label(proxy_key)
         else:
             return self._label
 
@@ -4859,7 +4878,8 @@ def _corresponding_column_or_error(fromclause, column, require_embedded=False):
 class AnnotatedColumnElement(Annotated):
     def __init__(self, element, values):
         Annotated.__init__(self, element, values)
-        self.__dict__.pop("comparator", None)
+        for attr in ("comparator", "_proxy_key", "_key_label"):
+            self.__dict__.pop(attr, None)
         for attr in ("name", "key", "table"):
             if self.__dict__.get(attr, False) is None:
                 self.__dict__.pop(attr)
index 127b12e8172ea4228646232d50a594c175ffead2..6e3c9dbfbfda7b4d4c148d494bb2ce25d9f24772 100644 (file)
@@ -1928,6 +1928,7 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause):
                 if name_is_truncatable
                 else (name or self.name),
                 self.type,
+                # this may actually be ._proxy_key when the key is incoming
                 key=key if key else name if name else self.key,
                 primary_key=self.primary_key,
                 nullable=self.nullable,
index 7e2c5dd3bd081999bac5f42ff0c4107ac0daabae..23fdf7e12b0bf1936f55dd646739130cb1d9c75f 100644 (file)
@@ -4078,51 +4078,42 @@ class SelectState(util.MemoizedSlots, CompileState):
 
     @classmethod
     def _column_naming_convention(cls, label_style):
-        names = set()
-        pa = []
-
         if label_style is LABEL_STYLE_NONE:
 
             def go(c, col_name=None):
-                return col_name or c._proxy_key
+                return c._proxy_key
 
         elif label_style is LABEL_STYLE_TABLENAME_PLUS_COL:
+            names = set()
+            pa = []  # late-constructed as needed, python 2 has no "nonlocal"
 
             def go(c, col_name=None):
                 # we use key_label since this name is intended for targeting
                 # within the ColumnCollection only, it's not related to SQL
                 # rendering which always uses column name for SQL label names
 
-                if col_name:
-                    name = c._gen_label(col_name)
-                else:
-                    name = c._key_label
+                name = c._key_label
 
                 if name in names:
                     if not pa:
                         pa.append(prefix_anon_map())
 
-                    name = c._label_anon_label % pa[0]
+                    name = c._label_anon_key_label % pa[0]
                 else:
                     names.add(name)
 
                 return name
 
         else:
+            names = set()
+            pa = []  # late-constructed as needed, python 2 has no "nonlocal"
 
             def go(c, col_name=None):
-                # we use key_label since this name is intended for targeting
-                # within the ColumnCollection only, it's not related to SQL
-                # rendering which always uses column name for SQL label names
-                if col_name:
-                    name = col_name
-                else:
-                    name = c._proxy_key
+                name = c._proxy_key
                 if name in names:
                     if not pa:
                         pa.append(prefix_anon_map())
-
-                    name = c.anon_label % pa[0]
+                    name = c.anon_key_label % pa[0]
                 else:
                     names.add(name)
 
@@ -5617,6 +5608,14 @@ class Select(
         return self
 
     def _generate_columns_plus_names(self, anon_for_dupe_key):
+        """Generate column names as rendered in a SELECT statement by
+        the compiler.
+
+        This is distinct from other name generators that are intended for
+        population of .c collections and similar, which may have slightly
+        different rules.
+
+        """
         cols = self._exported_columns_iterator()
 
         # when use_labels is on:
@@ -5732,19 +5731,17 @@ class Select(
                 if key is not None and key in keys_seen:
                     if pa is None:
                         pa = prefix_anon_map()
-                    key = c._label_anon_label % pa
+                    key = c._label_anon_key_label % pa
                 keys_seen.add(key)
             elif disambiguate_only:
-                key = c.key
+                key = c._proxy_key
                 if key is not None and key in keys_seen:
                     if pa is None:
                         pa = prefix_anon_map()
-                    key = c.anon_label % pa
+                    key = c.anon_key_label % pa
                 keys_seen.add(key)
             else:
-                # one of the above label styles is set for subqueries
-                # as of #5221 so this codepath is likely not called now.
-                key = None
+                key = c._proxy_key
             prox.append(
                 c._make_proxy(
                     subquery, key=key, name=name, name_is_truncatable=True
index d86d2ff702a9d99a2b43798fbd2577c156f3737c..f58a1161e47139cc57701fc6bbe321dc8ffa612a 100644 (file)
@@ -25,6 +25,7 @@ from sqlalchemy import LABEL_STYLE_NONE
 from sqlalchemy import LABEL_STYLE_TABLENAME_PLUS_COL
 from sqlalchemy import literal
 from sqlalchemy import literal_column
+from sqlalchemy import MetaData
 from sqlalchemy import null
 from sqlalchemy import or_
 from sqlalchemy import select
@@ -71,6 +72,7 @@ from sqlalchemy.testing.assertions import assert_raises
 from sqlalchemy.testing.assertions import assert_raises_message
 from sqlalchemy.testing.assertions import eq_
 from sqlalchemy.testing.assertions import expect_warnings
+from sqlalchemy.testing.assertions import is_not_none
 from sqlalchemy.testing.assertsql import CompiledSQL
 from sqlalchemy.testing.fixtures import fixture_session
 from sqlalchemy.testing.schema import Column
@@ -141,19 +143,44 @@ class OnlyReturnTuplesTest(QueryTest):
 class RowTupleTest(QueryTest):
     run_setup_mappers = None
 
-    def test_custom_names(self):
+    @testing.combinations((True,), (False,), argnames="legacy")
+    @testing.combinations((True,), (False,), argnames="use_subquery")
+    @testing.combinations((True,), (False,), argnames="set_column_key")
+    def test_custom_names(self, legacy, use_subquery, set_column_key):
+        """Test labeling as used with ORM attributes named differently from
+        the column.
+
+        Compare to the tests in RowLabelingTest which tests this also,
+        this test is more oriented towards legacy Query use.
+
+        """
         User, users = self.classes.User, self.tables.users
 
-        mapper(User, users, properties={"uname": users.c.name})
+        if set_column_key:
+            uwkey = Table(
+                "users",
+                MetaData(),
+                Column("id", Integer, primary_key=True),
+                Column("name", String, key="uname"),
+            )
+            mapper(User, uwkey)
+        else:
+            mapper(User, users, properties={"uname": users.c.name})
 
-        row = (
-            fixture_session()
-            .query(User.id, User.uname)
-            .filter(User.id == 7)
-            .first()
-        )
+        s = fixture_session()
+        if legacy:
+            q = s.query(User.id, User.uname).filter(User.id == 7)
+            if use_subquery:
+                q = s.query(q.subquery())
+            row = q.first()
+        else:
+            q = select(User.id, User.uname).filter(User.id == 7)
+            if use_subquery:
+                q = select(q.subquery())
+            row = s.execute(q).first()
 
         eq_(row.id, 7)
+
         eq_(row.uname, "jack")
 
     @testing.combinations(
@@ -494,7 +521,7 @@ class RowTupleTest(QueryTest):
 class RowLabelingTest(QueryTest):
     @testing.fixture
     def assert_row_keys(self):
-        def go(stmt, expected, coreorm_exec):
+        def go(stmt, expected, coreorm_exec, selected_columns=None):
 
             if coreorm_exec == "core":
                 with testing.db.connect() as conn:
@@ -506,18 +533,25 @@ class RowLabelingTest(QueryTest):
 
             eq_(row.keys(), expected)
 
+            if selected_columns is None:
+                selected_columns = expected
+
             # we are disambiguating in exported_columns even if
             # LABEL_STYLE_NONE, this seems weird also
             if (
                 stmt._label_style is not LABEL_STYLE_NONE
                 and coreorm_exec == "core"
             ):
-                eq_(stmt.exported_columns.keys(), list(expected))
+                eq_(stmt.exported_columns.keys(), list(selected_columns))
 
             if (
                 stmt._label_style is not LABEL_STYLE_NONE
                 and coreorm_exec == "orm"
             ):
+
+                for k in expected:
+                    is_not_none(getattr(row, k))
+
                 try:
                     column_descriptions = stmt.column_descriptions
                 except (NotImplementedError, AttributeError):
@@ -529,7 +563,7 @@ class RowLabelingTest(QueryTest):
                             for entity in column_descriptions
                             if entity["name"] is not None
                         ],
-                        list(expected),
+                        list(selected_columns),
                     )
 
         return go
@@ -711,6 +745,109 @@ class RowLabelingTest(QueryTest):
         row = s.query(u1, u2).join(u2, u1.c.id == u2.c.id).first()
         eq_(row.keys(), ["id", "name", "id", "name"])
 
+    @testing.fixture
+    def uname_fixture(self):
+        class Foo(object):
+            pass
+
+        if False:
+            m = MetaData()
+            users = Table(
+                "users",
+                m,
+                Column("id", Integer, primary_key=True),
+                Column("name", String, key="uname"),
+            )
+            mapper(Foo, users, properties={"uname": users.c.uname})
+        else:
+            users = self.tables.users
+            mapper(Foo, users, properties={"uname": users.c.name})
+
+        return Foo
+
+    @testing.combinations(
+        (LABEL_STYLE_NONE, ("id", "name"), ("id", "uname")),
+        (LABEL_STYLE_DISAMBIGUATE_ONLY, ("id", "name"), ("id", "uname")),
+        (
+            LABEL_STYLE_TABLENAME_PLUS_COL,
+            ("users_id", "users_name"),
+            ("users_id", "users_uname"),
+        ),
+        argnames="label_style,expected_core,expected_orm",
+    )
+    @testing.combinations(("core",), ("orm",), argnames="coreorm_exec")
+    def test_renamed_properties_columns(
+        self,
+        label_style,
+        expected_core,
+        expected_orm,
+        uname_fixture,
+        assert_row_keys,
+        coreorm_exec,
+    ):
+        Foo = uname_fixture
+
+        stmt = select(Foo.id, Foo.uname).set_label_style(label_style)
+
+        if coreorm_exec == "core":
+            assert_row_keys(
+                stmt,
+                expected_core,
+                coreorm_exec,
+                selected_columns=expected_orm,
+            )
+        else:
+            assert_row_keys(stmt, expected_orm, coreorm_exec)
+
+    @testing.combinations(
+        (
+            LABEL_STYLE_NONE,
+            ("id", "name", "id", "name"),
+            ("id", "uname", "id", "uname"),
+        ),
+        (
+            LABEL_STYLE_DISAMBIGUATE_ONLY,
+            ("id", "name", "id_1", "name_1"),
+            ("id", "uname", "id_1", "uname_1"),
+        ),
+        (
+            LABEL_STYLE_TABLENAME_PLUS_COL,
+            ("u1_id", "u1_name", "u2_id", "u2_name"),
+            ("u1_id", "u1_uname", "u2_id", "u2_uname"),
+        ),
+        argnames="label_style,expected_core,expected_orm",
+    )
+    @testing.combinations(("core",), ("orm",), argnames="coreorm_exec")
+    # @testing.combinations(("orm",), argnames="coreorm_exec")
+    def test_renamed_properties_subq(
+        self,
+        label_style,
+        expected_core,
+        expected_orm,
+        uname_fixture,
+        assert_row_keys,
+        coreorm_exec,
+    ):
+        Foo = uname_fixture
+
+        u1 = select(Foo.id, Foo.uname).subquery("u1")
+        u2 = select(Foo.id, Foo.uname).subquery("u2")
+
+        stmt = (
+            select(u1, u2)
+            .join_from(u1, u2, u1.c.id == u2.c.id)
+            .set_label_style(label_style)
+        )
+        if coreorm_exec == "core":
+            assert_row_keys(
+                stmt,
+                expected_core,
+                coreorm_exec,
+                selected_columns=expected_orm,
+            )
+        else:
+            assert_row_keys(stmt, expected_orm, coreorm_exec)
+
     def test_entity_anon_aliased(self, assert_row_keys):
         User = self.classes.User
 
index 260cae37bd40dd57e2cbbb5f7cefa82d18b0f58c..a2182e3eaa3dd79b60a04dac76d774555f2bcb9e 100644 (file)
@@ -243,7 +243,7 @@ class AliasedClassTest(fixtures.TestBase, AssertsCompiledSQL):
                 "entity_namespace": point_mapper,
                 "parententity": point_mapper,
                 "parentmapper": point_mapper,
-                "orm_key": "x_alone",
+                "proxy_key": "x_alone",
             },
         )
         eq_(
@@ -252,7 +252,7 @@ class AliasedClassTest(fixtures.TestBase, AssertsCompiledSQL):
                 "entity_namespace": point_mapper,
                 "parententity": point_mapper,
                 "parentmapper": point_mapper,
-                "orm_key": "x",
+                "proxy_key": "x",
             },
         )
 
index 2a543aa613622399da537afd51820b4a0a77682e..140de8622305dc1d5d28619d564f5c9708469875 100644 (file)
@@ -659,8 +659,9 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
             "SELECT sum(lala(mytable.myid)) AS bar FROM mytable",
         )
 
+    def test_use_labels_keyed(self):
         self.assert_compile(
-            select(keyed), "SELECT keyed.x, keyed.y" ", keyed.z FROM keyed"
+            select(keyed), "SELECT keyed.x, keyed.y, keyed.z FROM keyed"
         )
 
         self.assert_compile(
index e15c740752da0176656ce3c62ce280fe7b021718..9f0c72247c0b52b893c4efbb9e6962a3c1a3cc0b 100644 (file)
@@ -3067,11 +3067,11 @@ class WithLabelsTest(fixtures.TestBase):
         )
         eq_(
             list(sel.selected_columns.keys()),
-            ["t_x_id", "t_x_b_1"],
+            ["t_x_id", "t_x_id_1"],
         )
         eq_(
             list(sel.subquery().c.keys()),
-            ["t_x_id", "t_x_b_1"],
+            ["t_x_id", "t_x_id_1"],
         )
         self._assert_result_keys(sel, ["t_a", "t_x_b"])
         self._assert_subq_result_keys(sel, ["t_a", "t_x_b"])
@@ -3095,11 +3095,11 @@ class WithLabelsTest(fixtures.TestBase):
         )
         eq_(
             list(sel.selected_columns.keys()),
-            ["t_x_a", "t_x_id_1"],
+            ["t_x_a", "t_x_a_1"],
         )
 
         # deduping for different cols but same label
-        eq_(list(sel.subquery().c.keys()), ["t_x_a", "t_x_id_1"])
+        eq_(list(sel.subquery().c.keys()), ["t_x_a", "t_x_a_1"])
 
         # if we turn off deduping entirely
         # eq_(list(sel.subquery().c.keys()), ["t_x_a", "t_x_a"])
@@ -3115,7 +3115,7 @@ class WithLabelsTest(fixtures.TestBase):
 
     def test_keys_overlap_names_dont_nolabel(self):
         sel = self._keys_overlap_names_dont()
-        eq_(sel.selected_columns.keys(), ["x", "b_1"])
+        eq_(sel.selected_columns.keys(), ["x", "x_1"])
         self._assert_result_keys(sel, ["a", "b"])
 
     def test_keys_overlap_names_dont_label(self):