From: Federico Caselli Date: Thu, 22 Jan 2026 23:04:56 +0000 (+0100) Subject: Improve typing story for core from clauses. X-Git-Tag: rel_2_1_0b2~28^2 X-Git-Url: http://git.ipfire.org/gitweb/?a=commitdiff_plain;h=b0b74dc965df2433f6cd951c89499738e97028c0;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Improve typing story for core from clauses. Most :class:`_sql.FromClause` subclasses are not generic on :class:`.TypedColumns` subclasses, that can be used to type their :attr:`_sql.FromClause.c` collection. This applied to :class:`_schema.Table`, :class:`_sql.Join`, :class:`_sql.Subquery`, :class:`_sql.CTE` and more. Fixes: #13085 Change-Id: I724aca887a85c4a401df875903eda12125066680 --- diff --git a/doc/build/changelog/migration_21.rst b/doc/build/changelog/migration_21.rst index 193ef56690..e48fde7071 100644 --- a/doc/build/changelog/migration_21.rst +++ b/doc/build/changelog/migration_21.rst @@ -198,7 +198,7 @@ dataclass-level default (i.e. set using any of the :paramref:`_orm.column_property.default`, or :paramref:`_orm.deferred.default` parameters) is directed to be delivered at the Python :term:`descriptor` level using mechanisms in SQLAlchemy's attribute -system that normally return ``None`` for un-popualted columns, so that even though the default is not +system that normally return ``None`` for un-populated columns, so that even though the default is not populated into ``__dict__``, it's still delivered when the attribute is accessed. This behavior is based on what Python dataclasses itself does when a default is indicated for a field that also includes ``init=False``. @@ -721,6 +721,92 @@ up front, which would be verbose and not automatic. :ticket:`10635` +.. _change_13085: + +Better type checker integration for Core froms, like Table +---------------------------------------------------------- + +SQLAlchemy 2.1 changes :class:`_schema.Table`, along with most +:class:`_sql.FromClause` subclasses, to be generic on the column collection, +providing the option for better static type checking support. +By declaring the columns using a :class:`_schema.TypedColumns` subclass and +providing it to the :class:`_schema.Table` instance, IDEs and type checkers +can infer the exact types of columns when accessing them via the +:attr:`_schema.Table.c` attribute, enabling better autocomplete and type validation. + +Example usage:: + + from sqlalchemy import Table, TypedColumns, Column, Integer, MetaData, select + + + class user_cols(TypedColumns): + id = Column(Integer, primary_key=True) + name: Column[str] + age: Column[int] + + # optional, used to infer the select types when selecting the table + __row_pos__: tuple[int, str, int] + + + metadata = MetaData() + user = Table("user", metadata, user_cols) + + # Type checkers now understand the column types when selecting single columns + stmt = select(user.c.id, user.c.name) # Inferred as Select[int, str] + + # and also when selecting the whole table, when __row_pos__ is present + stmt = select(user) # Inferred as Select[int, str, int] + +The optional :attr:`sqlalchemy.sql._annotated_cols.HasRowPos.__row_pos__` annotation +is used to infer the types of a select when selecting the table directly. + +Columns can be declared in :class:`.TypedColumns` subclasses by instantiating +them directly or by using only a type annotations, that will be inferred when +generating a :class:`_schema.Table`. + +Other :class:`_sql.FromClause`, like :class:`_sql.Join`, :class:`_sql.CTE`, etc, can be made +generic using the :meth:`_sql.FromClause.with_cols` method:: + + # using with_cols the ``c`` collection of the cte has typed tables + cte = user.select().cte().with_cols(user_cols) + +ORM Integration +^^^^^^^^^^^^^^^ + +This functionality also offers some integration with the ORM, by using +:class:`_orm.MappedColumn` annotated attributes in the ORM model and +:func:`_orm.as_typed_table` to get an annotated :class:`_sql.FromClause`:: + + from sqlalchemy import TypedColumns + from sqlalchemy.orm import DeclarativeBase, mapped_column + from sqlalchemy.orm import MappedColumn, as_typed_table + + + class Base(DeclarativeBase): + pass + + + class A(Base): + __tablename__ = "a" + __typed_cols__: "a_cols" + + id: MappedColumn[int] = mapped_column(primary_key=True) + data: MappedColumn[str] + + + class a_cols(A, TypedColumns): + pass + + + # table_a is annotated as FromClause[a_cols], and is just A.__table__ + table_a = as_typed_table(A) + +For proper typing integration :class:`_orm.MappedColumn` should be used +to annotate the single columns, since it's a more specific annotation than +the usual :class:`_orm.Mapped` used for ORM attributes. + +:ticket:`13085` + .. _change_8601: ``filter_by()`` now searches across all FROM clause entities diff --git a/doc/build/changelog/unreleased_21/13085.rst b/doc/build/changelog/unreleased_21/13085.rst new file mode 100644 index 0000000000..6ab65ebd16 --- /dev/null +++ b/doc/build/changelog/unreleased_21/13085.rst @@ -0,0 +1,13 @@ +.. change:: + :tags: schema, usecase + :tickets: 13085 + + Most :class:`_sql.FromClause` subclasses are now generic on + :class:`_schema.TypedColumns` subclasses, that can be used to type their + :attr:`_sql.FromClause.c` collection. + This applied to :class:`_schema.Table`, :class:`_sql.Join`, + :class:`_sql.Subquery`, :class:`_sql.CTE` and more. + + .. seealso:: + + :ref:`change_13085` \ No newline at end of file diff --git a/doc/build/core/metadata.rst b/doc/build/core/metadata.rst index 93af90e42e..9dd5e99af6 100644 --- a/doc/build/core/metadata.rst +++ b/doc/build/core/metadata.rst @@ -916,3 +916,13 @@ Column, Table, MetaData API .. autoclass:: Table :members: :inherited-members: + +.. autoclass:: TypedColumns + :members: + +.. autoclass:: Named + :members: + +.. autoclass:: sqlalchemy.sql._annotated_cols.HasRowPos + :special-members: __row_pos__ + diff --git a/doc/build/orm/mapping_api.rst b/doc/build/orm/mapping_api.rst index ca01b2e22b..286839b65c 100644 --- a/doc/build/orm/mapping_api.rst +++ b/doc/build/orm/mapping_api.rst @@ -150,4 +150,6 @@ Class Mapping API .. autofunction:: unmapped_dataclass +.. autofunction:: as_typed_table + diff --git a/doc/build/tutorial/metadata.rst b/doc/build/tutorial/metadata.rst index 88ad92489f..5b3730851b 100644 --- a/doc/build/tutorial/metadata.rst +++ b/doc/build/tutorial/metadata.rst @@ -197,7 +197,39 @@ parameter. related column, in the above example the :class:`_types.Integer` datatype of the ``user_account.id`` column. -In the next section we will emit the completed DDL for the ``user`` and +Using :class:`.TypedColumns` to get a better typing experience +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +A SQLAlchemy :class:`_schema.Table` can also be defined using a +:class:`_schema.TypedColumns` to offers better integration with type checker and IDEs. +The tables defined above could be declared as follows:: + + >>> from sqlalchemy import Named, TypedColumns, Table + >>> other_meta = MetaData() + >>> class user_cols(TypedColumns): + ... id: Named[int] = Column(primary_key=True) + ... name: Named[str | None] = Column(String(30)) + ... fullname: Named[str | None] + + >>> typed_user_table = Table("user_account", other_meta, user_cols) + + >>> class address_cols(TypedColumns): + ... id: Named[int] = Column(primary_key=True) + ... user_id: Named[int] = Column(ForeignKey("user_account.id")) + ... email_address: Named[str] + ... __row_pos__: tuple[int, int, str] + + >>> typed_address_table = Table("address", other_meta, address_cols) + +The columns are defined by subclassing :class:`.TypedColumns`, so that +static type checkers can understand what columns are present in the +:attr:`_schema.Table.c` collection. Functionally the two methods of defining +the metadata objects are equivalent. +The optional ``__row_pos__`` annotation is an aid to type checker so that +they can correctly suggest the type to apply when selecting from the complete +table, without specifying the single columns. + +In the next section we will emit the completed DDL for the ``user_account`` and ``address`` table to see the completed result. .. _tutorial_emitting_ddl: @@ -576,7 +608,7 @@ are found to be present already: .. _tutorial_table_reflection: Table Reflection -------------------------------- +---------------- .. topic:: Optional Section diff --git a/lib/sqlalchemy/__init__.py b/lib/sqlalchemy/__init__.py index b594e3c665..4eb12b5b27 100644 --- a/lib/sqlalchemy/__init__.py +++ b/lib/sqlalchemy/__init__.py @@ -78,9 +78,11 @@ from .schema import Identity as Identity from .schema import Index as Index from .schema import insert_sentinel as insert_sentinel from .schema import MetaData as MetaData +from .schema import Named as Named from .schema import PrimaryKeyConstraint as PrimaryKeyConstraint from .schema import Sequence as Sequence from .schema import Table as Table +from .schema import TypedColumns as TypedColumns from .schema import UniqueConstraint as UniqueConstraint from .sql import ColumnExpressionArgument as ColumnExpressionArgument from .sql import NotNullable as NotNullable diff --git a/lib/sqlalchemy/orm/__init__.py b/lib/sqlalchemy/orm/__init__.py index e24e311660..43e1d980a1 100644 --- a/lib/sqlalchemy/orm/__init__.py +++ b/lib/sqlalchemy/orm/__init__.py @@ -56,6 +56,7 @@ from .context import FromStatement as FromStatement from .context import QueryContext as QueryContext from .decl_api import add_mapped_attribute as add_mapped_attribute from .decl_api import as_declarative as as_declarative +from .decl_api import as_typed_table as as_typed_table from .decl_api import declarative_base as declarative_base from .decl_api import declarative_mixin as declarative_mixin from .decl_api import DeclarativeBase as DeclarativeBase diff --git a/lib/sqlalchemy/orm/base.py b/lib/sqlalchemy/orm/base.py index 2f87d585f4..5f4268f705 100644 --- a/lib/sqlalchemy/orm/base.py +++ b/lib/sqlalchemy/orm/base.py @@ -27,11 +27,14 @@ from typing import TypeVar from typing import Union from . import exc +from ._typing import _O from ._typing import insp_is_mapper from .. import exc as sa_exc from .. import inspection from .. import util from ..sql import roles +from ..sql._typing import _T +from ..sql._typing import _T_co from ..sql.elements import SQLColumnExpression from ..sql.elements import SQLCoreOperations from ..util import FastIntFlag @@ -46,18 +49,16 @@ if typing.TYPE_CHECKING: from .instrumentation import ClassManager from .interfaces import PropComparator from .mapper import Mapper + from .properties import MappedColumn from .state import InstanceState from .util import AliasedClass from .writeonly import WriteOnlyCollection + from ..sql._annotated_cols import TypedColumns from ..sql._typing import _ColumnExpressionArgument from ..sql._typing import _InfoType from ..sql.elements import ColumnElement from ..sql.operators import OperatorType - -_T = TypeVar("_T", bound=Any) -_T_co = TypeVar("_T_co", bound=Any, covariant=True) - -_O = TypeVar("_O", bound=object) + from ..sql.schema import Column class LoaderCallableStatus(Enum): @@ -804,6 +805,11 @@ class Mapped( if typing.TYPE_CHECKING: + @overload + def __get__( # type: ignore[misc] + self: MappedColumn[_T_co], instance: TypedColumns, owner: Any + ) -> Column[_T_co]: ... + @overload def __get__( self, instance: None, owner: Any @@ -814,7 +820,7 @@ class Mapped( def __get__( self, instance: Optional[object], owner: Any - ) -> Union[InstrumentedAttribute[_T_co], _T_co]: ... + ) -> Union[InstrumentedAttribute[_T_co], Column[_T_co], _T_co]: ... @classmethod def _empty_constructor(cls, arg1: Any) -> Mapped[_T_co]: ... diff --git a/lib/sqlalchemy/orm/decl_api.py b/lib/sqlalchemy/orm/decl_api.py index 2363595134..e42cbf3394 100644 --- a/lib/sqlalchemy/orm/decl_api.py +++ b/lib/sqlalchemy/orm/decl_api.py @@ -23,6 +23,7 @@ from typing import Literal from typing import Mapping from typing import Optional from typing import overload +from typing import Protocol from typing import Set from typing import Tuple from typing import Type @@ -52,6 +53,7 @@ from .decl_base import _DeclarativeMapperConfig from .decl_base import _DeferredDeclarativeConfig from .decl_base import _del_attribute from .decl_base import _ORMClassConfigurator +from .decl_base import MappedClassProtocol from .descriptor_props import Composite from .descriptor_props import Synonym from .descriptor_props import Synonym as _orm_synonym @@ -65,6 +67,7 @@ from .. import util from ..event import dispatcher from ..event import EventTarget from ..sql import sqltypes +from ..sql._annotated_cols import _TC from ..sql.base import _NoArg from ..sql.elements import SQLCoreOperations from ..sql.schema import MetaData @@ -753,6 +756,100 @@ class _DeclarativeTyping(TypingOnly): def __init__(self, **kw: Any): ... +class MappedClassWithTypedColumnsProtocol(Protocol[_TC]): + """An ORM mapped class that also defines in the ``__typed_cols__`` + attribute its . + """ + + __typed_cols__: _TC + """The :class:`_schema.TypedColumns` of this ORM mapped class.""" + + __name__: ClassVar[str] + __mapper__: ClassVar[Mapper[Any]] + __table__: ClassVar[FromClause] + + +@overload +def as_typed_table( + cls: type[MappedClassWithTypedColumnsProtocol[_TC]], / +) -> FromClause[_TC]: ... + + +@overload +def as_typed_table( + cls: MappedClassProtocol[Any], typed_columns_cls: type[_TC], / +) -> FromClause[_TC]: ... + + +def as_typed_table( + cls: ( + MappedClassProtocol[Any] + | type[MappedClassWithTypedColumnsProtocol[Any]] + ), + typed_columns_cls: Any = None, + /, +) -> FromClause[Any]: + """Return a typed :class:`_sql.FromClause` from the give ORM model. + + This function is just a typing help, at runtime it just returns the + ``__table__`` attribute of the provided ORM model. + + It's usually called providing both the ORM model and the + :class:`_schema.TypedColumns` class. Single argument calls are supported + if the ORM model class provides an annotation pointing to its + :class:`_schema.TypedColumns` in the ``__typed_cols__`` attribute. + + + Example usage:: + + from sqlalchemy import TypedColumns + from sqlalchemy.orm import DeclarativeBase, mapped_column + from sqlalchemy.orm import MappedColumn, as_typed_table + + + class Base(DeclarativeBase): + pass + + + class A(Base): + __tablename__ = "a" + + id: MappedColumn[int] = mapped_column(primary_key=True) + data: MappedColumn[str] + + + class a_cols(A, TypedColumns): + pass + + + # table_a is annotated as FromClause[a_cols] + table_a = as_typed_table(A, a_cols) + + + class B(Base): + __tablename__ = "b" + __typed_cols__: "b_cols" + + a: Mapped[int] = mapped_column(primary_key=True) + b: Mapped[str] + + + class b_cols(B, TypedColumns): + pass + + + # table_b is a FromClause[b_cols], can call with just B since it + # provides the __typed_cols__ annotation + table_b = as_typed_table(B) + + For proper typing integration :class:`_orm.MappedColumn` should be used + to annotate the single columns, since it's a more specific annotation than + the usual :class:`_orm.Mapped` used for ORM attributes. + + """ + return cls.__table__ + + class DeclarativeBase( # Inspectable is used only by the mypy plugin inspection.Inspectable[InstanceState[Any]], diff --git a/lib/sqlalchemy/orm/decl_base.py b/lib/sqlalchemy/orm/decl_base.py index 9e30fd84de..f646fbb01b 100644 --- a/lib/sqlalchemy/orm/decl_base.py +++ b/lib/sqlalchemy/orm/decl_base.py @@ -64,6 +64,7 @@ from .. import event from .. import exc from .. import util from ..sql import expression +from ..sql._annotated_cols import TypedColumns from ..sql.base import _NoArg from ..sql.schema import Column from ..sql.schema import Table @@ -276,7 +277,11 @@ class _ORMClassConfigurator: f"Class {cls_!r} already has been instrumented declaratively" ) - if cls_.__dict__.get("__abstract__", False): + # allow subclassing an orm class with typed columns without + # generating an orm class + if cls_.__dict__.get("__abstract__", False) or issubclass( + cls_, TypedColumns + ): return None defer_map = _get_immediate_cls_attr( diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 16de51cef6..cdf4774169 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -1020,7 +1020,10 @@ class AliasedInsp( if name: return element.alias(name=name, flat=flat) else: - return coercions.expect( + # see selectable.py->Alias._factory() for similar + # mypy issue. Cannot get the overload to see this + # in mypy (works fine in pyright) + return coercions.expect( # type: ignore[no-any-return] roles.AnonymizedFromClauseRole, element, flat=flat ) else: diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index 3ac069736e..98c1dc484a 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -9,6 +9,8 @@ from __future__ import annotations +from .sql._annotated_cols import Named as Named +from .sql._annotated_cols import TypedColumns as TypedColumns from .sql.base import SchemaVisitor as SchemaVisitor from .sql.ddl import _CreateDropBase as _CreateDropBase from .sql.ddl import AddConstraint as AddConstraint diff --git a/lib/sqlalchemy/sql/_annotated_cols.py b/lib/sqlalchemy/sql/_annotated_cols.py new file mode 100644 index 0000000000..02c9456092 --- /dev/null +++ b/lib/sqlalchemy/sql/_annotated_cols.py @@ -0,0 +1,397 @@ +# sql/_annotated_cols.py +# Copyright (C) 2005-2026 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +from __future__ import annotations + +from typing import Any +from typing import Generic +from typing import Literal +from typing import NoReturn +from typing import overload +from typing import Protocol +from typing import TYPE_CHECKING + +from . import sqltypes +from ._typing import _T +from ._typing import _Ts +from .base import _NoArg +from .base import ReadOnlyColumnCollection +from .. import util +from ..exc import ArgumentError +from ..exc import InvalidRequestError +from ..util import typing as sa_typing +from ..util.langhelpers import dunders_re +from ..util.typing import Never +from ..util.typing import Self +from ..util.typing import TypeVar +from ..util.typing import Unpack + +if TYPE_CHECKING: + from .elements import ColumnClause # noqa (for zimports) + from .elements import KeyedColumnElement # noqa (for zimports) + from .schema import Column + from .type_api import TypeEngine + from ..util.typing import _AnnotationScanType + + +class Named(Generic[_T]): + """A named descriptor that is interpreted by SQLAlchemy in various ways. + + .. seealso:: + + :class:`_schema.TypedColumns` Define table columns using this + descriptor. + + .. versionadded:: 2.1.0b2 + """ + + __slots__ = () + + key: str + if TYPE_CHECKING: + + # NOTE: this overload prevents users from using the a TypedColumns + # class like if it were an orm mapped class + @overload + def __get__(self, instance: None, owner: Any) -> Never: ... + + @overload + def __get__( + self, instance: TypedColumns, owner: Any + ) -> Column[_T]: ... + @overload + def __get__(self, instance: Any, owner: Any) -> Self: ... + + def __get__(self, instance: object | None, owner: Any) -> Any: ... + + +# NOTE: TypedColumns subclasses are ignored by the ORM mapping process +class TypedColumns(ReadOnlyColumnCollection[str, "Column[Any]"]): + """Class that generally represent the typed columns of a :class:`.Table`, + but can be used with most :class:`_sql.FromClause` subclasses with the + :meth:`_sql.FromClause.with_cols` method. + + This is a "typing only" class that is never instantiated at runtime: the + type checker will think that this class is exposed as the ``table.c`` + attribute, but in reality a normal :class:`_schema.ColumnCollection` is + used at runtime. + + Subclasses should just list the columns as class attributes, without + specifying method or other non column members. + + To resolve the columns, a simplified version of the ORM logic is used, + in particular, columns can be declared by: + + * directly instantiating them, to declare constraint, custom SQL types and + additional column options; + * using only a :class:`.Named` or :class:`_schema.Column` type annotation, + where nullability and SQL type will be inferred by the python type + provided. + Type inference is available for a common subset of python types. + * a mix of both, where the instance can be used to declare + constraints and other column options while the annotation will be used + to set the SQL type and nullability if not provided by the instance. + + In all cases the name is inferred from the attribute name, unless + explicitly provided. + + .. note:: + + The generated table will create a copy of any column instance assigned + as attributes of this class, so columns should be accessed only via + the ``table.c`` collection, not using this class directly. + + Example of the inference behavior:: + + from sqlalchemy import Column, Integer, Named, String, TypedColumns + + + class tbl_cols(TypedColumns): + # the name will be set to ``id``, type is inferred as Column[int] + id = Column(Integer, primary_key=True) + + # not null String column is generated + name: Named[str] + + # nullable Double column is generated + weight: Named[float | None] + + # nullable Integer column, with sql name 'user_age' + age: Named[int | None] = Column("user_age") + + # not null column with type String(42) + middle_name: Named[str] = Column(String(42)) + + Mixins and subclasses are also supported:: + + class with_id(TypedColumns): + id = Column(Integer, primary_key=True) + + + class named_cols(TypedColumns): + name: Named[str] + description: Named[str | None] + + + class product_cols(named_cols, with_id): + ean: Named[str] = Column(unique=True) + + + product = Table("product", metadata, product_cols) + + + class office_cols(named_cols, with_id): + address: Named[str] + + + office = Table("office", metadata, office_cols) + + The positional types returned when selecting the table can + be optionally declared by specifying a :attr:`.HasRowPos.__row_pos__` + annotation:: + + from sqlalchemy import select + + + class some_cols(TypedColumns): + id = Column(Integer, primary_key=True) + name: Named[str] + weight: Named[float | None] + + __row_pos__: tuple[int, str, float | None] + + + some_table = Table("st", metadata, some_cols) + + # both will be typed as Select[int, str, float | None] + stmt1 = some_table.select() + stmt2 = select(some_table) + + .. seealso:: + + :class:`.Table` for usage details on how to use this class to + create a table instance. + + :meth:`_sql.FromClause.with_cols` to apply a :class:`.TypedColumns` + to a from clause. + + .. versionadded:: 2.1.0b2 + """ # noqa + + __slots__ = () + + if not TYPE_CHECKING: + + def __new__(cls, *args: Any, **kwargs: Any) -> NoReturn: + raise InvalidRequestError( + "Cannot instantiate a TypedColumns object." + ) + + def __init_subclass__(cls) -> None: + methods = { + name + for name, value in cls.__dict__.items() + if not dunders_re.match(name) and callable(value) + } + if methods: + raise InvalidRequestError( + "TypedColumns subclasses may not define methods. " + f"Found {sorted(methods)}" + ) + + +_KeyColCC_co = TypeVar( + "_KeyColCC_co", + bound=ReadOnlyColumnCollection[str, "KeyedColumnElement[Any]"], + covariant=True, + default=ReadOnlyColumnCollection[str, "KeyedColumnElement[Any]"], +) +_ColClauseCC_co = TypeVar( + "_ColClauseCC_co", + bound=ReadOnlyColumnCollection[str, "ColumnClause[Any]"], + covariant=True, + default=ReadOnlyColumnCollection[str, "ColumnClause[Any]"], +) +_ColCC_co = TypeVar( + "_ColCC_co", + bound=ReadOnlyColumnCollection[str, "Column[Any]"], + covariant=True, + default=ReadOnlyColumnCollection[str, "Column[Any]"], +) + +_TC = TypeVar("_TC", bound=TypedColumns) +_TC_co = TypeVar("_TC_co", bound=TypedColumns, covariant=True) + + +class HasRowPos(Protocol[Unpack[_Ts]]): + """Protocol for a :class:`_schema.TypedColumns` used to indicate the + positional types will be returned when selecting the table. + + .. versionadded:: 2.1.0b2 + """ + + __row_pos__: tuple[Unpack[_Ts]] + """A tuple that represents the types that will be returned when + selecting from the table. + """ + + +@util.preload_module("sqlalchemy.sql.schema") +def _extract_columns_from_class( + table_columns_cls: type[TypedColumns], +) -> list[Column[Any]]: + columns: dict[str, Column[Any]] = {} + + Column = util.preloaded.sql_schema.Column + NULL_UNSPECIFIED = util.preloaded.sql_schema.NULL_UNSPECIFIED + + for base in table_columns_cls.__mro__[::-1]: + if base in TypedColumns.__mro__: + continue + + # _ClassScanAbstractConfig._cls_attr_resolver + cls_annotations = util.get_annotations(base) + cls_vars = vars(base) + items = [ + (n, cls_vars.get(n), cls_annotations.get(n)) + for n in util.merge_lists_w_ordering( + list(cls_vars), list(cls_annotations) + ) + if not dunders_re.match(n) + ] + # -- + for name, obj, annotation in items: + if obj is None: + assert annotation is not None + # no attribute, just annotation + extracted_type = _collect_annotation( + table_columns_cls, name, base.__module__, annotation + ) + if extracted_type is _NoArg.NO_ARG: + raise ArgumentError( + "No type information could be extracted from " + f"annotation {annotation} for attribute " + f"'{base.__name__}.{name}'" + ) + sqltype = _get_sqltype(extracted_type) + if sqltype is None: + raise ArgumentError( + f"Could not find a SQL type for type {extracted_type} " + f"obtained from annotation {annotation} in " + f"attribute '{base.__name__}.{name}'" + ) + columns[name] = Column( + name, + sqltype, + nullable=sa_typing.includes_none(extracted_type), + ) + elif isinstance(obj, Column): + # has attribute attribute + # _DeclarativeMapperConfig._produce_column_copies + # as with orm this case is not supported + for fk in obj.foreign_keys: + if ( + fk._table_column is not None + and fk._table_column.table is None + ): + raise InvalidRequestError( + f"Column '{base.__name__}.{name}' with foreign " + "key to non-table-bound columns is not supported " + "when using a TypedColumns. If possible use the " + "qualified string name the column" + ) + + col = obj._copy() + # MapptedColumn.declarative_scan + if col.key == col.name and col.key != name: + col.key = name + if col.name is None: + col.name = name + + sqltype = col.type + anno_sqltype = None + nullable: Literal[_NoArg.NO_ARG] | bool = _NoArg.NO_ARG + if annotation is not None: + # there is an annotation, extract the type + extracted_type = _collect_annotation( + table_columns_cls, name, base.__module__, annotation + ) + if extracted_type is not _NoArg.NO_ARG: + anno_sqltype = _get_sqltype(extracted_type) + nullable = sa_typing.includes_none(extracted_type) + + if sqltype._isnull: + if anno_sqltype is None and not col.foreign_keys: + raise ArgumentError( + "Python typing annotation is required for " + f"attribute '{base.__name__}.{name}' when " + "primary argument(s) for Column construct are " + "None or not present" + ) + elif anno_sqltype is not None: + col._set_type(anno_sqltype) + + if ( + nullable is not _NoArg.NO_ARG + and col._user_defined_nullable is NULL_UNSPECIFIED + and not col.primary_key + ): + col.nullable = nullable + columns[name] = col + else: + raise ArgumentError( + f"Unexpected value for attribute '{base.__name__}.{name}'" + f". Expected a Column, not: {type(obj)}" + ) + + # Return columns as a list + return list(columns.values()) + + +@util.preload_module("sqlalchemy.sql.schema") +def _collect_annotation( + cls: type[Any], name: str, module: str, raw_annotation: _AnnotationScanType +) -> _AnnotationScanType | Literal[_NoArg.NO_ARG]: + Column = util.preloaded.sql_schema.Column + + _locals = {"Column": Column, "Named": Named} + # _ClassScanAbstractConfig._collect_annotation & _extract_mapped_subtype + try: + annotation = sa_typing.de_stringify_annotation( + cls, raw_annotation, module, _locals + ) + except Exception as e: + raise ArgumentError( + f"Could not interpret annotation {raw_annotation} for " + f"attribute '{cls.__name__}.{name}'" + ) from e + + if ( + not sa_typing.is_generic(annotation) + and isinstance(annotation, type) + and issubclass(annotation, (Column, Named)) + ): + # no generic information, ignore + return _NoArg.NO_ARG + elif not sa_typing.is_origin_of_cls(annotation, (Column, Named)): + raise ArgumentError( + f"Annotation {raw_annotation} for attribute " + f"'{cls.__name__}.{name}' is not of type Named/Column[...]" + ) + else: + assert len(annotation.__args__) == 1 # Column[int, int] raises + return annotation.__args__[0] # type: ignore[no-any-return] + + +def _get_sqltype(annotation: _AnnotationScanType) -> TypeEngine[Any] | None: + our_type = sa_typing.de_optionalize_union_types(annotation) + # simplified version of registry._resolve_type given no customizable + # type map + sql_type = sqltypes._type_map_get(our_type) # type: ignore[arg-type] + if sql_type is not None and not sql_type._isnull: + return sqltypes.to_instance(sql_type) + else: + return None diff --git a/lib/sqlalchemy/sql/_selectable_constructors.py b/lib/sqlalchemy/sql/_selectable_constructors.py index 0c564f97dd..d03b925ed7 100644 --- a/lib/sqlalchemy/sql/_selectable_constructors.py +++ b/lib/sqlalchemy/sql/_selectable_constructors.py @@ -15,6 +15,8 @@ from typing import Union from . import coercions from . import roles +from ._annotated_cols import _KeyColCC_co +from ._annotated_cols import HasRowPos from ._typing import _ColumnsClauseArgument from ._typing import _no_kw from .elements import ColumnClause @@ -49,6 +51,7 @@ if TYPE_CHECKING: from ._typing import _T8 from ._typing import _T9 from ._typing import _Ts + from ._typing import _Ts2 from ._typing import _TypedColumnClauseArgument as _TCCA from .functions import Function from .selectable import CTE @@ -58,8 +61,10 @@ if TYPE_CHECKING: def alias( - selectable: FromClause, name: Optional[str] = None, flat: bool = False -) -> NamedFromClause: + selectable: FromClause[_KeyColCC_co], + name: Optional[str] = None, + flat: bool = False, +) -> NamedFromClause[_KeyColCC_co]: """Return a named alias of the given :class:`.FromClause`. For :class:`.Table` and :class:`.Join` objects, the return type is the @@ -496,6 +501,68 @@ def select( # END OVERLOADED FUNCTIONS select +@overload +def select( + __table: FromClause[HasRowPos[Unpack[_Ts]]], # type: ignore[type-var] +) -> Select[Unpack[_Ts]]: ... + + +# NOTE: this seems to currently be interpreted by mypy as not allowed. +# https://peps.python.org/pep-0646/#multiple-type-variable-tuples-not-allowed +# https://github.com/python/mypy/issues/20188 +@overload +def select( + __table: FromClause[HasRowPos[Unpack[_Ts]]], # type: ignore[type-var] + __table2: FromClause[HasRowPos[Unpack[_Ts2]]], # type: ignore[type-var] +) -> Select[Unpack[_Ts], Unpack[_Ts2]]: ... # type: ignore[misc] + + +@overload +def select( + __table: FromClause[HasRowPos[Unpack[_Ts]]], # type: ignore[type-var] + __ent0: _TCCA[_T0], +) -> Select[Unpack[_Ts], _T0]: ... + + +@overload +def select( + __table: FromClause[HasRowPos[Unpack[_Ts]]], # type: ignore[type-var] + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], +) -> Select[Unpack[_Ts], _T0, _T1]: ... + + +@overload +def select( + __table: FromClause[HasRowPos[Unpack[_Ts]]], # type: ignore[type-var] + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], +) -> Select[Unpack[_Ts], _T0, _T1, _T2]: ... + + +@overload +def select( + __ent0: _TCCA[_T0], + __table: FromClause[HasRowPos[Unpack[_Ts]]], # type: ignore[type-var] +) -> Select[_T0, Unpack[_Ts]]: ... + + +@overload +def select( + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __table: FromClause[HasRowPos[Unpack[_Ts]]], # type: ignore[type-var] +) -> Select[_T0, _T1, Unpack[_Ts]]: ... + + +@overload +def select( + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __table: FromClause[HasRowPos[Unpack[_Ts]]], # type: ignore[type-var] +) -> Select[_T0, _T1, _T2, Unpack[_Ts]]: ... @overload diff --git a/lib/sqlalchemy/sql/_typing.py b/lib/sqlalchemy/sql/_typing.py index 2848471a43..492601e554 100644 --- a/lib/sqlalchemy/sql/_typing.py +++ b/lib/sqlalchemy/sql/_typing.py @@ -81,6 +81,7 @@ if TYPE_CHECKING: _T = TypeVar("_T", bound=Any) _T_co = TypeVar("_T_co", bound=Any, covariant=True) _Ts = TypeVarTuple("_Ts") +_Ts2 = TypeVarTuple("_Ts2") _CE = TypeVar("_CE", bound="ColumnElement[Any]") diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index 9652616cb9..56bde75a12 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -1933,12 +1933,7 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]): @overload def __getitem__( - self, key: Tuple[Union[str, int], ...] - ) -> ReadOnlyColumnCollection[_COLKEY, _COL_co]: ... - - @overload - def __getitem__( - self, key: slice + self, key: Union[Tuple[Union[str, int], ...], slice] ) -> ReadOnlyColumnCollection[_COLKEY, _COL_co]: ... def __getitem__( diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index 58451f320f..4ca0f0bc6f 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -39,6 +39,7 @@ from .base import ColumnCollection from .base import ExecutableStatement from .base import Generative from .base import HasMemoized +from .base import ReadOnlyColumnCollection from .base import WriteableColumnCollection from .elements import _type_from_args from .elements import AggregateOrderBy @@ -393,7 +394,9 @@ class FunctionElement( return self.alias(name=name, joins_implicitly=joins_implicitly).column @util.ro_non_memoized_property - def columns(self) -> ColumnCollection[str, KeyedColumnElement[Any]]: # type: ignore[override] # noqa: E501 + def columns( + self, + ) -> ReadOnlyColumnCollection[str, KeyedColumnElement[Any]]: r"""The set of columns exported by this :class:`.FunctionElement`. This is a placeholder collection that allows the function to be @@ -419,12 +422,12 @@ class FunctionElement( return self.c @util.ro_memoized_property - def c(self) -> ColumnCollection[str, KeyedColumnElement[Any]]: # type: ignore[override] # noqa: E501 + def c(self) -> ReadOnlyColumnCollection[str, KeyedColumnElement[Any]]: """synonym for :attr:`.FunctionElement.columns`.""" return WriteableColumnCollection( columns=[(col.key, col) for col in self._all_selected_columns] - ) + ).as_readonly() @property def _all_selected_columns(self) -> Sequence[KeyedColumnElement[Any]]: diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index 4df6798da6..4a8752cfb6 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -65,6 +65,12 @@ from . import ddl from . import roles from . import type_api from . import visitors +from ._annotated_cols import _ColCC_co +from ._annotated_cols import _extract_columns_from_class +from ._annotated_cols import _TC_co +from ._annotated_cols import Named +from ._annotated_cols import TypedColumns +from ._typing import _T from .base import _DefaultDescriptionTuple from .base import _NoArg from .base import _NoneName @@ -119,7 +125,6 @@ if typing.TYPE_CHECKING: from ..engine.reflection import _ReflectionInfo from ..sql.selectable import FromClause -_T = TypeVar("_T", bound="Any") _SI = TypeVar("_SI", bound="SchemaItem") _TAB = TypeVar("_TAB", bound="Table") @@ -320,12 +325,19 @@ class HasSchemaAttr(SchemaItem): class Table( - DialectKWArgs, HasSchemaAttr, TableClause, inspection.Inspectable["Table"] + DialectKWArgs, + HasSchemaAttr, + TableClause[_ColCC_co], + inspection.Inspectable["Table"], ): r"""Represent a table in a database. e.g.:: + from sqlalchemy import Table, MetaData, Integer, String, Column + + metadata = MetaData() + mytable = Table( "mytable", metadata, @@ -342,10 +354,56 @@ class Table( object - in this way the :class:`_schema.Table` constructor acts as a registry function. + May also be defined as "typed table" by passing a subclass of + :class:`_schema.TypedColumns` as the 3rd argument:: + + from sqlalchemy import TypedColumns, select + + + class user_cols(TypedColumns): + id = Column(Integer, primary_key=True) + name: Column[str] + age: Column[int] + middle_name: Column[str | None] + + # optional, used to infer the select types when selecting the table + __row_pos__: tuple[int, str, int, str | None] + + + user = Table("user", metadata, user_cols) + + # the columns are typed: the statement has type Select[int, str] + stmt = sa.select(user.c.id, user.c.name).where(user.c.age > 30) + + # Inferred as Select[int, str, int, str | None] thanks to __row_pos__ + stmt1 = user.select() + stmt2 = sa.select(user) + + The :attr:`sqlalchemy.sql._annotated_cols.HasRowPos.__row_pos__` + annotation is optional, and it's used to infer the types in a + :class:`_sql.Select` when selecting the complete table. + If a :class:`_schema.TypedColumns` does not define it, + the default ``Select[*tuple[Any]]`` will be inferred. + + An existing :class:`Table` can be casted as "typed table" using + the :meth:`Table.with_cols`:: + + class mytable_cols(TypedColumns): + mytable_id: Column[int] + value: Column[str | None] + + + typed_mytable = mytable.with_cols(mytable_cols) + .. seealso:: - :ref:`metadata_describing` - Introduction to database metadata + :ref:`metadata_describing` Introduction to database metadata + :class:`_schema.TypedColumns` More information about typed column + definition + + .. versionchanged:: 2.1.0b2 - :class:`_schema.Table` is now generic to + support "typed tables" """ __visit_name__ = "table" @@ -358,6 +416,8 @@ class Table( @util.ro_non_memoized_property def foreign_keys(self) -> Set[ForeignKey]: ... + def with_cols(self, type_: type[_TC_co]) -> Table[_TC_co]: ... + _columns: DedupeColumnCollection[Column[Any]] # type: ignore[assignment] _sentinel_column: Optional[Column[Any]] @@ -400,19 +460,6 @@ class Table( """ - if TYPE_CHECKING: - - @util.ro_non_memoized_property - def columns(self) -> ReadOnlyColumnCollection[str, Column[Any]]: ... - - @util.ro_non_memoized_property - def exported_columns( - self, - ) -> ReadOnlyColumnCollection[str, Column[Any]]: ... - - @util.ro_non_memoized_property - def c(self) -> ReadOnlyColumnCollection[str, Column[Any]]: ... - def _gen_cache_key( self, anon_map: anon_map, bindparams: List[BindParameter[Any]] ) -> Tuple[Any, ...]: @@ -441,11 +488,32 @@ class Table( return object.__new__(cls) try: - name, metadata, args = args[0], args[1], args[2:] - except IndexError: + name, metadata, *other_args = args + except ValueError: raise TypeError( "Table() takes at least two positional-only " - "arguments 'name' and 'metadata'" + "arguments 'name', and 'metadata'" + ) from None + if other_args and isinstance(other_args[0], type): + typed_columns_cls = other_args[0] + if not issubclass(typed_columns_cls, TypedColumns): + raise exc.InvalidRequestError( + "The ``typed_columns_cls`` argument requires a " + "TypedColumns subclass." + ) + elif hasattr(typed_columns_cls, "_sa_class_manager"): + # an orm class subclassed with TypedColumns. Reject it + raise exc.InvalidRequestError( + "To get a typed table from an ORM class, use the " + "`as_typed_table()` function instead." + ) + + extracted_columns = _extract_columns_from_class(typed_columns_cls) + other_args = extracted_columns + other_args[1:] + elif "typed_columns_cls" in kw: + raise TypeError( + "The ``typed_columns_cls`` argument may be passed " + "only positionally" ) schema = kw.get("schema", None) @@ -463,7 +531,7 @@ class Table( must_exist = kw.pop("must_exist", kw.pop("mustexist", False)) key = _get_table_key(name, schema) if key in metadata.tables: - if not keep_existing and not extend_existing and bool(args): + if not keep_existing and not extend_existing and bool(other_args): raise exc.InvalidRequestError( f"Table '{key}' is already defined for this MetaData " "instance. Specify 'extend_existing=True' " @@ -473,7 +541,7 @@ class Table( ) table = metadata.tables[key] if extend_existing: - table._init_existing(*args, **kw) + table._init_existing(*other_args, **kw) return table else: if must_exist: @@ -482,20 +550,45 @@ class Table( table.dispatch.before_parent_attach(table, metadata) metadata._add_table(name, schema, table) try: - table.__init__(name, metadata, *args, _no_init=False, **kw) # type: ignore[misc] # noqa: E501 + table.__init__(name, metadata, *other_args, _no_init=False, **kw) # type: ignore[misc] # noqa: E501 table.dispatch.after_parent_attach(table, metadata) return table except Exception: with util.safe_reraise(): metadata._remove_table(name, schema) + @overload def __init__( - self, + self: Table[_TC_co], name: str, metadata: MetaData, + typed_columns_cls: type[_TC_co], + /, *args: SchemaItem, - schema: Optional[Union[str, Literal[SchemaConst.BLANK_SCHEMA]]] = None, - quote: Optional[bool] = None, + schema: str | Literal[SchemaConst.BLANK_SCHEMA] | None = None, + quote: bool | None = None, + quote_schema: bool | None = None, + keep_existing: bool = False, + extend_existing: bool = False, + implicit_returning: bool = True, + comment: str | None = None, + info: dict[Any, Any] | None = None, + listeners: ( + _typing_Sequence[tuple[str, Callable[..., Any]]] | None + ) = None, + prefixes: _typing_Sequence[str] | None = None, + **kw: Any, + ) -> None: ... + + @overload + def __init__( + self: Table[ReadOnlyColumnCollection[str, Column[Any]]], + name: str, + metadata: MetaData, + /, + *args: SchemaItem, + schema: str | Literal[SchemaConst.BLANK_SCHEMA] | None = None, + quote: bool | None = None, quote_schema: Optional[bool] = None, autoload_with: Optional[Union[Engine, Connection]] = None, autoload_replace: bool = True, @@ -504,12 +597,44 @@ class Table( resolve_fks: bool = True, include_columns: Optional[Collection[str]] = None, implicit_returning: bool = True, - comment: Optional[str] = None, - info: Optional[Dict[Any, Any]] = None, - listeners: Optional[ - _typing_Sequence[Tuple[str, Callable[..., Any]]] - ] = None, - prefixes: Optional[_typing_Sequence[str]] = None, + comment: str | None = None, + info: dict[Any, Any] | None = None, + listeners: ( + _typing_Sequence[tuple[str, Callable[..., Any]]] | None + ) = None, + prefixes: _typing_Sequence[str] | None = None, + _creator_ddl: TableCreateDDL | None = None, + _dropper_ddl: TableDropDDL | None = None, + # used internally in the metadata.reflect() process + _extend_on: Optional[Set[Table]] = None, + # used by __new__ to bypass __init__ + _no_init: bool = True, + # dialect-specific keyword args + **kw: Any, + ) -> None: ... + + def __init__( + self, + name: str, + metadata: MetaData, + /, + *args: Any, + schema: str | Literal[SchemaConst.BLANK_SCHEMA] | None = None, + quote: bool | None = None, + quote_schema: Optional[bool] = None, + autoload_with: Optional[Union[Engine, Connection]] = None, + autoload_replace: bool = True, + keep_existing: bool = False, + extend_existing: bool = False, + resolve_fks: bool = True, + include_columns: Optional[Collection[str]] = None, + implicit_returning: bool = True, + comment: str | None = None, + info: dict[Any, Any] | None = None, + listeners: ( + _typing_Sequence[tuple[str, Callable[..., Any]]] | None + ) = None, + prefixes: _typing_Sequence[str] | None = None, _creator_ddl: TableCreateDDL | None = None, _dropper_ddl: TableDropDDL | None = None, # used internally in the metadata.reflect() process @@ -549,6 +674,12 @@ class Table( may be used to associate this table with a particular :class:`.Connection` or :class:`.Engine`. + :param table_columns_cls: a subclass of :class:`_schema.TypedColumns` + that defines the columns that will be "typed" when accessing + them from the :attr:`_schema.Table.c` attribute. + + .. versionadded:: 2.1.0b2 + :param \*args: Additional positional arguments are used primarily to add the list of :class:`_schema.Column` objects contained within this @@ -556,6 +687,10 @@ class Table( :class:`.SchemaItem` constructs may be added here, including :class:`.PrimaryKeyConstraint`, and :class:`_schema.ForeignKeyConstraint`. + Additional columns may be provided also when using a + :paramref:`_schema.Table.table_columns_cls` class; they will + be appended to the "typed" columns and will appear as untyped + when accessing them via the :attr:`_schema.Table.c` collection. :param autoload_replace: Defaults to ``True``; when using :paramref:`_schema.Table.autoload_with` @@ -813,6 +948,12 @@ class Table( # don't run __init__ from __new__ by default; # __new__ has a specific place that __init__ is called return + if args: + # this is the call done by `__new__` that should have resolved + # TypedColumns to the individual columns + assert not ( + isinstance(args[0], type) and issubclass(args[0], TypedColumns) + ) super().__init__(quoted_name(name, quote)) self.metadata = metadata @@ -1440,7 +1581,7 @@ class Table( ] ] = None, name: Optional[str] = None, - ) -> Table: + ) -> Table[_ColCC_co]: """Return a copy of this :class:`_schema.Table` associated with a different :class:`_schema.MetaData`. @@ -1466,7 +1607,7 @@ class Table( ] ] = None, name: Optional[str] = None, - ) -> Table: + ) -> Table[_ColCC_co]: """Return a copy of this :class:`_schema.Table` associated with a different :class:`_schema.MetaData`. @@ -1566,7 +1707,7 @@ class Table( for col in self.columns: args.append(col._copy(schema=actual_schema, _to_metadata=metadata)) - table = Table( + table: Table[_ColCC_co] = Table( # type: ignore[assignment] name, metadata, schema=actual_schema, @@ -1625,7 +1766,7 @@ class Table( return self._schema_item_copy(table) -class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]): +class Column(DialectKWArgs, SchemaItem, ColumnClause[_T], Named[_T]): """Represents a column in a database table.""" __visit_name__ = "column" @@ -2182,15 +2323,14 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]): l_args = [__name_pos, __type_pos] + list(args) del args - if l_args: - if isinstance(l_args[0], str): - if name is not None: - raise exc.ArgumentError( - "May not pass name positionally and as a keyword." - ) - name = l_args.pop(0) # type: ignore - elif l_args[0] is None: - l_args.pop(0) + if isinstance(l_args[0], str): + if name is not None: + raise exc.ArgumentError( + "May not pass name positionally and as a keyword." + ) + name = l_args.pop(0) # type: ignore + elif l_args[0] is None: + l_args.pop(0) if l_args: coltype = l_args[0] diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 8a0673dc2a..6b10c7fdc7 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -49,6 +49,10 @@ from . import roles from . import traversals from . import type_api from . import visitors +from ._annotated_cols import _ColClauseCC_co +from ._annotated_cols import _KeyColCC_co +from ._annotated_cols import _TC_co +from ._annotated_cols import HasRowPos from ._typing import _ColumnsClauseArgument from ._typing import _no_kw from ._typing import _T @@ -615,7 +619,9 @@ class HasHints: return self -class FromClause(roles.AnonymizedFromClauseRole, Selectable): +class FromClause( + roles.AnonymizedFromClauseRole, Generic[_KeyColCC_co], Selectable +): """Represent an element that can be used within the ``FROM`` clause of a ``SELECT`` statement. @@ -643,7 +649,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): def _hide_froms(self) -> Iterable[FromClause]: return () - _is_clone_of: Optional[FromClause] + _is_clone_of: Optional[FromClause[_KeyColCC_co]] _columns: WriteableColumnCollection[Any, Any] @@ -662,6 +668,21 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): _use_schema_map = False + def with_cols(self, type_: Type[_TC_co]) -> FromClause[_TC_co]: + """Cast this :class:`.FromClause` to be generic on a specific a + :class:`_schema.TypedColumns` subclass. + + At runtime returns self unchanged, without performing any validation. + """ + return self # type: ignore + + @overload + def select( + self: FromClause[HasRowPos[Unpack[_Ts]]], # type: ignore[type-var] + ) -> Select[Unpack[_Ts]]: ... + @overload + def select(self) -> Select[Unpack[TupleAny]]: ... + def select(self) -> Select[Unpack[TupleAny]]: r"""Return a SELECT of this :class:`_expression.FromClause`. @@ -782,7 +803,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): def alias( self, name: Optional[str] = None, flat: bool = False - ) -> NamedFromClause: + ) -> NamedFromClause[_KeyColCC_co]: """Return an alias of this :class:`_expression.FromClause`. E.g.:: @@ -870,9 +891,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): ) @util.ro_non_memoized_property - def exported_columns( - self, - ) -> ReadOnlyColumnCollection[str, KeyedColumnElement[Any]]: + def exported_columns(self) -> _KeyColCC_co: """A :class:`_expression.ColumnCollection` that represents the "exported" columns of this :class:`_expression.FromClause`. @@ -894,9 +913,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): return self.c @util.ro_non_memoized_property - def columns( - self, - ) -> ReadOnlyColumnCollection[str, KeyedColumnElement[Any]]: + def columns(self) -> _KeyColCC_co: """A named-based collection of :class:`_expression.ColumnElement` objects maintained by this :class:`_expression.FromClause`. @@ -912,7 +929,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): return self.c @util.ro_memoized_property - def c(self) -> ReadOnlyColumnCollection[str, KeyedColumnElement[Any]]: + def c(self) -> _KeyColCC_co: """ A synonym for :attr:`.FromClause.columns` @@ -921,7 +938,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): """ if "_columns" not in self.__dict__: self._setup_collections() - return self._columns.as_readonly() + return self._columns.as_readonly() # type: ignore[return-value] def _setup_collections(self) -> None: with util.mini_gil: @@ -1072,7 +1089,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): ) -> Union[FromGrouping, Self]: ... -class NamedFromClause(FromClause): +class NamedFromClause(FromClause[_KeyColCC_co]): """A :class:`.FromClause` that has a name. Examples include tables, subqueries, CTEs, aliased tables. @@ -1115,6 +1132,12 @@ class NamedFromClause(FromClause): """ return TableValuedColumn(self, type_api.TABLEVALUE) + if TYPE_CHECKING: + + def with_cols( + self, type_: type[_TC_co] + ) -> NamedFromClause[_TC_co]: ... + class SelectLabelStyle(Enum): """Label style constants that may be passed to @@ -1244,7 +1267,7 @@ class SelectLabelStyle(Enum): LABEL_STYLE_DEFAULT = LABEL_STYLE_DISAMBIGUATE_ONLY -class Join(roles.DMLTableRole, FromClause): +class Join(roles.DMLTableRole, FromClause[_KeyColCC_co]): """Represent a ``JOIN`` construct between two :class:`_expression.FromClause` elements. @@ -1604,6 +1627,13 @@ class Join(roles.DMLTableRole, FromClause): "join explicitly." % (a.description, b.description) ) + @overload + def select( + self: Join[HasRowPos[Unpack[_Ts]]], # type: ignore[type-var] + ) -> Select[Unpack[_Ts]]: ... + @overload + def select(self) -> Select[Unpack[TupleAny]]: ... + def select(self) -> Select[Unpack[TupleAny]]: r"""Create a :class:`_expression.Select` from this :class:`_expression.Join`. @@ -1677,6 +1707,10 @@ class Join(roles.DMLTableRole, FromClause): self_list: List[FromClause] = [self] return self_list + self.left._from_objects + self.right._from_objects + if TYPE_CHECKING: + + def with_cols(self, type_: type[_TC_co]) -> Join[_TC_co]: ... + class NoInit: def __init__(self, *arg: Any, **kw: Any): @@ -1707,7 +1741,7 @@ class LateralFromClause(NamedFromClause): # -> TableSample -> only for FromClause -class AliasedReturnsRows(NoInit, NamedFromClause): +class AliasedReturnsRows(NoInit, NamedFromClause[_KeyColCC_co]): """Base class of aliases against tables, subqueries, and other selectables.""" @@ -1808,8 +1842,8 @@ class AliasedReturnsRows(NoInit, NamedFromClause): return [self] -class FromClauseAlias(AliasedReturnsRows): - element: FromClause +class FromClauseAlias(AliasedReturnsRows[_KeyColCC_co]): + element: FromClause[_KeyColCC_co] @util.ro_non_memoized_property def description(self) -> str: @@ -1820,7 +1854,7 @@ class FromClauseAlias(AliasedReturnsRows): return name -class Alias(roles.DMLTableRole, FromClauseAlias): +class Alias(roles.DMLTableRole, FromClauseAlias[_KeyColCC_co]): """Represents an table or selectable alias (AS). Represents an alias, as typically applied to any table or @@ -1842,16 +1876,18 @@ class Alias(roles.DMLTableRole, FromClauseAlias): inherit_cache = True - element: FromClause + element: FromClause[_KeyColCC_co] @classmethod def _factory( cls, - selectable: FromClause, + selectable: FromClause[_KeyColCC_co], name: Optional[str] = None, flat: bool = False, - ) -> NamedFromClause: - return coercions.expect(roles.FromClauseRole, selectable).alias( + ) -> NamedFromClause[_KeyColCC_co]: + # mypy refuses to see the overload that has this returning + # NamedFromClause[Any]. Pylance sees it just fine. + return coercions.expect(roles.FromClauseRole, selectable).alias( # type: ignore[no-any-return] # noqa: E501 name=name, flat=flat ) @@ -2137,7 +2173,7 @@ class CTE( Generative, HasPrefixes, HasSuffixes, - AliasedReturnsRows, + AliasedReturnsRows[_KeyColCC_co], ): """Represent a Common Table Expression. @@ -2197,8 +2233,8 @@ class CTE( name: Optional[str] = None, recursive: bool = False, nesting: bool = False, - _cte_alias: Optional[CTE] = None, - _restates: Optional[CTE] = None, + _cte_alias: Optional[CTE[_KeyColCC_co]] = None, + _restates: Optional[CTE[_KeyColCC_co]] = None, _prefixes: Optional[Tuple[()]] = None, _suffixes: Optional[Tuple[()]] = None, ) -> None: @@ -2234,7 +2270,9 @@ class CTE( foreign_keys=foreign_keys, ) - def alias(self, name: Optional[str] = None, flat: bool = False) -> CTE: + def alias( + self, name: Optional[str] = None, flat: bool = False + ) -> CTE[_KeyColCC_co]: """Return an :class:`_expression.Alias` of this :class:`_expression.CTE`. @@ -2258,7 +2296,9 @@ class CTE( _suffixes=self._suffixes, ) - def union(self, *other: _SelectStatementForCompoundArgument[Any]) -> CTE: + def union( + self, *other: _SelectStatementForCompoundArgument[Any] + ) -> CTE[_KeyColCC_co]: r"""Return a new :class:`_expression.CTE` with a SQL ``UNION`` of the original CTE against the given selectables provided as positional arguments. @@ -2289,7 +2329,7 @@ class CTE( def union_all( self, *other: _SelectStatementForCompoundArgument[Any] - ) -> CTE: + ) -> CTE[_KeyColCC_co]: r"""Return a new :class:`_expression.CTE` with a SQL ``UNION ALL`` of the original CTE against the given selectables provided as positional arguments. @@ -2319,7 +2359,7 @@ class CTE( _suffixes=self._suffixes, ) - def _get_reference_cte(self) -> CTE: + def _get_reference_cte(self) -> CTE[_KeyColCC_co]: """ A recursive CTE is updated to attach the recursive part. Updated CTEs should still refer to the original CTE. @@ -2327,6 +2367,10 @@ class CTE( """ return self._restates if self._restates is not None else self + if TYPE_CHECKING: + + def with_cols(self, type_: type[_TC_co]) -> CTE[_TC_co]: ... + class _CTEOpts(NamedTuple): nesting: bool @@ -2965,7 +3009,7 @@ class HasCTE(roles.HasCTERole, SelectsRows): ) -class Subquery(AliasedReturnsRows): +class Subquery(AliasedReturnsRows[_KeyColCC_co]): """Represent a subquery of a SELECT. A :class:`.Subquery` is created by invoking the @@ -3026,26 +3070,24 @@ class Subquery(AliasedReturnsRows): return self.element.set_label_style(LABEL_STYLE_NONE).scalar_subquery() -class FromGrouping(GroupedElement, FromClause): +class FromGrouping(GroupedElement, FromClause[_KeyColCC_co]): """Represent a grouping of a FROM clause""" _traverse_internals: _TraverseInternalsType = [ ("element", InternalTraversal.dp_clauseelement) ] - element: FromClause + element: FromClause[_KeyColCC_co] - def __init__(self, element: FromClause): + def __init__(self, element: FromClause[_KeyColCC_co]): self.element = coercions.expect(roles.FromClauseRole, element) @util.ro_non_memoized_property - def columns( - self, - ) -> ReadOnlyColumnCollection[str, KeyedColumnElement[Any]]: + def columns(self) -> _KeyColCC_co: return self.element.columns @util.ro_non_memoized_property - def c(self) -> ReadOnlyColumnCollection[str, KeyedColumnElement[Any]]: + def c(self) -> _KeyColCC_co: return self.element.columns @property @@ -3061,11 +3103,15 @@ class FromGrouping(GroupedElement, FromClause): def alias( self, name: Optional[str] = None, flat: bool = False - ) -> NamedFromGrouping: + ) -> NamedFromGrouping[_KeyColCC_co]: return NamedFromGrouping(self.element.alias(name=name, flat=flat)) - def _anonymous_fromclause(self, **kw: Any) -> FromGrouping: - return FromGrouping(self.element._anonymous_fromclause(**kw)) + def _anonymous_fromclause( + self, *, name: Optional[str] = None, flat: bool = False + ) -> FromGrouping: + return FromGrouping( + self.element._anonymous_fromclause(name=name, flat=flat) + ) @util.ro_non_memoized_property def _hide_froms(self) -> Iterable[FromClause]: @@ -3075,10 +3121,10 @@ class FromGrouping(GroupedElement, FromClause): def _from_objects(self) -> List[FromClause]: return self.element._from_objects - def __getstate__(self) -> Dict[str, FromClause]: + def __getstate__(self) -> Dict[str, FromClause[_KeyColCC_co]]: return {"element": self.element} - def __setstate__(self, state: Dict[str, FromClause]) -> None: + def __setstate__(self, state: Dict[str, FromClause[_KeyColCC_co]]) -> None: self.element = state["element"] if TYPE_CHECKING: @@ -3088,7 +3134,9 @@ class FromGrouping(GroupedElement, FromClause): ) -> Self: ... -class NamedFromGrouping(FromGrouping, NamedFromClause): +class NamedFromGrouping( + FromGrouping[_KeyColCC_co], NamedFromClause[_KeyColCC_co] +): """represent a grouping of a named FROM clause .. versionadded:: 2.0 @@ -3104,7 +3152,9 @@ class NamedFromGrouping(FromGrouping, NamedFromClause): ) -> Self: ... -class TableClause(roles.DMLTableRole, Immutable, NamedFromClause): +class TableClause( + roles.DMLTableRole, Immutable, NamedFromClause[_ColClauseCC_co] +): """Represents a minimal "table" construct. This is a lightweight table object that has only a name, a @@ -3183,13 +3233,7 @@ class TableClause(roles.DMLTableRole, Immutable, NamedFromClause): if TYPE_CHECKING: - @util.ro_non_memoized_property - def columns( - self, - ) -> ReadOnlyColumnCollection[str, ColumnClause[Any]]: ... - - @util.ro_non_memoized_property - def c(self) -> ReadOnlyColumnCollection[str, ColumnClause[Any]]: ... + def with_cols(self, type_: type[_TC_co]) -> TableClause[_TC_co]: ... def __str__(self) -> str: if self.schema is not None: diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py index affa6c4fa0..50a22026f2 100644 --- a/lib/sqlalchemy/util/langhelpers.py +++ b/lib/sqlalchemy/util/langhelpers.py @@ -2036,7 +2036,7 @@ def attrsetter(attrname): return env["set"] -_dunders = re.compile("^__.+__$") +dunders_re = re.compile("^__.+__$") class TypingOnly: @@ -2050,7 +2050,7 @@ class TypingOnly: def __init_subclass__(cls, **kw: Any) -> None: if TypingOnly in cls.__bases__: remaining = { - name for name in cls.__dict__ if not _dunders.match(name) + name for name in cls.__dict__ if not dunders_re.match(name) } if remaining: raise AssertionError( diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py index 6a0d2ed85c..01bf0a7b3a 100644 --- a/lib/sqlalchemy/util/typing.py +++ b/lib/sqlalchemy/util/typing.py @@ -35,7 +35,6 @@ from typing import Tuple from typing import Type from typing import TYPE_CHECKING from typing import TypeGuard -from typing import TypeVar from typing import Union import typing_extensions @@ -53,6 +52,7 @@ if True: # zimports removes the tailing comments from typing_extensions import Unpack as Unpack # 3.11 from typing_extensions import Never as Never # 3.11 from typing_extensions import LiteralString as LiteralString # 3.11 + from typing_extensions import TypeVar as TypeVar # 3.13 for default _T = TypeVar("_T", bound=Any) diff --git a/test/orm/declarative/test_basic.py b/test/orm/declarative/test_basic.py index 3d921ed5e9..d79417a5c6 100644 --- a/test/orm/declarative/test_basic.py +++ b/test/orm/declarative/test_basic.py @@ -19,6 +19,7 @@ from sqlalchemy import UniqueConstraint from sqlalchemy import Uuid from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import as_declarative +from sqlalchemy.orm import as_typed_table from sqlalchemy.orm import backref from sqlalchemy.orm import class_mapper from sqlalchemy.orm import clear_mappers @@ -44,6 +45,7 @@ from sqlalchemy.orm import relationship from sqlalchemy.orm import Session from sqlalchemy.orm import synonym from sqlalchemy.orm import synonym_for +from sqlalchemy.orm.base import opt_manager_of_class from sqlalchemy.orm.decl_api import add_mapped_attribute from sqlalchemy.orm.decl_api import DeclarativeBaseNoMeta from sqlalchemy.orm.decl_api import DeclarativeMeta @@ -51,6 +53,7 @@ from sqlalchemy.orm.decl_base import _DeferredDeclarativeConfig from sqlalchemy.orm.events import InstrumentationEvents from sqlalchemy.orm.events import MapperEvents from sqlalchemy.schema import PrimaryKeyConstraint +from sqlalchemy.schema import TypedColumns from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import assertions @@ -3458,3 +3461,183 @@ class NamedAttrOrderingTest(fixtures.TestBase): "for argument 'local_table'; got", ): registry().map_imperatively(ImpModel, DecModel) + + +class TypedColumnInteropTest(fixtures.TestBase): + @testing.variation( + "mapping_style", + [ + "decl_base_fn", + "decl_base_base", + "decl_base_no_meta", + "map_declaratively", + "decorator", + "mapped_as_dataclass", + ], + ) + def test_define_typed_columns(self, mapping_style): + if mapping_style.decl_base_fn: + Base = declarative_base() + + class DecModel(Base): + __tablename__ = "foo" + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[str] + + r = Base.registry + elif mapping_style.decl_base_base: + + class Base(DeclarativeBase): + pass + + class DecModel(Base): + __tablename__ = "foo" + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[str] + + r = Base.registry + elif mapping_style.decl_base_no_meta: + + class Base(DeclarativeBaseNoMeta): + pass + + class DecModel(Base): + __tablename__ = "foo" + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[str] + + r = Base.registry + elif mapping_style.decorator: + r = registry() + + @r.mapped + class DecModel: + __tablename__ = "foo" + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[str] + + elif mapping_style.map_declaratively: + r = registry() + + class DecModel: + __tablename__ = "foo" + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[str] + + r.map_declaratively(DecModel) + elif mapping_style.decorator: + r = registry() + + @r.mapped + class DecModel: + __tablename__ = "foo" + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[str] + + elif mapping_style.mapped_as_dataclass: + r = registry() + + @r.mapped_as_dataclass + class DecModel: + __tablename__ = "foo" + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[str] + + else: + assert False + + class the_cols(DecModel, TypedColumns): + pass + + class the_cols2(TypedColumns, DecModel): + pass + + for cls in (the_cols, the_cols2): + assertions.not_in(cls, r._class_registry) + is_(opt_manager_of_class(cls), None) + is_(cls.__mapper__.class_, DecModel) + + def test_define_table_orm(self): + + class Base(DeclarativeBase): + pass + + class DecModel(Base): + __tablename__ = "foo" + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[str] + + with expect_raises_message( + exc.InvalidRequestError, + "The ``typed_columns_cls`` argument requires a " + "TypedColumns subclass", + ): + Table("bar", Base.metadata, DecModel) + + @testing.combinations("foo", "bar", argnames="tablename") + def test_define_new_table_with_cols(self, tablename): + + class Base(DeclarativeBase): + pass + + class DecModel(Base): + __tablename__ = "foo" + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[str] + + class the_cols(DecModel, TypedColumns): + pass + + with expect_raises_message( + exc.InvalidRequestError, + "To get a typed table from an ORM class, use the " + r"`as_typed_table\(\)` function instead", + ): + Table(tablename, Base.metadata, the_cols) + + @testing.variation("assign", [True, False]) + def test_define___typed_cols__(self, assign): + + class Base(DeclarativeBase): + pass + + class DecModel(Base): + __tablename__ = "foo" + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[str] + __typed_cols__: "cols" + + class cols(DecModel, TypedColumns): + pass + + if assign: + DecModel.__typed_cols__ = cols + + assertions.not_in("__typed_cols__", DecModel.__mapper__.attrs) + + def test_as_typed_table(self): + class Base(DeclarativeBase): + pass + + class A(Base): + __tablename__ = "a" + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[str] + + class a_cols(A, TypedColumns): + pass + + t = as_typed_table(A, a_cols) + is_(t, A.__table__) + + class B(Base): + __tablename__ = "b" + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[str] + __typed_cols__: "b_cols" + + class b_cols(A, TypedColumns): + pass + + t2 = as_typed_table(B) + is_(t2, B.__table__) diff --git a/test/sql/test_typed_froms.py b/test/sql/test_typed_froms.py new file mode 100644 index 0000000000..9a2e068aa4 --- /dev/null +++ b/test/sql/test_typed_froms.py @@ -0,0 +1,758 @@ +from typing import Annotated + +import sqlalchemy as sa +from sqlalchemy import Column +from sqlalchemy import Double +from sqlalchemy import Integer +from sqlalchemy import Named +from sqlalchemy import String +from sqlalchemy import Table +from sqlalchemy import TypedColumns +from sqlalchemy.exc import ArgumentError +from sqlalchemy.exc import DuplicateColumnError +from sqlalchemy.exc import InvalidRequestError +from sqlalchemy.testing import combinations +from sqlalchemy.testing import eq_ +from sqlalchemy.testing import expect_raises_message +from sqlalchemy.testing import fixtures +from sqlalchemy.testing import in_ +from sqlalchemy.testing import is_ +from sqlalchemy.testing import is_instance_of +from sqlalchemy.testing import is_not +from sqlalchemy.testing import not_in + + +class TypedTableTest(fixtures.TestBase): + """Test suite for typed table and TypedColumns classes.""" + + def test_table_creation(self, metadata): + """Test that typed_table creates an actual Table in metadata.""" + + class name_does_not_matter(TypedColumns): + id: Named[int] + name: Named[str] + + test_table = Table("test_table", metadata, name_does_not_matter) + + is_instance_of(test_table, sa.Table) + is_(type(test_table), sa.Table) + eq_(test_table.name, "test_table") + in_("test_table", metadata.tables) + is_(metadata.tables["test_table"], test_table) + + def test_empty(self, metadata): + """Test that typed_table creates an actual Table in metadata.""" + + class empty_cols(TypedColumns): + pass + + empty = Table("test_table", metadata, empty_cols) + + is_instance_of(empty, sa.Table) + eq_(len(empty.c), 0) + + def test_simple_columns_with_objects(self, metadata): + """Test table with explicit Column objects.""" + + class users_cols(TypedColumns): + id = Column(Integer, primary_key=True) + name = Column(String(50)) + + users = Table("users", metadata, users_cols, schema="my_schema") + + eq_(users.schema, "my_schema") + eq_(len(users.c), 2) + in_("id", users.c) + in_("name", users.c) + is_(users.c.id.primary_key, True) + is_instance_of(users.c.name.type, String) + eq_(users.c.name.type.length, 50) + + def test_columns_are_copied(self, metadata): + """Test table with explicit Column objects.""" + + class usersCols(TypedColumns): + id = Column(Integer, primary_key=True) + name = Column(String(50)) + + user = Table("users", metadata, usersCols, schema="my_schema") + + is_not(user.c.id, usersCols.id) + is_not(user.c.name, usersCols.name) + is_(usersCols.id.table, None) + is_(usersCols.name.table, None) + + def test_columns_with_annotations_only(self, metadata): + """Test table with type annotations only (no Column objects).""" + + class products_cols(TypedColumns): + id: Column[int] + name: Column[str] + weight: Column[float] + + products = Table("products", metadata, products_cols) + + eq_(len(products.c), 3) + in_("id", products.c) + in_("name", products.c) + in_("weight", products.c) + is_instance_of(products.c.id.type, Integer) + is_instance_of(products.c.name.type, String) + is_instance_of(products.c.weight.type, Double) + + def test_columns_with_annotations_only_named(self, metadata): + """Test table with type annotations only using Named.""" + + class products_cols(TypedColumns): + id: Named[int] + name: Named[str] + weight: Named[float] + + products = Table("products", metadata, products_cols) + + eq_(len(products.c), 3) + in_("id", products.c) + in_("name", products.c) + in_("weight", products.c) + is_instance_of(products.c.id.type, Integer) + is_instance_of(products.c.name.type, String) + is_instance_of(products.c.weight.type, Double) + + def test_mixed_columns_and_annotations(self, metadata): + """Test table with mix of Column objects and annotations.""" + + class items_cols(TypedColumns): + id = Column(Integer, primary_key=True) + name: Column[str] + price: Column[float] + + items = Table("items", metadata, items_cols) + + eq_(len(items.c), 3) + is_(items.c.id.primary_key, True) + in_("name", items.c) + in_("price", items.c) + is_instance_of(items.c.id.type, Integer) + is_instance_of(items.c.name.type, String) + is_instance_of(items.c.price.type, Double) + + def test_annotation_completion(self, metadata): + """Complete column information from annotation.""" + + class items_cols(TypedColumns): + id: Column[int | None] = Column(primary_key=True) + name: Column[str] = Column(String(100)) + price: Column[float] = Column(nullable=True) + + items = Table("items", metadata, items_cols) + + eq_(len(items.c), 3) + is_(items.c.id.primary_key, True) + is_(items.c.id.nullable, False) + in_("name", items.c) + in_("price", items.c) + is_instance_of(items.c.id.type, Integer) + is_instance_of(items.c.name.type, String) + eq_(items.c.name.type.length, 100) + is_instance_of(items.c.price.type, Double) + is_(items.c.price.nullable, True) + + def test_type_from_anno_ignored_when_provided(self, metadata): + """Complete column information from annotation.""" + + class items_cols(TypedColumns): + id: Column[int | None] = Column(primary_key=True) + name: Column[str] = Column(Double) + price: Column[float] = Column(String) + + items = Table("items", metadata, items_cols) + + is_instance_of(items.c.name.type, Double) + is_instance_of(items.c.price.type, String) + + def test_nullable_from_annotation(self, metadata): + """Test nullable inference from Optional annotation.""" + + class records_cols(TypedColumns): + id: Column[int] + description: Column[str | None] + + records = Table("records", metadata, records_cols) + + is_instance_of(records.c.id.type, Integer) + is_(records.c.id.nullable, False) + is_instance_of(records.c.description.type, String) + is_(records.c.description.nullable, True) + + def test_nullable_from_annotation_ignored_when_set(self, metadata): + """Test nullable inference is ignored if nullable is set""" + + class records_cols(TypedColumns): + id: Column[int] = Column(nullable=True) + description: Column[str | None] = Column(nullable=False) + + records = Table("records", metadata, records_cols) + + is_instance_of(records.c.id.type, Integer) + is_(records.c.id.nullable, True) + is_instance_of(records.c.description.type, String) + is_(records.c.description.nullable, False) + + @combinations(True, False, argnames="define_cols") + def test_use_same_typedcols_multiple_times(self, metadata, define_cols): + if define_cols: + + class cols(TypedColumns): + id = Column(Integer) + name = Column(String) + + else: + + class cols(TypedColumns): + id: Column[int] + name: Column[str] + + t1 = Table("t1", metadata, cols) + t2 = Table("t2", metadata, cols) + is_not(t1.c.id, t2.c.id) + is_not(t1.c.name, t2.c.name) + + def test_bad_anno_with_type_provided(self, metadata): + """Test error when no type info is found.""" + + class MyType: + pass + + class ThisIsFine_cols(TypedColumns): + id: Column[MyType] = Column(Double) + name: Column = Column(String) + + ThisIsFine = Table("tbl", metadata, ThisIsFine_cols) + + is_instance_of(ThisIsFine.c.id.type, Double) + is_instance_of(ThisIsFine.c.name.type, String) + + def test_inheritance_from_typed_columns(self, metadata): + """Test column inheritance from parent TypedColumns.""" + + class base_columns(TypedColumns): + id: Column[int] + + class derived_cols(base_columns): + name: Column[str] + + derived = Table("derived", metadata, derived_cols) + + eq_(len(derived.c), 2) + in_("id", derived.c) + in_("name", derived.c) + eq_(derived.c.keys(), ["id", "name"]) # check order + + def test_many_mixin(self, metadata): + class with_name(TypedColumns): + name: Column[str] + + class with_age(TypedColumns): + age: Column[int] + + class person_cols(with_age, with_name): + id = Column(Integer, primary_key=True) + + person = Table("person", metadata, person_cols) + + eq_(person.c.keys(), ["name", "age", "id"]) + + def test_shared_base_columns_different_tables(self, metadata): + """Test that a base TypedColumns can be used in multiple tables + with different instances.""" + + class base_columns(TypedColumns): + id: Column[int] + + class table1_cols(base_columns): + name: Column[str] + + table1 = Table("table1", metadata, table1_cols) + + class table2_cols(base_columns): + name: Column[str] + + table2 = Table("table2", metadata, table2_cols) + + in_("id", table1.c) + in_("id", table2.c) + is_not(table1.c.id, table2.c.id) + + def test_shared_column_with_pk_different_tables(self, metadata): + """Test that base column instances with pk are independent in + different tables.""" + + class base_columns(TypedColumns): + id = Column(Integer, primary_key=True) + + class table1_cols(base_columns): + name: Column[str] + + table1 = Table("table1", metadata, table1_cols) + + class table2_cols(base_columns): + other: Column[str] + + table2 = Table("table2", metadata, table2_cols) + + is_(table1.c.id.primary_key, True) + is_(table2.c.id.primary_key, True) + is_not(table1.c.id, table2.c.id) + is_not(table1.c.id, base_columns.id) + is_not(table2.c.id, base_columns.id) + in_("name", table1.c) + not_in("name", table2.c) + in_("other", table2.c) + not_in("other", table1.c) + + def test_override_column(self, metadata): + """Test that a base TypedColumns can be used in multiple tables + with different instances.""" + + class base_columns(TypedColumns): + id: Column[int] + name: Column[str] + theta: Column[float] + + class mid_columns(base_columns): + name: Column[str | None] # override to make nullable + theta: Column[float] = Column(sa.Numeric(asdecimal=False)) + + class table1_cols(mid_columns): + id: Column[int] = Column(sa.BigInteger) + + table1 = Table("table1", metadata, table1_cols) + + eq_(len(table1.c), 3) + is_instance_of(table1.c.id.type, sa.BigInteger) + is_instance_of(table1.c.name.type, String) + is_(table1.c.name.nullable, True) + is_instance_of(table1.c.theta.type, sa.Numeric) + is_(mid_columns.theta.table, None) + + def test_column_name_and_key_set(self, metadata): + """Test that column name and key are properly set.""" + + class t_cols(TypedColumns): + user_id: Column[int] + + t = Table("t", metadata, t_cols) + + col = t.c.user_id + eq_(col.name, "user_id") + eq_(col.key, "user_id") + + def test_provide_different_col_name(self, metadata): + class t_cols(TypedColumns): + user_id: Column[int] = Column("uid") + + t = Table("t", metadata, t_cols) + in_("user_id", t.c) + not_in("uid", t.c) + col = t.c.user_id + eq_(col.name, "uid") + eq_(col.key, "user_id") + not_in("user_id", str(t.select())) + + def test_provide_different_key(self, metadata): + # this doesn't make a lot of sense, but it's consistent with the orm + class t_cols(TypedColumns): + user_id: Column[int] = Column(key="uid") + + t = Table("t", metadata, t_cols) + in_("uid", t.c) + not_in("user_id", t.c) + col = t.c.uid + eq_(col.name, "user_id") + eq_(col.key, "uid") + not_in("uid", str(t.select())) + + def test_add_more_columns(self, metadata): + + class records_cols(TypedColumns): + id: Column[int] + description: Column[str | None] + + records = Table( + "records", + metadata, + records_cols, + Column("x", Integer), + Column("y", Integer), + Column("z", Integer), + ) + + eq_(records.c.keys(), ["id", "description", "x", "y", "z"]) + + def test_add_constraints(self, metadata): + + class records_cols(TypedColumns): + id: Column[int] + description: Column[str | None] + + records = Table( + "records", + metadata, + records_cols, + Column("x", Integer), + Column("y", Integer), + sa.Index("foo", "id", "y"), + sa.UniqueConstraint("description"), + ) + + eq_(records.c.keys(), ["id", "description", "x", "y"]) + is_(len(records.indexes), 1) + eq_(list(records.indexes)[0].columns, [records.c.id, records.c.y]) + is_(len(records.constraints), 2) # including pk + (uq,) = [ + c + for c in records.constraints + if isinstance(c, sa.UniqueConstraint) + ] + eq_(uq.columns, [records.c.description]) + + def test_init_no_col_no_typed_cols(self, metadata): + tt = Table("a", metadata) + eq_(len(tt.c), 0) + + def test_invalid_non_typed_columns(self, metadata): + """Test that rejects non-TypedColumns subclasses.""" + with expect_raises_message( + InvalidRequestError, "requires a TypedColumns subclass" + ): + + class not_typed_columns: + id = Column(Integer) + + Table("bad", metadata, not_typed_columns) + + with expect_raises_message( + ArgumentError, "'SchemaItem' object, such as a 'Column'" + ): + Table("bad", metadata, 123) # not a class at all + + def test_no_kw_args(self, metadata): + """Test that rejects TypedColumns as kw args.""" + with expect_raises_message( + TypeError, + "The ``typed_columns_cls`` argument may be passed " + "only positionally", + ): + + class not_typed_columns(TypedColumns): + id = Column(Integer) + + Table("bad", metadata, typed_columns_cls=not_typed_columns) + + with expect_raises_message( + ArgumentError, "'SchemaItem' object, such as a 'Column'" + ): + Table("bad", metadata, 123) # not a class at all + + def test_invalid_method_definition(self, metadata): + """Test that TypedColumns rejects method definitions.""" + with expect_raises_message( + InvalidRequestError, "may not define methods" + ): + + class invalid(TypedColumns): + id: Column[int] + + def some_method(self): + pass + + def test_cannot_interpret_annotation(self, metadata): + with expect_raises_message( + ArgumentError, + "Could not interpret annotation this it not valid for " + "attribute 'not_typed_columns.id'", + ): + + class not_typed_columns(TypedColumns): + id: "this it not valid" # noqa + + Table("bad", metadata, not_typed_columns) + + def test_invalid_annotation_type(self, metadata): + """Test error when annotation is not Column[...].""" + + with expect_raises_message( + ArgumentError, + "Annotation for attribute 'bad_anno.id' is not " + "of type Named/Column", + ): + + class bad_anno(TypedColumns): + id: int # Missing Column[...] + + Table("bad_anno", metadata, bad_anno) + + def test_missing_generic_in_column(self, metadata): + with expect_raises_message( + ArgumentError, + "No type information could be extracted from annotation " + " for attribute " + "'bad_anno.id'", + ): + + class bad_anno(TypedColumns): + id: Column # missing generic + + Table("bad_anno", metadata, bad_anno) + + def test_missing_generic_in_named(self, metadata): + with expect_raises_message( + ArgumentError, + "No type information could be extracted from annotation " + " for attribute " + "'bad_anno.id'", + ): + + class bad_anno(TypedColumns): + id: Named # missing generic + + Table("bad_anno", metadata, bad_anno) + + def test_no_pep593(self, metadata): + """Test nullable inference is ignored if nullable is set""" + + class records_cols(TypedColumns): + id: Column[Annotated[int, "x"]] + description: Column[str | None] + + with expect_raises_message( + ArgumentError, + "Could not find a SQL type for type typing.Annotated.+" + " obtained from annotation .+ in attribute 'records_cols.id'", + ): + + Table("records", metadata, records_cols) + + def test_no_pep593_columns(self, metadata): + """Test nullable inference is ignored if nullable is set""" + + class records_cols(TypedColumns): + id: Column[Annotated[int, Column(Integer, primary_key=True)]] + description: Column[str | None] + + with expect_raises_message( + ArgumentError, + "Could not find a SQL type for type typing.Annotated.+" + " obtained from annotation .+ in attribute 'records_cols.id'", + ): + + Table("records", metadata, records_cols) + + def test_unknown_type(self, metadata): + """Test error when no type info is found.""" + + class MyType: + pass + + with expect_raises_message( + ArgumentError, + "Could not find a SQL type for type .*MyType.+ obtained from " + "annotation .* in attribute 'bad_anno.id'", + ): + + class bad_anno(TypedColumns): + id: Column[MyType] + + Table("bad_anno", metadata, bad_anno) + + def test_invalid_annotation_type_provided_column(self, metadata): + """Test error when annotation is not Column[...].""" + + with expect_raises_message( + ArgumentError, + "Annotation for attribute 'bad_anno.id' is not " + "of type Named/Column", + ): + + class bad_anno(TypedColumns): + id: int = Column(Integer) + + Table("bad_anno", metadata, bad_anno) + + def test_missing_generic_in_column_provided_col(self, metadata): + with expect_raises_message( + ArgumentError, + "Python typing annotation is required for attribute " + r"'bad_anno.id' when primary argument\(s\) for Column construct " + "are None or not present", + ): + + class bad_anno(TypedColumns): + id: Column = Column(nullable=False) + + Table("bad_anno", metadata, bad_anno) + + def test_unknown_type_provided_col(self, metadata): + """Test error when no type info is found.""" + + class MyType: + pass + + with expect_raises_message( + ArgumentError, + "Python typing annotation is required for attribute " + r"'bad_anno.id' when primary argument\(s\) for Column construct " + "are None or not present", + ): + + class bad_anno(TypedColumns): + id: Column[MyType] = Column(nullable=False) + + Table("bad_anno", metadata, bad_anno) + + def test_invalid_attribute_value(self, metadata): + """Test error when attribute is neither Column nor annotation.""" + with expect_raises_message(ArgumentError, "Expected a Column"): + + class bad_attr(TypedColumns): + id = 42 # Invalid: not a Column + + Table("bad_attr", metadata, bad_attr) + + def test_cannot_instantiate_typed_columns(self): + """Test that TypedColumns cannot be directly instantiated.""" + + class TestTC(TypedColumns): + pass + + with expect_raises_message(InvalidRequestError, "Cannot instantiate"): + TestTC() + + def test_mix_column_duplicate(self, metadata): + + with expect_raises_message( + DuplicateColumnError, + "A column with name 'y' is already present in table 'records'", + ): + + class records(TypedColumns): + id: Column[int] + y: Column[str | None] + + Table( + "records", + metadata, + records, + Column("x", Integer), + Column("y", Integer), + Column("z", Integer), + ) + + def test_simple_fk(self, metadata): + t = sa.Table("t1", metadata, Column("id", Integer)) + + class t2_cols(TypedColumns): + id: Column[int] + t1_id = Column(sa.ForeignKey("t1.id")) + + t2 = Table("t2", metadata, t2_cols) + + is_instance_of(t2.c.t1_id.type, Integer) + eq_(len(t2.c.t1_id.foreign_keys), 1) + is_(list(t2.c.t1_id.foreign_keys)[0].column, t.c.id) + + def test_simple_fk_many_times(self, metadata): + t = sa.Table("t1", metadata, Column("id", Integer)) + + class cols(TypedColumns): + id: Column[int] + t1_id = Column(sa.ForeignKey("t1.id")) + + t2 = Table("t2", metadata, cols) + t3 = Table("t3", metadata, cols) + t4 = Table("t4", metadata, cols) + + cc = set() + fk = set() + for tx in (t2, t3, t4): + is_not(tx.c.t1_id, cols.t1_id) + eq_(tx.c.t1_id.foreign_keys & cols.t1_id.foreign_keys, set()) + eq_(len(tx.c.t1_id.foreign_keys), 1) + is_(list(tx.c.t1_id.foreign_keys)[0].column, t.c.id) + cc.add(tx.c.t1_id) + fk.update(tx.c.t1_id.foreign_keys) + eq_(len(cc), 3) + eq_(len(fk), 3) + + def test_fk_mixin(self, metadata): + t = sa.Table("t1", metadata, Column("id", Integer)) + + class tid(TypedColumns): + t1_id = Column(sa.ForeignKey("t1.id")) + + class a_cols(tid): + id: Column[int] + + a = Table("a", metadata, a_cols) + + class b_cols(tid): + b: Column[int] + + b = Table("b", metadata, b_cols) + + for tx in (a, b): + eq_(len(tx.c.t1_id.foreign_keys), 1) + is_(list(tx.c.t1_id.foreign_keys)[0].column, t.c.id) + + def test_fk_non_tbl_bound(self, metadata): + + with expect_raises_message( + InvalidRequestError, + "Column 'a.t1_id' with foreign " + "key to non-table-bound columns is not supported " + "when using a TypedColumns. If possible use the " + "qualified string name the column", + ): + + class a(TypedColumns): + id = Column(Integer) + t1_id = Column(sa.ForeignKey(id)) + a: Column[int] + + Table("a", metadata, a) + + def test_fk_mixin_non_tbl_bound(self, metadata): + class tid(TypedColumns): + id = Column(Integer) + t1_id = Column(sa.ForeignKey(id)) + + with expect_raises_message( + InvalidRequestError, + "Column 'tid.t1_id' with foreign " + "key to non-table-bound columns is not supported " + "when using a TypedColumns. If possible use the " + "qualified string name the column", + ): + + class a(tid): + a: Column[int] + + Table("a", metadata, a) + + def test_with_cols(self, metadata): + class cols(TypedColumns): + id = Column(Integer) + x = Column(String) + + t = Table("a", metadata, cols) + is_(t, t.with_cols(cols)) + + class cols2(TypedColumns): + name = Column(Integer) + + is_(t, t.with_cols(cols2)) # no runtime check is performed + + sq = t.select().subquery() + is_(sq, sq.with_cols(cols)) + cte = t.select().cte() + is_(cte, cte.with_cols(cols)) diff --git a/test/typing/plain_files/orm/typed_froms_orm_interop.py b/test/typing/plain_files/orm/typed_froms_orm_interop.py new file mode 100644 index 0000000000..7a0fa1ac91 --- /dev/null +++ b/test/typing/plain_files/orm/typed_froms_orm_interop.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +from typing import Any +from typing import assert_type +from typing import TypeAlias + +from sqlalchemy import MetaData +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import Mapped +from sqlalchemy.orm._orm_constructors import synonym +from sqlalchemy.orm.attributes import InstrumentedAttribute +from sqlalchemy.orm.decl_api import as_typed_table +from sqlalchemy.orm.properties import MappedColumn +from sqlalchemy.sql._annotated_cols import TypedColumns +from sqlalchemy.sql.schema import Column + +T_A: TypeAlias = tuple[Any, ...] +meta = MetaData() + + +class Base(DeclarativeBase): + pass + + +class A(Base): + __tablename__ = "a" + + a: MappedColumn[int] + b: MappedColumn[str] + + x: Mapped[str] = synonym("b") + + +class a_cols(A, TypedColumns): + pass + + +assert_type(A.a, InstrumentedAttribute[int]) +assert_type(A().a, int) +assert_type(A.x, InstrumentedAttribute[str]) + +assert_type(a_cols.a, InstrumentedAttribute[int]) +assert_type(a_cols.b, InstrumentedAttribute[str]) +assert_type(a_cols.x, InstrumentedAttribute[str]) + + +def col_instance(arg: a_cols) -> None: + assert_type(arg.a, Column[int]) + assert_type(arg.b, Column[str]) + assert_type(arg.x, str) + + +def test_as_typed_table() -> None: + # plain class + tbl = as_typed_table(A, a_cols) + assert_type(tbl.c.a, Column[int]) + assert_type(tbl.c.b, Column[str]) + assert_type(tbl.c.metadata, MetaData) # not great but inevitable + + # class with __typed_cols__ + class X(Base): + __tablename__ = "b" + + x: MappedColumn[int] + y: MappedColumn[str] + __typed_cols__: x_cols + + class x_cols(X, TypedColumns): + pass + + tblX = as_typed_table(X) + assert_type(tblX.c.x, Column[int]) + assert_type(tblX.c.y, Column[str]) diff --git a/test/typing/plain_files/sql/typed_froms.py b/test/typing/plain_files/sql/typed_froms.py new file mode 100644 index 0000000000..6ad0a3f915 --- /dev/null +++ b/test/typing/plain_files/sql/typed_froms.py @@ -0,0 +1,214 @@ +from typing import Any +from typing import assert_type + +from sqlalchemy import Column +from sqlalchemy import column +from sqlalchemy import ColumnClause +from sqlalchemy import CTE +from sqlalchemy import Integer +from sqlalchemy import Join +from sqlalchemy import MetaData +from sqlalchemy import Named +from sqlalchemy import Select +from sqlalchemy import select +from sqlalchemy import String +from sqlalchemy import Table +from sqlalchemy import table +from sqlalchemy import TableClause +from sqlalchemy import TypedColumns +from sqlalchemy.sql.base import ReadOnlyColumnCollection +from sqlalchemy.sql.selectable import NamedFromClause +from sqlalchemy.util.typing import Never +from sqlalchemy.util.typing import Unpack + +meta = MetaData() + + +class with_name(TypedColumns): + name: Column[str] + + +class user_col(with_name): + id = Column(Integer, primary_key=True) + age: Column[int] + middle_name: Column[str | None] + + +user = Table("user", meta, user_col) + +assert_type(with_name.name, Never) +assert_type(user, Table[user_col]) +assert_type(user.c.id, Column[int]) +assert_type(user.c.name, Column[str]) +assert_type(user.c.age, Column[int]) +assert_type(user.c.middle_name, Column[str | None]) + +assert_type(select(user.c.age, user.c.name), Select[int, str]) + + +class with_age(TypedColumns): + age: Named[int] + + +class person_col(with_age, with_name): + id: Named[int] = Column(primary_key=True) + + +person = Table("person", meta, person_col) + + +def select_name(table: Table[with_name]) -> Select[str]: + # it's covariant + assert_type(table.c.name, Column[str]) + return select(table.c.name) + + +def default_generic(table: Table) -> None: + assert_type(table.c.name, Column[Any]) + + +def generic_any(table: Table[Any]) -> None: + # TODO: would be nice to have this also be Column[Any] + assert_type(table.c.name, Any) + + +select_name(person) +select_name(user) + +assert_type(person, Table[person_col]) +assert_type(person.c.id, Column[int]) +assert_type(person.c.name, Column[str]) +assert_type(person.c.age, Column[int]) +assert_type(with_age.age, Never) + +assert_type(person.select(), Select[*tuple[Any, ...]]) +assert_type(select(person), Select[*tuple[Any, ...]]) + + +class address_cols(TypedColumns): + user: Named[str] + address: Named[str] + + +address = Table("address", meta, address_cols, Column("extra", Integer)) +assert_type(address, Table[address_cols]) +assert_type(address.c.extra, Column[Any]) +assert_type(address.c["user"], Column[Any]) +assert_type("user" in address.c, bool) +assert_type(address.c.keys(), list[str]) + +plain = Table("a", meta, Column("x", Integer), Column("y", String)) +assert_type(plain, Table[ReadOnlyColumnCollection[str, Column[Any]]]) +assert_type(plain.c.x, Column[Any]) +assert_type(plain.c.y, Column[Any]) + + +class plain_cols(TypedColumns): + x: Named[int | None] + y: Named[str | None] + + +plain_now_typed = plain.with_cols(plain_cols) +assert_type(plain_now_typed, Table[plain_cols]) +assert_type(plain_now_typed.c.x, Column[int | None]) +assert_type(plain_now_typed.c.y, Column[str | None]) + +aa = address.alias() +assert_type(address.c.user, Column[str]) +join = address.join(plain) + +# a join defines a new namespace of cols that is table-prefixed, so +# this part can't be automated +assert_type(join, Join) + + +# but we can cast +class address_join_cols(TypedColumns): + address_user: Named[str] + address_address: Named[str] + + +join_typed = join.with_cols(address_join_cols) +assert_type(join_typed, Join[address_join_cols]) +assert_type(join_typed.c.address_user, Column[str]) +assert_type(join_typed.c.address_x, Column[Any]) + +my_select = select(address.c.user, address.c.address) +my_cte = my_select.cte().with_cols(address_cols) + +assert_type(my_cte, CTE[address_cols]) +assert_type(my_cte.c.address, Column[str]) +my_sq = my_select.subquery().with_cols(address_cols) + +assert_type(my_sq, NamedFromClause[address_cols]) +assert_type(my_sq.c.address, Column[str]) + +alias = person.alias() +assert_type(alias, NamedFromClause[person_col]) +assert_type(alias.with_cols(address_cols), NamedFromClause[address_cols]) + + +class with_name_clause(TypedColumns): + name: ColumnClause[str] + + +lower_table = table("t", column("name", String)).with_cols(with_name_clause) +assert_type(lower_table, TableClause[with_name_clause]) +assert_type(lower_table.c.name, ColumnClause[str]) +lower_table2 = lower_table.with_cols(with_name) +assert_type(lower_table2.c.name, Column[str]) + + +def test_row_pos() -> None: + # no row pos specified, behaves like a normal table + assert_type(select(address), Select[Unpack[tuple[Any, ...]]]) + + class user_cols(TypedColumns): + id: Named[int] + name: Named[str] + age: Named[int] + + __row_pos__: tuple[int, str, int] + + user = Table("user", meta, user_cols) + + class item_cols(TypedColumns): + name: Named[str] + weight: Named[float] + + __row_pos__: tuple[str, float] + + item = Table("item", meta, item_cols) + + assert_type(select(user), Select[int, str, int]) + # NOTE: mypy seems not to understand multiple unpacks... + # https://github.com/python/mypy/issues/20188 + # assert_type(select(item, user), Select[str, float, int, str, int]) + assert_type(select(item, user), Select[str, float]) + assert_type(select(item, user, item), Select[Unpack[tuple[Any, ...]]]) + # col after + assert_type(select(user, person.c.name), Select[int, str, int, str]) + assert_type( + select(user, person.c.name, person.c.id), + Select[int, str, int, str, int], + ) + assert_type( + select(user, person.c.name, person.c.id, person.c.name), + Select[int, str, int, str, int, str], + ) + # col before + assert_type(select(person.c.name, user), Select[str, int, str, int]) + assert_type( + select(person.c.id, person.c.name, person.c.name, user), + Select[int, str, str, int, str, int], + ) + + # select method + assert_type(user.select(), Select[int, str, int]) + assert_type(user.alias().select(), Select[int, str, int]) + join = user.join(item).with_cols(item_cols) + assert_type(join, Join[item_cols]) + # NOTE: mypy does not understand annotations on self + # https://github.com/python/mypy/issues/14243 + # assert_type(join.select(), Select[str, float]) + assert_type(join.select(), Select[Unpack[tuple[Any, ...]]])