From: Mike Bayer Date: Wed, 1 Mar 2023 16:07:25 +0000 (-0500) Subject: TextualSelect is ReturnsRowsRole X-Git-Tag: rel_2_0_5~9^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=45f7b3b8ac9a1b393b45f2f199a88c3bb0c86705;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git TextualSelect is ReturnsRowsRole Fixed typing bug where :meth:`_sql.Select.from_statement` would not accept :func:`_sql.text` or :class:`.TextualSelect` objects as a valid type. Additionally repaired the :class:`.TextClause.columns` method to have a return type, which was missing. Fixes: #9398 Change-Id: I627fc33bf83365e1c7f7c6ed29ea387dfd4a57d8 --- diff --git a/doc/build/changelog/unreleased_20/9398.rst b/doc/build/changelog/unreleased_20/9398.rst new file mode 100644 index 0000000000..731695f244 --- /dev/null +++ b/doc/build/changelog/unreleased_20/9398.rst @@ -0,0 +1,8 @@ +.. change:: + :tags: bug, typing + :tickets: 9398 + + Fixed typing bug where :meth:`_sql.Select.from_statement` would not accept + :func:`_sql.text` or :class:`.TextualSelect` objects as a valid type. + Additionally repaired the :class:`.TextClause.columns` method to have a + return type, which was missing. diff --git a/lib/sqlalchemy/sql/_typing.py b/lib/sqlalchemy/sql/_typing.py index a828d6a0fb..14b1b95940 100644 --- a/lib/sqlalchemy/sql/_typing.py +++ b/lib/sqlalchemy/sql/_typing.py @@ -186,6 +186,7 @@ overall which brings in the TextClause object also. """ + _ColumnExpressionOrLiteralArgument = Union[Any, _ColumnExpressionArgument[_T]] _ColumnExpressionOrStrLabelArgument = Union[str, _ColumnExpressionArgument[_T]] diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index e51b755ddb..a416b6ac09 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -99,6 +99,7 @@ if typing.TYPE_CHECKING: from .selectable import _SelectIterable from .selectable import FromClause from .selectable import NamedFromClause + from .selectable import TextualSelect from .sqltypes import TupleType from .type_api import TypeEngine from .visitors import _CloneCallableType @@ -2385,7 +2386,9 @@ class TextClause( return self @util.preload_module("sqlalchemy.sql.selectable") - def columns(self, *cols, **types): + def columns( + self, *cols: _ColumnExpressionArgument[Any], **types: TypeEngine[Any] + ) -> TextualSelect: r"""Turn this :class:`_expression.TextClause` object into a :class:`_expression.TextualSelect` object that serves the same role as a SELECT @@ -2503,29 +2506,38 @@ class TextClause( """ selectable = util.preloaded.sql_selectable + + input_cols: List[NamedColumn[Any]] = [ + coercions.expect(roles.LabeledColumnExprRole, col) for col in cols + ] + positional_input_cols = [ ColumnClause(col.key, types.pop(col.key)) if col.key in types else col - for col in cols + for col in input_cols ] - keyed_input_cols: List[ColumnClause[Any]] = [ + keyed_input_cols: List[NamedColumn[Any]] = [ ColumnClause(key, type_) for key, type_ in types.items() ] - return selectable.TextualSelect( + elem = selectable.TextualSelect.__new__(selectable.TextualSelect) + elem._init( self, positional_input_cols + keyed_input_cols, positional=bool(positional_input_cols) and not keyed_input_cols, ) + return elem @property - def type(self): + def type(self) -> TypeEngine[Any]: return type_api.NULLTYPE @property def comparator(self): - return self.type.comparator_factory(self) + # TODO: this seems wrong, it seems like we might not + # be using this method. + return self.type.comparator_factory(self) # type: ignore def self_group(self, against=None): if against is operators.in_op: diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 75b5d09e3f..39ef420dd6 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -4551,7 +4551,7 @@ class SelectState(util.MemoizedSlots, CompileState): @classmethod def from_statement( - cls, statement: Select[Any], from_statement: ExecutableReturnsRows + cls, statement: Select[Any], from_statement: roles.ReturnsRowsRole ) -> ExecutableReturnsRows: cls._plugin_not_implemented() @@ -5273,7 +5273,7 @@ class Select( return meth(self) def from_statement( - self, statement: ExecutableReturnsRows + self, statement: roles.ReturnsRowsRole ) -> ExecutableReturnsRows: """Apply the columns which this :class:`.Select` would select onto another statement. @@ -6770,7 +6770,7 @@ class Exists(UnaryExpression[bool]): return e -class TextualSelect(SelectBase, Executable, Generative): +class TextualSelect(SelectBase, ExecutableReturnsRows, Generative): """Wrap a :class:`_expression.TextClause` construct within a :class:`_expression.SelectBase` interface. @@ -6815,14 +6815,28 @@ class TextualSelect(SelectBase, Executable, Generative): def __init__( self, text: TextClause, - columns: List[ColumnClause[Any]], + columns: List[_ColumnExpressionArgument[Any]], + positional: bool = False, + ) -> None: + + self._init( + text, + # convert for ORM attributes->columns, etc + [ + coercions.expect(roles.LabeledColumnExprRole, c) + for c in columns + ], + positional, + ) + + def _init( + self, + text: TextClause, + columns: List[NamedColumn[Any]], positional: bool = False, ) -> None: self.element = text - # convert for ORM attributes->columns, etc - self.column_args = [ - coercions.expect(roles.ColumnsClauseRole, c) for c in columns - ] + self.column_args = columns self.positional = positional @HasMemoized_ro_memoized_attribute diff --git a/test/ext/mypy/plain_files/typed_queries.py b/test/ext/mypy/plain_files/typed_queries.py index 3e67a71325..2de565e6a4 100644 --- a/test/ext/mypy/plain_files/typed_queries.py +++ b/test/ext/mypy/plain_files/typed_queries.py @@ -2,13 +2,19 @@ from __future__ import annotations from typing import Tuple +from sqlalchemy import Column from sqlalchemy import column from sqlalchemy import create_engine from sqlalchemy import delete from sqlalchemy import func from sqlalchemy import insert +from sqlalchemy import Integer +from sqlalchemy import MetaData from sqlalchemy import Select from sqlalchemy import select +from sqlalchemy import String +from sqlalchemy import Table +from sqlalchemy import text from sqlalchemy import update from sqlalchemy.orm import aliased from sqlalchemy.orm import DeclarativeBase @@ -29,6 +35,13 @@ class User(Base): data: Mapped[str] +user_table = Table( + "user", + MetaData(), + Column("id", Integer, primary_key=True), + Column("name", String, primary_key=True), +) + session = Session() e = create_engine("sqlite://") @@ -443,3 +456,29 @@ def t_dml_delete() -> None: # EXPECTED_TYPE: Result[Tuple[int, str]] reveal_type(r1) + + +def t_from_statement() -> None: + + t = text("select * from user") + + # EXPECTED_TYPE: TextClause + reveal_type(t) + + select(User).from_statement(t) + + ts = text("select * from user").columns(User.id, User.name) + + # EXPECTED_TYPE: TextualSelect + reveal_type(ts) + + select(User).from_statement(ts) + + ts2 = text("select * from user").columns( + user_table.c.id, user_table.c.name + ) + + # EXPECTED_TYPE: TextualSelect + reveal_type(ts2) + + select(User).from_statement(ts2)