]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
add columns_clause_froms and related use cases
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 23 Jul 2021 20:07:50 +0000 (16:07 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 7 Aug 2021 18:47:29 +0000 (14:47 -0400)
Added new attribute :attr:`_sql.Select.columns_clause_froms` that will
retrieve the FROM list implied by the columns clause of the
:class:`_sql.Select` statement. This differs from the old
:attr:`_sql.Select.froms` collection in that it does not perform any ORM
compilation steps, which necessarily deannotate the FROM elements and do
things like compute joinedloads etc., which makes it not an appropriate
candidate for the :meth:`_sql.Select.select_from` method. Additionally adds
a new parameter
:paramref:`_sql.Select.with_only_columns.maintain_column_froms` that
transfers this collection to :meth:`_sql.Select.select_from` before
replacing the columns collection.

In addition, the :attr:`_sql.Select.froms` is renamed to
:meth:`_sql.Select.get_final_froms`, to stress that this collection is not
a simple accessor and is instead calculated given the full state of the
object, which can be an expensive call when used in an ORM context.

Additionally fixes a regression involving the
:func:`_orm.with_only_columns` function to support applying criteria to
column elements that were replaced with either
:meth:`_sql.Select.with_only_columns` or :meth:`_orm.Query.with_entities` ,
which had broken as part of :ticket:`6503` released in 1.4.19.

Fixes: #6808
Change-Id: Ib5d66cce488bbaca06dab4f68fb5cdaa73e8823e

doc/build/changelog/unreleased_14/6808.rst [new file with mode: 0644]
lib/sqlalchemy/orm/context.py
lib/sqlalchemy/orm/util.py
lib/sqlalchemy/sql/selectable.py
test/orm/test_core_compilation.py
test/orm/test_relationship_criteria.py
test/sql/test_deprecations.py
test/sql/test_selectable.py

diff --git a/doc/build/changelog/unreleased_14/6808.rst b/doc/build/changelog/unreleased_14/6808.rst
new file mode 100644 (file)
index 0000000..803a944
--- /dev/null
@@ -0,0 +1,26 @@
+.. change::
+    :tags: orm, usecase
+    :tickets: 6808
+
+    Added new attribute :attr:`_sql.Select.columns_clause_froms` that will
+    retrieve the FROM list implied by the columns clause of the
+    :class:`_sql.Select` statement. This differs from the old
+    :attr:`_sql.Select.froms` collection in that it does not perform any ORM
+    compilation steps, which necessarily deannotate the FROM elements and do
+    things like compute joinedloads etc., which makes it not an appropriate
+    candidate for the :meth:`_sql.Select.select_from` method. Additionally adds
+    a new parameter
+    :paramref:`_sql.Select.with_only_columns.maintain_column_froms` that
+    transfers this collection to :meth:`_sql.Select.select_from` before
+    replacing the columns collection.
+
+    In addition, the :attr:`_sql.Select.froms` is renamed to
+    :meth:`_sql.Select.get_final_froms`, to stress that this collection is not
+    a simple accessor and is instead calculated given the full state of the
+    object, which can be an expensive call when used in an ORM context.
+
+    Additionally fixes a regression involving the
+    :func:`_orm.with_only_columns` function to support applying criteria to
+    column elements that were replaced with either
+    :meth:`_sql.Select.with_only_columns` or :meth:`_orm.Query.with_entities` ,
+    which had broken as part of :ticket:`6503` released in 1.4.19.
index c4b695687631dffb85ed1f53ecf07e5080cee853..60347781905ca0362a1093cd78ef975860003cf9 100644 (file)
@@ -4,6 +4,8 @@
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
+import itertools
+
 from . import attributes
 from . import interfaces
 from . import loading
@@ -872,6 +874,19 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
                 for elem in _select_iterables([element]):
                     yield elem
 
+    @classmethod
+    def get_columns_clause_froms(cls, statement):
+        return cls._normalize_froms(
+            itertools.chain.from_iterable(
+                element._from_objects
+                if "parententity" not in element._annotations
+                else [
+                    element._annotations["parententity"].__clause_element__()
+                ]
+                for element in statement._raw_columns
+            )
+        )
+
     @classmethod
     @util.preload_module("sqlalchemy.orm.query")
     def from_statement(cls, statement, from_statement):
index 7cfb3589d0e528e61fabb8ed59cada43da305ada..46bb3c94356df279a6a5c121da2349eaebba4cfb 100644 (file)
@@ -1106,6 +1106,11 @@ class LoaderCriteriaOption(CriteriaOption):
         else:
             return self.where_criteria
 
+    def process_compile_state_replaced_entities(
+        self, compile_state, mapper_entities
+    ):
+        return self.process_compile_state(compile_state)
+
     def process_compile_state(self, compile_state):
         """Apply a modification to a given :class:`.CompileState`."""
 
index b6cf7f55e85336145328f9d0675acd7a122f5864..0040db6da75704073d7938f50f089cbeba7ce0e8 100644 (file)
@@ -4223,6 +4223,14 @@ class SelectState(util.MemoizedSlots, CompileState):
     def from_statement(cls, statement, from_statement):
         cls._plugin_not_implemented()
 
+    @classmethod
+    def get_columns_clause_froms(cls, statement):
+        return cls._normalize_froms(
+            itertools.chain.from_iterable(
+                element._from_objects for element in statement._raw_columns
+            )
+        )
+
     @classmethod
     def _column_naming_convention(cls, label_style):
 
@@ -4284,7 +4292,8 @@ class SelectState(util.MemoizedSlots, CompileState):
             check_statement=statement,
         )
 
-    def _normalize_froms(self, iterable_of_froms, check_statement=None):
+    @classmethod
+    def _normalize_froms(cls, iterable_of_froms, check_statement=None):
         """given an iterable of things to select FROM, reduce them to what
         would actually render in the FROM clause of a SELECT.
 
@@ -5347,13 +5356,79 @@ class Select(
         """
         return self.join(target, onclause=onclause, isouter=True, full=full)
 
+    def get_final_froms(self):
+        """Compute the final displayed list of :class:`_expression.FromClause`
+        elements.
+
+        This method will run through the full computation required to
+        determine what FROM elements will be displayed in the resulting
+        SELECT statement, including shadowing individual tables with
+        JOIN objects, as well as full computation for ORM use cases including
+        eager loading clauses.
+
+        For ORM use, this accessor returns the **post compilation**
+        list of FROM objects; this collection will include elements such as
+        eagerly loaded tables and joins.  The objects will **not** be
+        ORM enabled and not work as a replacement for the
+        :meth:`_sql.Select.select_froms` collection; additionally, the
+        method is not well performing for an ORM enabled statement as it
+        will incur the full ORM construction process.
+
+        To retrieve the FROM list that's implied by the "columns" collection
+        passed to the :class:`_sql.Select` originally, use the
+        :attr:`_sql.Select.columns_clause_froms` accessor.
+
+        To select from an alternative set of columns while maintaining the
+        FROM list, use the :meth:`_sql.Select.with_only_columns` method and
+        pass the
+        :paramref:`_sql.Select.with_only_columns.maintain_column_froms`
+        parameter.
+
+        .. versionadded:: 1.4.23 - the :meth:`_sql.Select.get_final_froms`
+           method replaces the previous :attr:`_sql.Select.froms` accessor,
+           which is deprecated.
+
+        .. seealso::
+
+            :attr:`_sql.Select.columns_clause_froms`
+
+        """
+        return self._compile_state_factory(self, None)._get_display_froms()
+
     @property
+    @util.deprecated(
+        "1.4.23",
+        "The :attr:`_expression.Select.froms` attribute is moved to "
+        "the :meth:`_expression.Select.get_final_froms` method.",
+    )
     def froms(self):
         """Return the displayed list of :class:`_expression.FromClause`
         elements.
 
+
         """
-        return self._compile_state_factory(self, None)._get_display_froms()
+        return self.get_final_froms()
+
+    @property
+    def columns_clause_froms(self):
+        """Return the set of :class:`_expression.FromClause` objects implied
+        by the columns clause of this SELECT statement.
+
+        .. versionadded:: 1.4.23
+
+        .. seealso::
+
+            :attr:`_sql.Select.froms` - "final" FROM list taking the full
+            statement into account
+
+            :meth:`_sql.Select.with_only_columns` - makes use of this
+            collection to set up a new FROM list
+
+        """
+
+        return SelectState.get_plugin_class(self).get_columns_clause_froms(
+            self
+        )
 
     @property
     def inner_columns(self):
@@ -5525,13 +5600,13 @@ class Select(
         )
 
     @_generative
-    def with_only_columns(self, *columns):
+    def with_only_columns(self, *columns, **kw):
         r"""Return a new :func:`_expression.select` construct with its columns
         clause replaced with the given columns.
 
-        This method is exactly equivalent to as if the original
+        By default, this method is exactly equivalent to as if the original
         :func:`_expression.select` had been called with the given columns
-        clause.   I.e. a statement::
+        clause. E.g. a statement::
 
             s = select(table1.c.a, table1.c.b)
             s = s.with_only_columns(table1.c.b)
@@ -5540,13 +5615,30 @@ class Select(
 
             s = select(table1.c.b)
 
-        Note that this will also dynamically alter the FROM clause of the
-        statement if it is not explicitly stated.  To maintain the FROM
-        clause, ensure the :meth:`_sql.Select.select_from` method is
-        used appropriately::
+        In this mode of operation, :meth:`_sql.Select.with_only_columns`
+        will also dynamically alter the FROM clause of the
+        statement if it is not explicitly stated.
+        To maintain the existing set of FROMs including those implied by the
+        current columns clause, add the
+        :paramref:`_sql.Select.with_only_columns.maintain_column_froms`
+        parameter::
+
+            s = select(table1.c.a, table2.c.b)
+            s = s.with_only_columns(table1.c.a, maintain_column_froms=True)
+
+        The above parameter performs a transfer of the effective FROMs
+        in the columns collection to the :meth:`_sql.Select.select_from`
+        method, as though the following were invoked::
 
             s = select(table1.c.a, table2.c.b)
-            s = s.select_from(table2.c.b).with_only_columns(table1.c.a)
+            s = s.select_from(table1, table2).with_only_columns(table1.c.a)
+
+        The :paramref:`_sql.Select.with_only_columns.maintain_column_froms`
+        parameter makes use of the :attr:`_sql.Select.columns_clause_froms`
+        collection and performs an operation equivalent to the following::
+
+            s = select(table1.c.a, table2.c.b)
+            s = s.select_from(*s.columns_clause_froms).with_only_columns(table1.c.a)
 
         :param \*columns: column expressions to be used.
 
@@ -5554,13 +5646,27 @@ class Select(
             method accepts the list of column expressions positionally;
             passing the expressions as a list is deprecated.
 
-        """
+        :param maintain_column_froms: boolean parameter that will ensure the
+         FROM list implied from the current columns clause will be transferred
+         to the :meth:`_sql.Select.select_from` method first.
+
+         .. versionadded:: 1.4.23
+
+        """  # noqa E501
 
         # memoizations should be cleared here as of
         # I95c560ffcbfa30b26644999412fb6a385125f663 , asserting this
         # is the case for now.
         self._assert_no_memoizations()
 
+        maintain_column_froms = kw.pop("maintain_column_froms", False)
+        if kw:
+            raise TypeError("unknown parameters: %s" % (", ".join(kw),))
+
+        if maintain_column_froms:
+            self.select_from.non_generative(self, *self.columns_clause_froms)
+
+        # then memoize the FROMs etc.
         _MemoizedSelectEntities._generate_for_statement(self)
 
         self._raw_columns = [
index e730d9097581e2b038f5db61318e2da023c46257..2adc438422253a2e26cb4d366dd75e05ef02cb05 100644 (file)
@@ -2,6 +2,7 @@ from sqlalchemy import bindparam
 from sqlalchemy import exc
 from sqlalchemy import func
 from sqlalchemy import insert
+from sqlalchemy import inspect
 from sqlalchemy import literal_column
 from sqlalchemy import null
 from sqlalchemy import or_
@@ -21,6 +22,7 @@ from sqlalchemy.orm import query_expression
 from sqlalchemy.orm import relationship
 from sqlalchemy.orm import undefer
 from sqlalchemy.orm import with_expression
+from sqlalchemy.orm import with_loader_criteria
 from sqlalchemy.orm import with_polymorphic
 from sqlalchemy.sql import and_
 from sqlalchemy.sql import sqltypes
@@ -30,6 +32,8 @@ from sqlalchemy.sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
 from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.testing import AssertsCompiledSQL
 from sqlalchemy.testing import eq_
+from sqlalchemy.testing import is_
+from sqlalchemy.testing.assertions import expect_raises_message
 from sqlalchemy.testing.fixtures import fixture_session
 from sqlalchemy.testing.util import resolve_lambda
 from .inheritance import _poly_fixtures
@@ -81,7 +85,7 @@ class SelectableTest(QueryTest, AssertsCompiledSQL):
 
         stmt = select(User).filter_by(name="ed")
 
-        eq_(stmt.froms, [self.tables.users])
+        eq_(stmt.get_final_froms(), [self.tables.users])
 
     def test_froms_join(self):
         User, Address = self.classes("User", "Address")
@@ -89,7 +93,7 @@ class SelectableTest(QueryTest, AssertsCompiledSQL):
 
         stmt = select(User).join(User.addresses)
 
-        assert stmt.froms[0].compare(users.join(addresses))
+        assert stmt.get_final_froms()[0].compare(users.join(addresses))
 
     @testing.combinations(
         (
@@ -166,6 +170,115 @@ class SelectableTest(QueryTest, AssertsCompiledSQL):
         eq_(stmt.column_descriptions, expected)
 
 
+class ColumnsClauseFromsTest(QueryTest, AssertsCompiledSQL):
+    __dialect__ = "default"
+
+    def test_exclude_eagerloads(self):
+        User, Address = self.classes("User", "Address")
+
+        stmt = select(User).options(joinedload(User.addresses))
+
+        froms = stmt.columns_clause_froms
+
+        mapper = inspect(User)
+        is_(froms[0], inspect(User).__clause_element__())
+        eq_(
+            froms[0]._annotations,
+            {
+                "entity_namespace": mapper,
+                "parententity": mapper,
+                "parentmapper": mapper,
+            },
+        )
+        eq_(len(froms), 1)
+
+    def test_maintain_annotations_from_table(self):
+        User, Address = self.classes("User", "Address")
+
+        stmt = select(User)
+
+        mapper = inspect(User)
+        froms = stmt.columns_clause_froms
+        is_(froms[0], inspect(User).__clause_element__())
+        eq_(
+            froms[0]._annotations,
+            {
+                "entity_namespace": mapper,
+                "parententity": mapper,
+                "parentmapper": mapper,
+            },
+        )
+        eq_(len(froms), 1)
+
+    def test_maintain_annotations_from_annoated_cols(self):
+        User, Address = self.classes("User", "Address")
+
+        stmt = select(User.id)
+
+        mapper = inspect(User)
+        froms = stmt.columns_clause_froms
+        is_(froms[0], inspect(User).__clause_element__())
+        eq_(
+            froms[0]._annotations,
+            {
+                "entity_namespace": mapper,
+                "parententity": mapper,
+                "parentmapper": mapper,
+            },
+        )
+        eq_(len(froms), 1)
+
+    def test_with_only_columns_unknown_kw(self):
+        User, Address = self.classes("User", "Address")
+
+        stmt = select(User.id)
+
+        with expect_raises_message(TypeError, "unknown parameters: foo"):
+            stmt.with_only_columns(User.id, foo="bar")
+
+    @testing.combinations((True,), (False,))
+    def test_replace_into_select_from_maintains_existing(self, use_flag):
+        User, Address = self.classes("User", "Address")
+
+        stmt = select(User.id).select_from(Address)
+
+        if use_flag:
+            stmt = stmt.with_only_columns(
+                func.count(), maintain_column_froms=True
+            )
+        else:
+            stmt = stmt.select_from(
+                *stmt.columns_clause_froms
+            ).with_only_columns(func.count())
+
+        # Address is maintained in the FROM list
+        self.assert_compile(
+            stmt, "SELECT count(*) AS count_1 FROM addresses, users"
+        )
+
+    @testing.combinations((True,), (False,))
+    def test_replace_into_select_from_with_loader_criteria(self, use_flag):
+        User, Address = self.classes("User", "Address")
+
+        stmt = select(User.id).options(
+            with_loader_criteria(User, User.name == "ed")
+        )
+
+        if use_flag:
+            stmt = stmt.with_only_columns(
+                func.count(), maintain_column_froms=True
+            )
+        else:
+            stmt = stmt.select_from(
+                *stmt.columns_clause_froms
+            ).with_only_columns(func.count())
+
+        self.assert_compile(
+            stmt,
+            "SELECT count(*) AS count_1 FROM users WHERE users.name = :name_1",
+        )
+
+
 class JoinTest(QueryTest, AssertsCompiledSQL):
     __dialect__ = "default"
 
index 683267b1c87c161e912cc0e9f667cd45d8478a9e..f9b4335df1a68d9fddc446fe6748bd32b93142fa 100644 (file)
@@ -5,6 +5,7 @@ from sqlalchemy import Column
 from sqlalchemy import DateTime
 from sqlalchemy import event
 from sqlalchemy import ForeignKey
+from sqlalchemy import func
 from sqlalchemy import Integer
 from sqlalchemy import orm
 from sqlalchemy import select
@@ -25,6 +26,7 @@ from sqlalchemy.orm import with_loader_criteria
 from sqlalchemy.orm.decl_api import declared_attr
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing.assertsql import CompiledSQL
+from sqlalchemy.testing.fixtures import fixture_session
 from test.orm import _fixtures
 
 
@@ -172,6 +174,39 @@ class LoaderCriteriaTest(_Fixtures, testing.AssertsCompiledSQL):
             "FROM users WHERE users.name != :name_1",
         )
 
+    def test_criteria_post_replace(self, user_address_fixture):
+        User, Address = user_address_fixture
+
+        stmt = (
+            select(User)
+            .select_from(User)
+            .options(with_loader_criteria(User, User.name != "name"))
+            .with_only_columns(func.count())
+        )
+
+        self.assert_compile(
+            stmt,
+            "SELECT count(*) AS count_1 FROM users "
+            "WHERE users.name != :name_1",
+        )
+
+    def test_criteria_post_replace_legacy(self, user_address_fixture):
+        User, Address = user_address_fixture
+
+        s = fixture_session()
+        stmt = (
+            s.query(User)
+            .select_from(User)
+            .options(with_loader_criteria(User, User.name != "name"))
+            .with_entities(func.count())
+        )
+
+        self.assert_compile(
+            stmt,
+            "SELECT count(*) AS count_1 FROM users "
+            "WHERE users.name != :name_1",
+        )
+
     def test_select_from_mapper_mapper_criteria(self, user_address_fixture):
         User, Address = user_address_fixture
 
index 44135e373e173ac302d87b22806b7faa15fbbd97..9b74ab1fa6572de5a6027daa7ffa63a74d927586 100644 (file)
@@ -451,6 +451,17 @@ class SelectableTest(fixtures.TestBase, AssertsCompiledSQL):
             "deprecated"
         )
 
+    def test_froms_renamed(self):
+        t1 = table("t1", column("q"))
+
+        stmt = select(t1)
+
+        with testing.expect_deprecated(
+            r"The Select.froms attribute is moved to the "
+            r"Select.get_final_froms\(\) method."
+        ):
+            eq_(stmt.froms, [t1])
+
     def test_select_list_argument(self):
 
         with testing.expect_deprecated_20(
index 5a94d4038aee3a48d384bad176ebb7d14f1c610b..b76873490462be9928b49466dc22cf49ba59677d 100644 (file)
@@ -1,5 +1,4 @@
 """Test various algorithmic properties of selectables."""
-
 from sqlalchemy import and_
 from sqlalchemy import bindparam
 from sqlalchemy import Boolean
@@ -590,6 +589,32 @@ class SelectableTest(
             "table1.col3, table1.colx FROM table1) AS anon_1",
         )
 
+    @testing.combinations(
+        (
+            [table1.c.col1],
+            [table1.join(table2)],
+            [table1.join(table2)],
+            [table1],
+        ),
+        ([table1], [table2], [table1, table2], [table1]),
+        (
+            [table1.c.col1, table2.c.col1],
+            [],
+            [table1, table2],
+            [table1, table2],
+        ),
+    )
+    def test_froms_accessors(
+        self, cols_expr, select_from, exp_final_froms, exp_cc_froms
+    ):
+        """tests for #6808"""
+        s1 = select(*cols_expr).select_from(*select_from)
+
+        for ff, efp in util.zip_longest(s1.get_final_froms(), exp_final_froms):
+            assert ff.compare(efp)
+
+        eq_(s1.columns_clause_froms, exp_cc_froms)
+
     def test_scalar_subquery_from_subq_same_source(self):
         s1 = select(table1.c.col1)