]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
TextualSelect is ReturnsRowsRole
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 1 Mar 2023 16:07:25 +0000 (11:07 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 2 Mar 2023 15:47:03 +0000 (10:47 -0500)
Fixed typing bug where :meth:`_sql.Select.from_statement` would not accept
:func:`_sql.text` or :class:`.TextualSelect` objects as a valid type.
Additionally repaired the :class:`.TextClause.columns` method to have a
return type, which was missing.

Fixes: #9398
Change-Id: I627fc33bf83365e1c7f7c6ed29ea387dfd4a57d8

doc/build/changelog/unreleased_20/9398.rst [new file with mode: 0644]
lib/sqlalchemy/sql/_typing.py
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/selectable.py
test/ext/mypy/plain_files/typed_queries.py

diff --git a/doc/build/changelog/unreleased_20/9398.rst b/doc/build/changelog/unreleased_20/9398.rst
new file mode 100644 (file)
index 0000000..731695f
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: bug, typing
+    :tickets: 9398
+
+    Fixed typing bug where :meth:`_sql.Select.from_statement` would not accept
+    :func:`_sql.text` or :class:`.TextualSelect` objects as a valid type.
+    Additionally repaired the :class:`.TextClause.columns` method to have a
+    return type, which was missing.
index a828d6a0fb8f6986e0547866ce701e7c164c54eb..14b1b95940bda6ffba3ee4efa0cffcfb1544d199 100644 (file)
@@ -186,6 +186,7 @@ overall which brings in the TextClause object also.
 
 """
 
+
 _ColumnExpressionOrLiteralArgument = Union[Any, _ColumnExpressionArgument[_T]]
 
 _ColumnExpressionOrStrLabelArgument = Union[str, _ColumnExpressionArgument[_T]]
index e51b755ddbb82bc67f8e6aaaa136d4206edfcd92..a416b6ac09682e41c609b8fea809667772e54d02 100644 (file)
@@ -99,6 +99,7 @@ if typing.TYPE_CHECKING:
     from .selectable import _SelectIterable
     from .selectable import FromClause
     from .selectable import NamedFromClause
+    from .selectable import TextualSelect
     from .sqltypes import TupleType
     from .type_api import TypeEngine
     from .visitors import _CloneCallableType
@@ -2385,7 +2386,9 @@ class TextClause(
         return self
 
     @util.preload_module("sqlalchemy.sql.selectable")
-    def columns(self, *cols, **types):
+    def columns(
+        self, *cols: _ColumnExpressionArgument[Any], **types: TypeEngine[Any]
+    ) -> TextualSelect:
         r"""Turn this :class:`_expression.TextClause` object into a
         :class:`_expression.TextualSelect`
         object that serves the same role as a SELECT
@@ -2503,29 +2506,38 @@ class TextClause(
 
         """
         selectable = util.preloaded.sql_selectable
+
+        input_cols: List[NamedColumn[Any]] = [
+            coercions.expect(roles.LabeledColumnExprRole, col) for col in cols
+        ]
+
         positional_input_cols = [
             ColumnClause(col.key, types.pop(col.key))
             if col.key in types
             else col
-            for col in cols
+            for col in input_cols
         ]
-        keyed_input_cols: List[ColumnClause[Any]] = [
+        keyed_input_cols: List[NamedColumn[Any]] = [
             ColumnClause(key, type_) for key, type_ in types.items()
         ]
 
-        return selectable.TextualSelect(
+        elem = selectable.TextualSelect.__new__(selectable.TextualSelect)
+        elem._init(
             self,
             positional_input_cols + keyed_input_cols,
             positional=bool(positional_input_cols) and not keyed_input_cols,
         )
+        return elem
 
     @property
-    def type(self):
+    def type(self) -> TypeEngine[Any]:
         return type_api.NULLTYPE
 
     @property
     def comparator(self):
-        return self.type.comparator_factory(self)
+        # TODO: this seems wrong, it seems like we might not
+        # be using this method.
+        return self.type.comparator_factory(self)  # type: ignore
 
     def self_group(self, against=None):
         if against is operators.in_op:
index 75b5d09e3f56885574e7b487e0adad7dfb7e1f6c..39ef420dd6ed32747c3d4fa3da3c1c5c779f0b1c 100644 (file)
@@ -4551,7 +4551,7 @@ class SelectState(util.MemoizedSlots, CompileState):
 
     @classmethod
     def from_statement(
-        cls, statement: Select[Any], from_statement: ExecutableReturnsRows
+        cls, statement: Select[Any], from_statement: roles.ReturnsRowsRole
     ) -> ExecutableReturnsRows:
         cls._plugin_not_implemented()
 
@@ -5273,7 +5273,7 @@ class Select(
         return meth(self)
 
     def from_statement(
-        self, statement: ExecutableReturnsRows
+        self, statement: roles.ReturnsRowsRole
     ) -> ExecutableReturnsRows:
         """Apply the columns which this :class:`.Select` would select
         onto another statement.
@@ -6770,7 +6770,7 @@ class Exists(UnaryExpression[bool]):
         return e
 
 
-class TextualSelect(SelectBase, Executable, Generative):
+class TextualSelect(SelectBase, ExecutableReturnsRows, Generative):
     """Wrap a :class:`_expression.TextClause` construct within a
     :class:`_expression.SelectBase`
     interface.
@@ -6815,14 +6815,28 @@ class TextualSelect(SelectBase, Executable, Generative):
     def __init__(
         self,
         text: TextClause,
-        columns: List[ColumnClause[Any]],
+        columns: List[_ColumnExpressionArgument[Any]],
+        positional: bool = False,
+    ) -> None:
+
+        self._init(
+            text,
+            # convert for ORM attributes->columns, etc
+            [
+                coercions.expect(roles.LabeledColumnExprRole, c)
+                for c in columns
+            ],
+            positional,
+        )
+
+    def _init(
+        self,
+        text: TextClause,
+        columns: List[NamedColumn[Any]],
         positional: bool = False,
     ) -> None:
         self.element = text
-        # convert for ORM attributes->columns, etc
-        self.column_args = [
-            coercions.expect(roles.ColumnsClauseRole, c) for c in columns
-        ]
+        self.column_args = columns
         self.positional = positional
 
     @HasMemoized_ro_memoized_attribute
index 3e67a71325b91dae691a0682712d1467a32be979..2de565e6a4af9e798655e76325e63874b264c216 100644 (file)
@@ -2,13 +2,19 @@ from __future__ import annotations
 
 from typing import Tuple
 
+from sqlalchemy import Column
 from sqlalchemy import column
 from sqlalchemy import create_engine
 from sqlalchemy import delete
 from sqlalchemy import func
 from sqlalchemy import insert
+from sqlalchemy import Integer
+from sqlalchemy import MetaData
 from sqlalchemy import Select
 from sqlalchemy import select
+from sqlalchemy import String
+from sqlalchemy import Table
+from sqlalchemy import text
 from sqlalchemy import update
 from sqlalchemy.orm import aliased
 from sqlalchemy.orm import DeclarativeBase
@@ -29,6 +35,13 @@ class User(Base):
     data: Mapped[str]
 
 
+user_table = Table(
+    "user",
+    MetaData(),
+    Column("id", Integer, primary_key=True),
+    Column("name", String, primary_key=True),
+)
+
 session = Session()
 
 e = create_engine("sqlite://")
@@ -443,3 +456,29 @@ def t_dml_delete() -> None:
 
     # EXPECTED_TYPE: Result[Tuple[int, str]]
     reveal_type(r1)
+
+
+def t_from_statement() -> None:
+
+    t = text("select * from user")
+
+    # EXPECTED_TYPE: TextClause
+    reveal_type(t)
+
+    select(User).from_statement(t)
+
+    ts = text("select * from user").columns(User.id, User.name)
+
+    # EXPECTED_TYPE: TextualSelect
+    reveal_type(ts)
+
+    select(User).from_statement(ts)
+
+    ts2 = text("select * from user").columns(
+        user_table.c.id, user_table.c.name
+    )
+
+    # EXPECTED_TYPE: TextualSelect
+    reveal_type(ts2)
+
+    select(User).from_statement(ts2)