]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
support "SELECT *" for ORM queries
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 11 Jul 2022 01:24:17 +0000 (21:24 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 11 Jul 2022 01:28:39 +0000 (21:28 -0400)
A :func:`_sql.select` construct that is passed a sole '*' argument for
``SELECT *``, either via string, :func:`_sql.text`, or
:func:`_sql.literal_column`, will be interpreted as a Core-level SQL
statement rather than as an ORM level statement. This is so that the ``*``,
when expanded to match any number of columns, will result in all columns
returned in the result. the ORM- level interpretation of
:func:`_sql.select` needs to know the names and types of all ORM columns up
front which can't be achieved when ``'*'`` is used.

If ``'*`` is used amongst other expressions simultaneously with an ORM
statement, an error is raised as this can't be interpreted correctly by the
ORM.

Fixes: #8235
Change-Id: Ic8e84491e14acdc8570704eadeaeaf6e16b1e870
(cherry picked from commit 3916bfc9ccf2904f69498075849a82ceee225b3a)

doc/build/changelog/unreleased_14/8235.rst [new file with mode: 0644]
lib/sqlalchemy/orm/context.py
lib/sqlalchemy/sql/elements.py
test/orm/test_loading.py

diff --git a/doc/build/changelog/unreleased_14/8235.rst b/doc/build/changelog/unreleased_14/8235.rst
new file mode 100644 (file)
index 0000000..ea5726e
--- /dev/null
@@ -0,0 +1,16 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 8235
+
+    A :func:`_sql.select` construct that is passed a sole '*' argument for
+    ``SELECT *``, either via string, :func:`_sql.text`, or
+    :func:`_sql.literal_column`, will be interpreted as a Core-level SQL
+    statement rather than as an ORM level statement. This is so that the ``*``,
+    when expanded to match any number of columns, will result in all columns
+    returned in the result. the ORM- level interpretation of
+    :func:`_sql.select` needs to know the names and types of all ORM columns up
+    front which can't be achieved when ``'*'`` is used.
+
+    If ``'*`` is used amongst other expressions simultaneously with an ORM
+    statement, an error is raised as this can't be interpreted correctly by the
+    ORM.
index 7cedc2b43cb387df212bcb5d51db271ef4837a4b..9d4f652ea4f3ccce9e884910c14609c283c204e6 100644 (file)
@@ -178,6 +178,7 @@ class ORMCompileState(CompileState):
             ("_set_base_alias", InternalTraversal.dp_boolean),
             ("_for_refresh_state", InternalTraversal.dp_boolean),
             ("_render_for_subquery", InternalTraversal.dp_boolean),
+            ("_is_star", InternalTraversal.dp_boolean),
         ]
 
         # set to True by default from Query._statement_20(), to indicate
@@ -202,6 +203,7 @@ class ORMCompileState(CompileState):
         _set_base_alias = False
         _for_refresh_state = False
         _render_for_subquery = False
+        _is_star = False
 
     current_path = _path_registry
 
@@ -336,6 +338,8 @@ class ORMCompileState(CompileState):
         load_options = execution_options.get(
             "_sa_orm_load_options", QueryContext.default_load_options
         )
+        if compile_state.compile_options._is_star:
+            return result
 
         querycontext = QueryContext(
             compile_state,
@@ -860,6 +864,11 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
 
         self._for_update_arg = query._for_update_arg
 
+        if self.compile_options._is_star and (len(self._entities) != 1):
+            raise sa_exc.CompileError(
+                "Can't generate ORM query that includes multiple expressions "
+                "at the same time as '*'; query for '*' alone if present"
+            )
         for entity in self._entities:
             entity.setup_compile_state(self)
 
@@ -2941,6 +2950,9 @@ class _RawColumnEntity(_ColumnEntity):
         self.raw_column_index = raw_column_index
         self.translate_raw_column = raw_column_index is not None
 
+        if column._is_star:
+            compile_state.compile_options += {"_is_star": True}
+
         if not is_current_entities or column._is_text_clause:
             self._label_name = None
         else:
index a1891f19cabe9a096cc63fe463bd6c94b73fdb12..c9cea23dadd23561eb649b83483a26b1e7668c4c 100644 (file)
@@ -216,6 +216,7 @@ class ClauseElement(
     _is_lambda_element = False
     _is_singleton_constant = False
     _is_immutable = False
+    _is_star = False
 
     _order_by_label_element = None
 
@@ -1803,6 +1804,10 @@ class TextClause(
 
     _allow_label_resolve = False
 
+    @property
+    def _is_star(self):
+        return self.text == "*"
+
     def __init__(self, text, bind=None):
         self._bind = bind
         self._bindparams = {}
@@ -4795,6 +4800,10 @@ class ColumnClause(
 
     _is_multiparam_column = False
 
+    @property
+    def _is_star(self):
+        return self.is_literal and self.name == "*"
+
     def __init__(self, text, type_=None, is_literal=False, _selectable=None):
         """Produce a :class:`.ColumnClause` object.
 
index 88a160b5a8395b7bfd337b731aa82b0f48c337fa..cc3c3f49424f5e9c9356bac58448c38329b15707 100644 (file)
@@ -1,12 +1,16 @@
 from sqlalchemy import exc
+from sqlalchemy import literal
+from sqlalchemy import literal_column
 from sqlalchemy import select
 from sqlalchemy import testing
+from sqlalchemy import text
 from sqlalchemy.orm import loading
 from sqlalchemy.orm import relationship
 from sqlalchemy.testing import mock
 from sqlalchemy.testing.assertions import assert_raises
 from sqlalchemy.testing.assertions import assert_raises_message
 from sqlalchemy.testing.assertions import eq_
+from sqlalchemy.testing.assertions import expect_raises_message
 from sqlalchemy.testing.fixtures import fixture_session
 from . import _fixtures
 
@@ -14,6 +18,90 @@ from . import _fixtures
 # class LoadOnIdentTest(_fixtures.FixtureTest):
 
 
+class SelectStarTest(_fixtures.FixtureTest):
+    run_setup_mappers = "once"
+    run_inserts = "once"
+    run_deletes = None
+
+    @classmethod
+    def setup_mappers(cls):
+        cls._setup_stock_mapping()
+
+    @testing.combinations(
+        "plain", "text", "literal_column", argnames="exprtype"
+    )
+    @testing.combinations("core", "orm", argnames="coreorm")
+    def test_single_star(self, exprtype, coreorm):
+        """test for #8235"""
+        User, Address = self.classes("User", "Address")
+
+        if exprtype == "plain":
+            star = "*"
+        elif exprtype == "text":
+            star = text("*")
+        elif exprtype == "literal_column":
+            star = literal_column("*")
+        else:
+            assert False
+
+        stmt = (
+            select(star)
+            .select_from(User)
+            .join(Address)
+            .where(User.id == 7)
+            .order_by(User.id, Address.id)
+        )
+
+        s = fixture_session()
+
+        if coreorm == "core":
+            result = s.connection().execute(stmt)
+        elif coreorm == "orm":
+            result = s.execute(stmt)
+        else:
+            assert False
+
+        eq_(result.all(), [(7, "jack", 1, 7, "jack@bean.com")])
+
+    @testing.combinations(
+        "plain", "text", "literal_column", argnames="exprtype"
+    )
+    @testing.combinations(
+        lambda User, star: (star, User.id),
+        lambda User, star: (star, User),
+        lambda User, star: (User.id, star),
+        lambda User, star: (User, star),
+        lambda User, star: (literal("some text"), star),
+        lambda User, star: (star, star),
+        lambda User, star: (star, text("some text")),
+        argnames="testcase",
+    )
+    def test_no_star_orm_combinations(self, exprtype, testcase):
+        """test for #8235"""
+        User = self.classes.User
+
+        if exprtype == "plain":
+            star = "*"
+        elif exprtype == "text":
+            star = text("*")
+        elif exprtype == "literal_column":
+            star = literal_column("*")
+        else:
+            assert False
+
+        args = testing.resolve_lambda(testcase, User=User, star=star)
+        stmt = select(*args).select_from(User)
+
+        s = fixture_session()
+
+        with expect_raises_message(
+            exc.CompileError,
+            r"Can't generate ORM query that includes multiple expressions "
+            r"at the same time as '\*';",
+        ):
+            s.execute(stmt)
+
+
 class InstanceProcessorTest(_fixtures.FixtureTest):
     def test_state_no_load_path_comparison(self):
         # test issue #5110