From: Mike Bayer Date: Mon, 24 Jan 2022 22:04:27 +0000 (-0500) Subject: establish mypy / typing approach for v2.0 X-Git-Tag: rel_2_0_0b1~487^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=e545298e35ea9f126054b337e4b5ba01988b29f7;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git establish mypy / typing approach for v2.0 large patch to get ORM / typing efforts started. this is to support adding new test cases to mypy, support dropping sqlalchemy2-stubs entirely from the test suite, validate major ORM typing reorganization to eliminate the need for the mypy plugin. * New declarative approach which uses annotation introspection, fixes: #7535 * Mapped[] is now at the base of all ORM constructs that find themselves in classes, to support direct typing without plugins * Mypy plugin updated for new typing structures * Mypy test suite broken out into "plugin" tests vs. "plain" tests, and enhanced to better support test structures where we assert that various objects are introspected by the type checker as we expect. as we go forward with typing, we will add new use cases to "plain" where we can assert that types are introspected as we expect. * For typing support, users will be much more exposed to the class names of things. Add these all to "sqlalchemy" import space. * Column(ForeignKey()) no longer needs to be `@declared_attr` if the FK refers to a remote table * composite() attributes mapped to a dataclass no longer need to implement a `__composite_values__()` method * with_variant() accepts multiple dialect names Change-Id: I22797c0be73a8fbbd2d6f5e0c0b7258b17fe145d Fixes: #7535 Fixes: #7551 References: #6810 --- diff --git a/doc/build/changelog/migration_20.rst b/doc/build/changelog/migration_20.rst index afab782770..f18720ccf8 100644 --- a/doc/build/changelog/migration_20.rst +++ b/doc/build/changelog/migration_20.rst @@ -182,7 +182,7 @@ and pylance. Given a program as below:: from sqlalchemy.dialects.mysql import VARCHAR - type_ = String(255).with_variant(VARCHAR(255, charset='utf8mb4'), "mysql") + type_ = String(255).with_variant(VARCHAR(255, charset='utf8mb4'), "mysql", "mariadb") if typing.TYPE_CHECKING: reveal_type(type_) @@ -191,6 +191,9 @@ A type checker like pyright will now report the type as:: info: Type of "type_" is "String" +In addition, as illustrated above, multiple dialect names may be passed for +single type, in particular this is helpful for the pair of ``"mysql"`` and +``"mariadb"`` dialects which are considered separately as of SQLAlchemy 1.4. :ticket:`6980` diff --git a/doc/build/changelog/unreleased_20/6980.rst b/doc/build/changelog/unreleased_20/6980.rst index d83599c48c..90cf74044d 100644 --- a/doc/build/changelog/unreleased_20/6980.rst +++ b/doc/build/changelog/unreleased_20/6980.rst @@ -10,6 +10,10 @@ behaviors, maintaining the original type allows for clearer type checking and debugging. + :meth:`_sqltypes.TypeEngine.with_variant` also accepts multiple dialect + names per call as well, in particular this is helpful for related + backend names such as ``"mysql", "mariadb"``. + .. seealso:: :ref:`change_6980` diff --git a/doc/build/changelog/unreleased_20/composite_dataclass.rst b/doc/build/changelog/unreleased_20/composite_dataclass.rst new file mode 100644 index 0000000000..a7312b0bd4 --- /dev/null +++ b/doc/build/changelog/unreleased_20/composite_dataclass.rst @@ -0,0 +1,9 @@ +.. change:: + :tags: feature, orm + + The :func:`_orm.composite` mapping construct now supports automatic + resolution of values when used with a Python ``dataclass``; the + ``__composite_values__()`` method no longer needs to be implemented as this + method is derived from inspection of the dataclass. + + See the new documentation at :ref:`mapper_composite` for examples. \ No newline at end of file diff --git a/doc/build/changelog/unreleased_20/decl_fks.rst b/doc/build/changelog/unreleased_20/decl_fks.rst new file mode 100644 index 0000000000..94de46eac6 --- /dev/null +++ b/doc/build/changelog/unreleased_20/decl_fks.rst @@ -0,0 +1,9 @@ +.. change:: + :tags: feature, orm + + Declarative mixins which use :class:`_schema.Column` objects that contain + :class:`_schema.ForeignKey` references no longer need to use + :func:`_orm.declared_attr` to achieve this mapping; the + :class:`_schema.ForeignKey` object is copied along with the + :class:`_schema.Column` itself when the column is applied to the declared + mapping. \ No newline at end of file diff --git a/doc/build/changelog/unreleased_20/prop_name.rst b/doc/build/changelog/unreleased_20/prop_name.rst new file mode 100644 index 0000000000..d085d0ddce --- /dev/null +++ b/doc/build/changelog/unreleased_20/prop_name.rst @@ -0,0 +1,19 @@ +.. change:: + :tags: change, orm + + To better accommodate explicit typing, the names of some ORM constructs + that are typically constructed internally, but nonetheless are sometimes + visible in messaging as well as typing, have been changed to more succinct + names which also match the name of their constructing function (with + different casing), in all cases maintaining aliases to the old names for + the forseeable future: + + * :class:`_orm.RelationshipProperty` becomes an alias for the primary name + :class:`_orm.Relationship`, which is constructed as always from the + :func:`_orm.relationship` function + * :class:`_orm.SynonymProperty` becomes an alias for the primary name + :class:`_orm.Synonym`, constructed as always from the + :func:`_orm.synonym` function + * :class:`_orm.CompositeProperty` becomes an alias for the primary name + :class:`_orm.Composite`, constructed as always from the + :func:`_orm.composite` function \ No newline at end of file diff --git a/doc/build/core/selectable.rst b/doc/build/core/selectable.rst index bad7dc8090..812f6f99a9 100644 --- a/doc/build/core/selectable.rst +++ b/doc/build/core/selectable.rst @@ -156,21 +156,13 @@ Label Style Constants Constants used with the :meth:`_sql.GenerativeSelect.set_label_style` method. -.. autodata:: LABEL_STYLE_DISAMBIGUATE_ONLY +.. autoclass:: SelectLabelStyle + :members: -.. autodata:: LABEL_STYLE_NONE -.. autodata:: LABEL_STYLE_TABLENAME_PLUS_COL +.. seealso:: -.. data:: LABEL_STYLE_DEFAULT + :meth:`_sql.Select.set_label_style` - The default label style, refers to :data:`_sql.LABEL_STYLE_DISAMBIGUATE_ONLY`. - - .. versionadded:: 1.4 - - .. seealso:: - - :meth:`_sql.Select.set_label_style` - - :meth:`_sql.Select.get_label_style` + :meth:`_sql.Select.get_label_style` diff --git a/doc/build/orm/composites.rst b/doc/build/orm/composites.rst index 0628f56aef..463bb70bc3 100644 --- a/doc/build/orm/composites.rst +++ b/doc/build/orm/composites.rst @@ -5,6 +5,11 @@ Composite Column Types ====================== +.. note:: + + This documentation is not yet updated to illustrate the new + typing-annotation syntax or direct support for dataclasses. + Sets of columns can be associated with a single user-defined datatype. The ORM provides a single attribute which represents the group of columns using the class you provide. diff --git a/doc/build/orm/declarative_mixins.rst b/doc/build/orm/declarative_mixins.rst index 9bb4c782e4..e78b966986 100644 --- a/doc/build/orm/declarative_mixins.rst +++ b/doc/build/orm/declarative_mixins.rst @@ -125,47 +125,61 @@ for each separate destination class. To accomplish this, the declarative extension creates a **copy** of each :class:`_schema.Column` object encountered on a class that is detected as a mixin. -This copy mechanism is limited to simple columns that have no foreign -keys, as a :class:`_schema.ForeignKey` itself contains references to columns -which can't be properly recreated at this level. For columns that -have foreign keys, as well as for the variety of mapper-level constructs -that require destination-explicit context, the -:class:`_orm.declared_attr` decorator is provided so that -patterns common to many classes can be defined as callables:: +This copy mechanism is limited to :class:`_schema.Column` and +:class:`_orm.MappedColumn` constructs. For :class:`_schema.Column` and +:class:`_orm.MappedColumn` constructs that contain references to +:class:`_schema.ForeignKey` constructs, the copy mechanism is limited to +foreign key references to remote tables only. + +.. versionchanged:: 2.0 The declarative API can now accommodate + :class:`_schema.Column` objects which refer to :class:`_schema.ForeignKey` + constraints to remote tables without the need to use the + :class:`_orm.declared_attr` function decorator. + +For the variety of mapper-level constructs that require destination-explicit +context, including self-referential foreign keys and constructs like +:func:`_orm.deferred`, :func:`_orm.relationship`, etc, the +:class:`_orm.declared_attr` decorator is provided so that patterns common to +many classes can be defined as callables:: from sqlalchemy.orm import declared_attr @declarative_mixin - class ReferenceAddressMixin: + class HasRelatedDataMixin: @declared_attr - def address_id(cls): - return Column(Integer, ForeignKey('address.id')) + def related_data(cls): + return deferred(Column(Text()) - class User(ReferenceAddressMixin, Base): + class User(HasRelatedDataMixin, Base): __tablename__ = 'user' id = Column(Integer, primary_key=True) -Where above, the ``address_id`` class-level callable is executed at the +Where above, the ``related_data`` class-level callable is executed at the point at which the ``User`` class is constructed, and the declarative -extension can use the resulting :class:`_schema.Column` object as returned by +extension can use the resulting :func`_orm.deferred` object as returned by the method without the need to copy it. -Columns generated by :class:`_orm.declared_attr` can also be -referenced by ``__mapper_args__`` to a limited degree, currently -by ``polymorphic_on`` and ``version_id_col``; the declarative extension -will resolve them at class construction time:: +For a self-referential foreign key on a mixin, the referenced +:class:`_schema.Column` object may be referenced in terms of the class directly +within the :class:`_orm.declared_attr`:: - @declarative_mixin - class MyMixin: - @declared_attr - def type_(cls): - return Column(String(50)) + class SelfReferentialMixin: + id = Column(Integer, primary_key=True) - __mapper_args__= {'polymorphic_on':type_} + @declared_attr + def parent_id(cls): + return Column(Integer, ForeignKey(cls.id)) + + class A(SelfReferentialMixin, Base): + __tablename__ = 'a' - class MyModel(MyMixin, Base): - __tablename__='test' - id = Column(Integer, primary_key=True) + + class B(SelfReferentialMixin, Base): + __tablename__ = 'b' + +Above, both classes ``A`` and ``B`` will contain columns ``id`` and +``parent_id``, where ``parent_id`` refers to the ``id`` column local to the +corresponding table ('a' or 'b'). .. _orm_declarative_mixins_relationships: @@ -182,9 +196,7 @@ reference a common target class via many-to-one:: @declarative_mixin class RefTargetMixin: - @declared_attr - def target_id(cls): - return Column('target_id', ForeignKey('target.id')) + target_id = Column('target_id', ForeignKey('target.id')) @declared_attr def target(cls): diff --git a/doc/build/orm/internals.rst b/doc/build/orm/internals.rst index 8520fd07c1..05cf83b394 100644 --- a/doc/build/orm/internals.rst +++ b/doc/build/orm/internals.rst @@ -32,9 +32,10 @@ sections, are listed here. :ref:`maptojoin` - usage example -.. autoclass:: CompositeProperty +.. autoclass:: Composite :members: +.. autodata:: CompositeProperty .. autoclass:: AttributeEvent :members: @@ -62,6 +63,8 @@ sections, are listed here. .. autoclass:: Mapped +.. autoclass:: MappedColumn + .. autoclass:: MapperProperty :members: @@ -98,14 +101,18 @@ sections, are listed here. :members: :inherited-members: -.. autoclass:: RelationshipProperty +.. autoclass:: Relationship :members: :inherited-members: -.. autoclass:: SynonymProperty +.. autodata:: RelationshipProperty + +.. autoclass:: Synonym :members: :inherited-members: +.. autodata:: SynonymProperty + .. autoclass:: QueryContext :members: diff --git a/doc/build/orm/loading_relationships.rst b/doc/build/orm/loading_relationships.rst index 2b93bc84af..773409f027 100644 --- a/doc/build/orm/loading_relationships.rst +++ b/doc/build/orm/loading_relationships.rst @@ -1261,8 +1261,6 @@ Relationship Loader API .. autofunction:: defaultload -.. autofunction:: eagerload - .. autofunction:: immediateload .. autofunction:: joinedload diff --git a/doc/build/orm/relationship_api.rst b/doc/build/orm/relationship_api.rst index 2766c4020a..ac584627f9 100644 --- a/doc/build/orm/relationship_api.rst +++ b/doc/build/orm/relationship_api.rst @@ -7,8 +7,6 @@ Relationships API .. autofunction:: backref -.. autofunction:: relation - .. autofunction:: dynamic_loader .. autofunction:: foreign diff --git a/lib/sqlalchemy/__init__.py b/lib/sqlalchemy/__init__.py index c8ec1d8250..eadb427d0d 100644 --- a/lib/sqlalchemy/__init__.py +++ b/lib/sqlalchemy/__init__.py @@ -6,10 +6,56 @@ # the MIT License: https://www.opensource.org/licenses/mit-license.php from . import util as _util +from .engine import AdaptedConnection as AdaptedConnection +from .engine import BaseCursorResult as BaseCursorResult +from .engine import BaseRow as BaseRow +from .engine import BindTyping as BindTyping +from .engine import BufferedColumnResultProxy as BufferedColumnResultProxy +from .engine import BufferedColumnRow as BufferedColumnRow +from .engine import BufferedRowResultProxy as BufferedRowResultProxy +from .engine import ChunkedIteratorResult as ChunkedIteratorResult +from .engine import Compiled as Compiled +from .engine import Connection as Connection from .engine import create_engine as create_engine from .engine import create_mock_engine as create_mock_engine +from .engine import CreateEnginePlugin as CreateEnginePlugin +from .engine import CursorResult as CursorResult +from .engine import Dialect as Dialect +from .engine import Engine as Engine from .engine import engine_from_config as engine_from_config +from .engine import ExceptionContext as ExceptionContext +from .engine import ExecutionContext as ExecutionContext +from .engine import FrozenResult as FrozenResult +from .engine import FullyBufferedResultProxy as FullyBufferedResultProxy +from .engine import Inspector as Inspector +from .engine import IteratorResult as IteratorResult +from .engine import make_url as make_url +from .engine import MappingResult as MappingResult +from .engine import MergedResult as MergedResult +from .engine import NestedTransaction as NestedTransaction +from .engine import Result as Result +from .engine import result_tuple as result_tuple +from .engine import ResultProxy as ResultProxy +from .engine import RootTransaction as RootTransaction +from .engine import Row as Row +from .engine import RowMapping as RowMapping +from .engine import ScalarResult as ScalarResult +from .engine import Transaction as Transaction +from .engine import TwoPhaseTransaction as TwoPhaseTransaction +from .engine import TypeCompiler as TypeCompiler +from .engine import URL as URL from .inspection import inspect as inspect +from .pool import AssertionPool as AssertionPool +from .pool import AsyncAdaptedQueuePool as AsyncAdaptedQueuePool +from .pool import ( + FallbackAsyncAdaptedQueuePool as FallbackAsyncAdaptedQueuePool, +) +from .pool import NullPool as NullPool +from .pool import Pool as Pool +from .pool import PoolProxiedConnection as PoolProxiedConnection +from .pool import QueuePool as QueuePool +from .pool import SingletonThreadPool as SingleonThreadPool +from .pool import StaticPool as StaticPool from .schema import BLANK_SCHEMA as BLANK_SCHEMA from .schema import CheckConstraint as CheckConstraint from .schema import Column as Column @@ -28,67 +74,139 @@ from .schema import PrimaryKeyConstraint as PrimaryKeyConstraint from .schema import Sequence as Sequence from .schema import Table as Table from .schema import UniqueConstraint as UniqueConstraint -from .sql import alias as alias -from .sql import all_ as all_ -from .sql import and_ as and_ -from .sql import any_ as any_ -from .sql import asc as asc -from .sql import between as between -from .sql import bindparam as bindparam -from .sql import case as case -from .sql import cast as cast -from .sql import collate as collate -from .sql import column as column -from .sql import delete as delete -from .sql import desc as desc -from .sql import distinct as distinct -from .sql import except_ as except_ -from .sql import except_all as except_all -from .sql import exists as exists -from .sql import extract as extract -from .sql import false as false -from .sql import func as func -from .sql import funcfilter as funcfilter -from .sql import insert as insert -from .sql import intersect as intersect -from .sql import intersect_all as intersect_all -from .sql import join as join -from .sql import label as label -from .sql import LABEL_STYLE_DEFAULT as LABEL_STYLE_DEFAULT -from .sql import ( +from .sql import SelectLabelStyle as SelectLabelStyle +from .sql.expression import Alias as Alias +from .sql.expression import alias as alias +from .sql.expression import AliasedReturnsRows as AliasedReturnsRows +from .sql.expression import all_ as all_ +from .sql.expression import and_ as and_ +from .sql.expression import any_ as any_ +from .sql.expression import asc as asc +from .sql.expression import between as between +from .sql.expression import BinaryExpression as BinaryExpression +from .sql.expression import bindparam as bindparam +from .sql.expression import BindParameter as BindParameter +from .sql.expression import BooleanClauseList as BooleanClauseList +from .sql.expression import CacheKey as CacheKey +from .sql.expression import Case as Case +from .sql.expression import case as case +from .sql.expression import Cast as Cast +from .sql.expression import cast as cast +from .sql.expression import ClauseElement as ClauseElement +from .sql.expression import ClauseList as ClauseList +from .sql.expression import collate as collate +from .sql.expression import CollectionAggregate as CollectionAggregate +from .sql.expression import column as column +from .sql.expression import ColumnClause as ColumnClause +from .sql.expression import ColumnCollection as ColumnCollection +from .sql.expression import ColumnElement as ColumnElement +from .sql.expression import ColumnOperators as ColumnOperators +from .sql.expression import CompoundSelect as CompoundSelect +from .sql.expression import CTE as CTE +from .sql.expression import cte as cte +from .sql.expression import custom_op as custom_op +from .sql.expression import Delete as Delete +from .sql.expression import delete as delete +from .sql.expression import desc as desc +from .sql.expression import distinct as distinct +from .sql.expression import except_ as except_ +from .sql.expression import except_all as except_all +from .sql.expression import Executable as Executable +from .sql.expression import Exists as Exists +from .sql.expression import exists as exists +from .sql.expression import Extract as Extract +from .sql.expression import extract as extract +from .sql.expression import false as false +from .sql.expression import False_ as False_ +from .sql.expression import FromClause as FromClause +from .sql.expression import FromGrouping as FromGrouping +from .sql.expression import func as func +from .sql.expression import funcfilter as funcfilter +from .sql.expression import Function as Function +from .sql.expression import FunctionElement as FunctionElement +from .sql.expression import FunctionFilter as FunctionFilter +from .sql.expression import GenerativeSelect as GenerativeSelect +from .sql.expression import Grouping as Grouping +from .sql.expression import HasCTE as HasCTE +from .sql.expression import HasPrefixes as HasPrefixes +from .sql.expression import HasSuffixes as HasSuffixes +from .sql.expression import Insert as Insert +from .sql.expression import insert as insert +from .sql.expression import intersect as intersect +from .sql.expression import intersect_all as intersect_all +from .sql.expression import Join as Join +from .sql.expression import join as join +from .sql.expression import Label as Label +from .sql.expression import label as label +from .sql.expression import LABEL_STYLE_DEFAULT as LABEL_STYLE_DEFAULT +from .sql.expression import ( LABEL_STYLE_DISAMBIGUATE_ONLY as LABEL_STYLE_DISAMBIGUATE_ONLY, ) -from .sql import LABEL_STYLE_NONE as LABEL_STYLE_NONE -from .sql import ( +from .sql.expression import LABEL_STYLE_NONE as LABEL_STYLE_NONE +from .sql.expression import ( LABEL_STYLE_TABLENAME_PLUS_COL as LABEL_STYLE_TABLENAME_PLUS_COL, ) -from .sql import lambda_stmt as lambda_stmt -from .sql import lateral as lateral -from .sql import literal as literal -from .sql import literal_column as literal_column -from .sql import modifier as modifier -from .sql import not_ as not_ -from .sql import null as null -from .sql import nulls_first as nulls_first -from .sql import nulls_last as nulls_last -from .sql import nullsfirst as nullsfirst -from .sql import nullslast as nullslast -from .sql import or_ as or_ -from .sql import outerjoin as outerjoin -from .sql import outparam as outparam -from .sql import over as over -from .sql import select as select -from .sql import table as table -from .sql import tablesample as tablesample -from .sql import text as text -from .sql import true as true -from .sql import tuple_ as tuple_ -from .sql import type_coerce as type_coerce -from .sql import union as union -from .sql import union_all as union_all -from .sql import update as update -from .sql import values as values -from .sql import within_group as within_group +from .sql.expression import lambda_stmt as lambda_stmt +from .sql.expression import LambdaElement as LambdaElement +from .sql.expression import Lateral as Lateral +from .sql.expression import lateral as lateral +from .sql.expression import literal as literal +from .sql.expression import literal_column as literal_column +from .sql.expression import modifier as modifier +from .sql.expression import not_ as not_ +from .sql.expression import Null as Null +from .sql.expression import null as null +from .sql.expression import nulls_first as nulls_first +from .sql.expression import nulls_last as nulls_last +from .sql.expression import Operators as Operators +from .sql.expression import or_ as or_ +from .sql.expression import outerjoin as outerjoin +from .sql.expression import outparam as outparam +from .sql.expression import Over as Over +from .sql.expression import over as over +from .sql.expression import quoted_name as quoted_name +from .sql.expression import ReleaseSavepointClause as ReleaseSavepointClause +from .sql.expression import ReturnsRows as ReturnsRows +from .sql.expression import ( + RollbackToSavepointClause as RollbackToSavepointClause, +) +from .sql.expression import SavepointClause as SavepointClause +from .sql.expression import ScalarSelect as ScalarSelect +from .sql.expression import Select as Select +from .sql.expression import select as select +from .sql.expression import Selectable as Selectable +from .sql.expression import SelectBase as SelectBase +from .sql.expression import StatementLambdaElement as StatementLambdaElement +from .sql.expression import Subquery as Subquery +from .sql.expression import table as table +from .sql.expression import TableClause as TableClause +from .sql.expression import TableSample as TableSample +from .sql.expression import tablesample as tablesample +from .sql.expression import TableValuedAlias as TableValuedAlias +from .sql.expression import text as text +from .sql.expression import TextAsFrom as TextAsFrom +from .sql.expression import TextClause as TextClause +from .sql.expression import TextualSelect as TextualSelect +from .sql.expression import true as true +from .sql.expression import True_ as True_ +from .sql.expression import Tuple as Tuple +from .sql.expression import tuple_ as tuple_ +from .sql.expression import type_coerce as type_coerce +from .sql.expression import TypeClause as TypeClause +from .sql.expression import TypeCoerce as TypeCoerce +from .sql.expression import typing as typing +from .sql.expression import UnaryExpression as UnaryExpression +from .sql.expression import union as union +from .sql.expression import union_all as union_all +from .sql.expression import Update as Update +from .sql.expression import update as update +from .sql.expression import UpdateBase as UpdateBase +from .sql.expression import Values as Values +from .sql.expression import values as values +from .sql.expression import ValuesBase as ValuesBase +from .sql.expression import Visitable as Visitable +from .sql.expression import within_group as within_group +from .sql.expression import WithinGroup as WithinGroup from .types import ARRAY as ARRAY from .types import BIGINT as BIGINT from .types import BigInteger as BigInteger @@ -133,7 +251,6 @@ from .types import UnicodeText as UnicodeText from .types import VARBINARY as VARBINARY from .types import VARCHAR as VARCHAR - __version__ = "2.0.0b1" diff --git a/lib/sqlalchemy/engine/__init__.py b/lib/sqlalchemy/engine/__init__.py index e934f9f891..c6bc4b6aa6 100644 --- a/lib/sqlalchemy/engine/__init__.py +++ b/lib/sqlalchemy/engine/__init__.py @@ -15,45 +15,45 @@ constructor ``create_engine()``. """ -from . import events -from . import util -from .base import Connection -from .base import Engine -from .base import NestedTransaction -from .base import RootTransaction -from .base import Transaction -from .base import TwoPhaseTransaction -from .create import create_engine -from .create import engine_from_config -from .cursor import BaseCursorResult -from .cursor import BufferedColumnResultProxy -from .cursor import BufferedColumnRow -from .cursor import BufferedRowResultProxy -from .cursor import CursorResult -from .cursor import FullyBufferedResultProxy -from .cursor import ResultProxy -from .interfaces import AdaptedConnection -from .interfaces import BindTyping -from .interfaces import Compiled -from .interfaces import CreateEnginePlugin -from .interfaces import Dialect -from .interfaces import ExceptionContext -from .interfaces import ExecutionContext -from .interfaces import TypeCompiler -from .mock import create_mock_engine -from .reflection import Inspector -from .result import ChunkedIteratorResult -from .result import FrozenResult -from .result import IteratorResult -from .result import MappingResult -from .result import MergedResult -from .result import Result -from .result import result_tuple -from .result import ScalarResult -from .row import BaseRow -from .row import Row -from .row import RowMapping -from .url import make_url -from .url import URL -from .util import connection_memoize -from ..sql import ddl +from . import events as events +from . import util as util +from .base import Connection as Connection +from .base import Engine as Engine +from .base import NestedTransaction as NestedTransaction +from .base import RootTransaction as RootTransaction +from .base import Transaction as Transaction +from .base import TwoPhaseTransaction as TwoPhaseTransaction +from .create import create_engine as create_engine +from .create import engine_from_config as engine_from_config +from .cursor import BaseCursorResult as BaseCursorResult +from .cursor import BufferedColumnResultProxy as BufferedColumnResultProxy +from .cursor import BufferedColumnRow as BufferedColumnRow +from .cursor import BufferedRowResultProxy as BufferedRowResultProxy +from .cursor import CursorResult as CursorResult +from .cursor import FullyBufferedResultProxy as FullyBufferedResultProxy +from .cursor import ResultProxy as ResultProxy +from .interfaces import AdaptedConnection as AdaptedConnection +from .interfaces import BindTyping as BindTyping +from .interfaces import Compiled as Compiled +from .interfaces import CreateEnginePlugin as CreateEnginePlugin +from .interfaces import Dialect as Dialect +from .interfaces import ExceptionContext as ExceptionContext +from .interfaces import ExecutionContext as ExecutionContext +from .interfaces import TypeCompiler as TypeCompiler +from .mock import create_mock_engine as create_mock_engine +from .reflection import Inspector as Inspector +from .result import ChunkedIteratorResult as ChunkedIteratorResult +from .result import FrozenResult as FrozenResult +from .result import IteratorResult as IteratorResult +from .result import MappingResult as MappingResult +from .result import MergedResult as MergedResult +from .result import Result as Result +from .result import result_tuple as result_tuple +from .result import ScalarResult as ScalarResult +from .row import BaseRow as BaseRow +from .row import Row as Row +from .row import RowMapping as RowMapping +from .url import make_url as make_url +from .url import URL as URL +from .util import connection_memoize as connection_memoize +from ..sql import ddl as ddl diff --git a/lib/sqlalchemy/engine/create.py b/lib/sqlalchemy/engine/create.py index 6fb8279894..2f8ce17df9 100644 --- a/lib/sqlalchemy/engine/create.py +++ b/lib/sqlalchemy/engine/create.py @@ -6,6 +6,7 @@ # the MIT License: https://www.opensource.org/licenses/mit-license.php from typing import Any +from typing import Union from . import base from . import url as _url @@ -41,7 +42,7 @@ from ..sql import compiler "is deprecated and will be removed in a future release. ", ), ) -def create_engine(url: "_url.URL", **kwargs: Any) -> "base.Engine": +def create_engine(url: Union[str, "_url.URL"], **kwargs: Any) -> "base.Engine": """Create a new :class:`_engine.Engine` instance. The standard calling form is to send the :ref:`URL ` as the diff --git a/lib/sqlalchemy/engine/reflection.py b/lib/sqlalchemy/engine/reflection.py index df7a53ab7d..882392e9c2 100644 --- a/lib/sqlalchemy/engine/reflection.py +++ b/lib/sqlalchemy/engine/reflection.py @@ -24,11 +24,13 @@ methods such as get_table_names, get_columns, etc. use the key 'name'. So for most return values, each record will have a 'name' attribute.. """ - import contextlib +from typing import List +from typing import Optional from .base import Connection from .base import Engine +from .interfaces import ReflectedColumn from .. import exc from .. import inspection from .. import sql @@ -433,7 +435,9 @@ class Inspector(inspection.Inspectable["Inspector"]): conn, view_name, schema, info_cache=self.info_cache ) - def get_columns(self, table_name, schema=None, **kw): + def get_columns( + self, table_name: str, schema: Optional[str] = None, **kw + ) -> List[ReflectedColumn]: """Return information about columns in `table_name`. Given a string `table_name` and an optional string `schema`, return diff --git a/lib/sqlalchemy/ext/associationproxy.py b/lib/sqlalchemy/ext/associationproxy.py index e6a826c649..d5119907ed 100644 --- a/lib/sqlalchemy/ext/associationproxy.py +++ b/lib/sqlalchemy/ext/associationproxy.py @@ -361,7 +361,7 @@ class AssociationProxyInstance: prop = orm.class_mapper(owning_class).get_property(target_collection) # this was never asserted before but this should be made clear. - if not isinstance(prop, orm.RelationshipProperty): + if not isinstance(prop, orm.Relationship): raise NotImplementedError( "association proxy to a non-relationship " "intermediary is not supported" @@ -717,8 +717,8 @@ class AssociationProxyInstance: """Produce a proxied 'any' expression using EXISTS. This expression will be a composed product - using the :meth:`.RelationshipProperty.Comparator.any` - and/or :meth:`.RelationshipProperty.Comparator.has` + using the :meth:`.Relationship.Comparator.any` + and/or :meth:`.Relationship.Comparator.has` operators of the underlying proxied attributes. """ @@ -737,8 +737,8 @@ class AssociationProxyInstance: """Produce a proxied 'has' expression using EXISTS. This expression will be a composed product - using the :meth:`.RelationshipProperty.Comparator.any` - and/or :meth:`.RelationshipProperty.Comparator.has` + using the :meth:`.Relationship.Comparator.any` + and/or :meth:`.Relationship.Comparator.has` operators of the underlying proxied attributes. """ @@ -859,9 +859,9 @@ class ObjectAssociationProxyInstance(AssociationProxyInstance): """Produce a proxied 'contains' expression using EXISTS. This expression will be a composed product - using the :meth:`.RelationshipProperty.Comparator.any`, - :meth:`.RelationshipProperty.Comparator.has`, - and/or :meth:`.RelationshipProperty.Comparator.contains` + using the :meth:`.Relationship.Comparator.any`, + :meth:`.Relationship.Comparator.has`, + and/or :meth:`.Relationship.Comparator.contains` operators of the underlying proxied attributes. """ diff --git a/lib/sqlalchemy/ext/declarative/extensions.py b/lib/sqlalchemy/ext/declarative/extensions.py index 5aff4dfe27..470ff6ad88 100644 --- a/lib/sqlalchemy/ext/declarative/extensions.py +++ b/lib/sqlalchemy/ext/declarative/extensions.py @@ -378,7 +378,7 @@ class DeferredReflection: metadata = mapper.class_.metadata for rel in mapper._props.values(): if ( - isinstance(rel, relationships.RelationshipProperty) + isinstance(rel, relationships.Relationship) and rel.secondary is not None ): if isinstance(rel.secondary, Table): diff --git a/lib/sqlalchemy/ext/mypy/apply.py b/lib/sqlalchemy/ext/mypy/apply.py index 99be194cdc..4e244b5b9e 100644 --- a/lib/sqlalchemy/ext/mypy/apply.py +++ b/lib/sqlalchemy/ext/mypy/apply.py @@ -36,6 +36,7 @@ from mypy.types import UnionType from . import infer from . import util +from .names import expr_to_mapped_constructor from .names import NAMED_TYPE_SQLA_MAPPED @@ -117,6 +118,7 @@ def re_apply_declarative_assignments( ): left_node = stmt.lvalues[0].node + python_type_for_type = mapped_attr_lookup[ stmt.lvalues[0].name ].type @@ -142,7 +144,7 @@ def re_apply_declarative_assignments( ) ): - python_type_for_type = ( + new_python_type_for_type = ( infer.infer_type_from_right_hand_nameexpr( api, stmt, @@ -152,19 +154,27 @@ def re_apply_declarative_assignments( ) ) - if python_type_for_type is None or isinstance( - python_type_for_type, UnboundType + if new_python_type_for_type is not None and not isinstance( + new_python_type_for_type, UnboundType ): - continue + python_type_for_type = new_python_type_for_type - # update the SQLAlchemyAttribute with the better information - mapped_attr_lookup[ - stmt.lvalues[0].name - ].type = python_type_for_type + # update the SQLAlchemyAttribute with the better + # information + mapped_attr_lookup[ + stmt.lvalues[0].name + ].type = python_type_for_type - update_cls_metadata = True + update_cls_metadata = True - if python_type_for_type is not None: + # for some reason if you have a Mapped type explicitly annotated, + # and here you set it again, mypy forgets how to do descriptors. + # no idea. 100% feeling around in the dark to see what sticks + if ( + not isinstance(left_node.type, Instance) + or left_node.type.type.fullname != NAMED_TYPE_SQLA_MAPPED + ): + assert python_type_for_type is not None left_node.type = api.named_type( NAMED_TYPE_SQLA_MAPPED, [python_type_for_type] ) @@ -202,6 +212,7 @@ def apply_type_to_mapped_statement( assert isinstance(left_node, Var) if left_hand_explicit_type is not None: + lvalue.is_inferred_def = False left_node.type = api.named_type( NAMED_TYPE_SQLA_MAPPED, [left_hand_explicit_type] ) @@ -224,7 +235,7 @@ def apply_type_to_mapped_statement( # _sa_Mapped._empty_constructor() # the original right-hand side is maintained so it gets type checked # internally - stmt.rvalue = util.expr_to_mapped_constructor(stmt.rvalue) + stmt.rvalue = expr_to_mapped_constructor(stmt.rvalue) def add_additional_orm_attributes( diff --git a/lib/sqlalchemy/ext/mypy/decl_class.py b/lib/sqlalchemy/ext/mypy/decl_class.py index c33c30e257..bd6c6f41e8 100644 --- a/lib/sqlalchemy/ext/mypy/decl_class.py +++ b/lib/sqlalchemy/ext/mypy/decl_class.py @@ -337,7 +337,7 @@ def _scan_declarative_decorator_stmt( # : Mapped[] = # _sa_Mapped._empty_constructor(lambda: ) # the function body is maintained so it gets type checked internally - rvalue = util.expr_to_mapped_constructor( + rvalue = names.expr_to_mapped_constructor( LambdaExpr(stmt.func.arguments, stmt.func.body) ) diff --git a/lib/sqlalchemy/ext/mypy/infer.py b/lib/sqlalchemy/ext/mypy/infer.py index 3cd946e04d..6a5e99e480 100644 --- a/lib/sqlalchemy/ext/mypy/infer.py +++ b/lib/sqlalchemy/ext/mypy/infer.py @@ -42,11 +42,13 @@ def infer_type_from_right_hand_nameexpr( left_hand_explicit_type: Optional[ProperType], infer_from_right_side: RefExpr, ) -> Optional[ProperType]: - type_id = names.type_id_for_callee(infer_from_right_side) - if type_id is None: return None + elif type_id is names.MAPPED: + python_type_for_type = _infer_type_from_mapped( + api, stmt, node, left_hand_explicit_type, infer_from_right_side + ) elif type_id is names.COLUMN: python_type_for_type = _infer_type_from_decl_column( api, stmt, node, left_hand_explicit_type @@ -245,7 +247,7 @@ def _infer_type_from_decl_composite_property( node: Var, left_hand_explicit_type: Optional[ProperType], ) -> Optional[ProperType]: - """Infer the type of mapping from a CompositeProperty.""" + """Infer the type of mapping from a Composite.""" assert isinstance(stmt.rvalue, CallExpr) target_cls_arg = stmt.rvalue.args[0] @@ -271,6 +273,38 @@ def _infer_type_from_decl_composite_property( return python_type_for_type +def _infer_type_from_mapped( + api: SemanticAnalyzerPluginInterface, + stmt: AssignmentStmt, + node: Var, + left_hand_explicit_type: Optional[ProperType], + infer_from_right_side: RefExpr, +) -> Optional[ProperType]: + """Infer the type of mapping from a right side expression + that returns Mapped. + + + """ + assert isinstance(stmt.rvalue, CallExpr) + + # (Pdb) print(stmt.rvalue.callee) + # NameExpr(query_expression [sqlalchemy.orm._orm_constructors.query_expression]) # noqa: E501 + # (Pdb) stmt.rvalue.callee.node + # + # (Pdb) stmt.rvalue.callee.node.type + # def [_T] (default_expr: sqlalchemy.sql.elements.ColumnElement[_T`-1] =) -> sqlalchemy.orm.base.Mapped[_T`-1] # noqa: E501 + # sqlalchemy.orm.base.Mapped[_T`-1] + # the_mapped_type = stmt.rvalue.callee.node.type.ret_type + + # TODO: look at generic ref and either use that, + # or reconcile w/ what's present, etc. + the_mapped_type = util.type_for_callee(infer_from_right_side) # noqa + + return infer_type_from_left_hand_type_only( + api, node, left_hand_explicit_type + ) + + def _infer_type_from_decl_column_property( api: SemanticAnalyzerPluginInterface, stmt: AssignmentStmt, diff --git a/lib/sqlalchemy/ext/mypy/names.py b/lib/sqlalchemy/ext/mypy/names.py index b6f911979c..ad4449e5bb 100644 --- a/lib/sqlalchemy/ext/mypy/names.py +++ b/lib/sqlalchemy/ext/mypy/names.py @@ -12,11 +12,14 @@ from typing import Set from typing import Tuple from typing import Union +from mypy.nodes import ARG_POS +from mypy.nodes import CallExpr from mypy.nodes import ClassDef from mypy.nodes import Expression from mypy.nodes import FuncDef from mypy.nodes import MemberExpr from mypy.nodes import NameExpr +from mypy.nodes import OverloadedFuncDef from mypy.nodes import SymbolNode from mypy.nodes import TypeAlias from mypy.nodes import TypeInfo @@ -51,7 +54,7 @@ QUERY_EXPRESSION: int = util.symbol("QUERY_EXPRESSION") # type: ignore NAMED_TYPE_BUILTINS_OBJECT = "builtins.object" NAMED_TYPE_BUILTINS_STR = "builtins.str" NAMED_TYPE_BUILTINS_LIST = "builtins.list" -NAMED_TYPE_SQLA_MAPPED = "sqlalchemy.orm.attributes.Mapped" +NAMED_TYPE_SQLA_MAPPED = "sqlalchemy.orm.base.Mapped" _lookup: Dict[str, Tuple[int, Set[str]]] = { "Column": ( @@ -61,11 +64,11 @@ _lookup: Dict[str, Tuple[int, Set[str]]] = { "sqlalchemy.sql.Column", }, ), - "RelationshipProperty": ( + "Relationship": ( RELATIONSHIP, { - "sqlalchemy.orm.relationships.RelationshipProperty", - "sqlalchemy.orm.RelationshipProperty", + "sqlalchemy.orm.relationships.Relationship", + "sqlalchemy.orm.Relationship", }, ), "registry": ( @@ -82,18 +85,18 @@ _lookup: Dict[str, Tuple[int, Set[str]]] = { "sqlalchemy.orm.ColumnProperty", }, ), - "SynonymProperty": ( + "Synonym": ( SYNONYM_PROPERTY, { - "sqlalchemy.orm.descriptor_props.SynonymProperty", - "sqlalchemy.orm.SynonymProperty", + "sqlalchemy.orm.descriptor_props.Synonym", + "sqlalchemy.orm.Synonym", }, ), - "CompositeProperty": ( + "Composite": ( COMPOSITE_PROPERTY, { - "sqlalchemy.orm.descriptor_props.CompositeProperty", - "sqlalchemy.orm.CompositeProperty", + "sqlalchemy.orm.descriptor_props.Composite", + "sqlalchemy.orm.Composite", }, ), "MapperProperty": ( @@ -159,7 +162,10 @@ _lookup: Dict[str, Tuple[int, Set[str]]] = { ), "query_expression": ( QUERY_EXPRESSION, - {"sqlalchemy.orm.query_expression"}, + { + "sqlalchemy.orm.query_expression", + "sqlalchemy.orm._orm_constructors.query_expression", + }, ), } @@ -209,7 +215,19 @@ def type_id_for_unbound_type( def type_id_for_callee(callee: Expression) -> Optional[int]: if isinstance(callee, (MemberExpr, NameExpr)): - if isinstance(callee.node, FuncDef): + if isinstance(callee.node, OverloadedFuncDef): + if ( + callee.node.impl + and callee.node.impl.type + and isinstance(callee.node.impl.type, CallableType) + ): + ret_type = get_proper_type(callee.node.impl.type.ret_type) + + if isinstance(ret_type, Instance): + return type_id_for_fullname(ret_type.type.fullname) + + return None + elif isinstance(callee.node, FuncDef): if callee.node.type and isinstance(callee.node.type, CallableType): ret_type = get_proper_type(callee.node.type.ret_type) @@ -251,3 +269,15 @@ def type_id_for_fullname(fullname: str) -> Optional[int]: return type_id else: return None + + +def expr_to_mapped_constructor(expr: Expression) -> CallExpr: + column_descriptor = NameExpr("__sa_Mapped") + column_descriptor.fullname = NAMED_TYPE_SQLA_MAPPED + member_expr = MemberExpr(column_descriptor, "_empty_constructor") + return CallExpr( + member_expr, + [expr], + [ARG_POS], + ["arg1"], + ) diff --git a/lib/sqlalchemy/ext/mypy/plugin.py b/lib/sqlalchemy/ext/mypy/plugin.py index 0a21feb51f..c9520fef33 100644 --- a/lib/sqlalchemy/ext/mypy/plugin.py +++ b/lib/sqlalchemy/ext/mypy/plugin.py @@ -40,6 +40,19 @@ from . import decl_class from . import names from . import util +try: + import sqlalchemy_stubs # noqa +except ImportError: + pass +else: + import sqlalchemy + + raise ImportError( + f"The SQLAlchemy mypy plugin in SQLAlchemy " + f"{sqlalchemy.__version__} does not work with sqlalchemy-stubs or " + "sqlalchemy2-stubs installed" + ) + class SQLAlchemyPlugin(Plugin): def get_dynamic_class_hook( diff --git a/lib/sqlalchemy/ext/mypy/util.py b/lib/sqlalchemy/ext/mypy/util.py index fa42074c39..741772eacd 100644 --- a/lib/sqlalchemy/ext/mypy/util.py +++ b/lib/sqlalchemy/ext/mypy/util.py @@ -10,24 +10,27 @@ from typing import Type as TypingType from typing import TypeVar from typing import Union -from mypy.nodes import ARG_POS from mypy.nodes import CallExpr from mypy.nodes import ClassDef from mypy.nodes import CLASSDEF_NO_INFO from mypy.nodes import Context from mypy.nodes import Expression +from mypy.nodes import FuncDef from mypy.nodes import IfStmt from mypy.nodes import JsonDict from mypy.nodes import MemberExpr from mypy.nodes import NameExpr from mypy.nodes import Statement from mypy.nodes import SymbolTableNode +from mypy.nodes import TypeAlias from mypy.nodes import TypeInfo from mypy.plugin import ClassDefContext from mypy.plugin import DynamicClassDefContext from mypy.plugin import SemanticAnalyzerPluginInterface from mypy.plugins.common import deserialize_and_fixup_type from mypy.typeops import map_type_from_supertype +from mypy.types import CallableType +from mypy.types import get_proper_type from mypy.types import Instance from mypy.types import NoneType from mypy.types import Type @@ -231,6 +234,25 @@ def flatten_typechecking(stmts: Iterable[Statement]) -> Iterator[Statement]: yield stmt +def type_for_callee(callee: Expression) -> Optional[Union[Instance, TypeInfo]]: + if isinstance(callee, (MemberExpr, NameExpr)): + if isinstance(callee.node, FuncDef): + if callee.node.type and isinstance(callee.node.type, CallableType): + ret_type = get_proper_type(callee.node.type.ret_type) + + if isinstance(ret_type, Instance): + return ret_type + + return None + elif isinstance(callee.node, TypeAlias): + target_type = get_proper_type(callee.node.target) + if isinstance(target_type, Instance): + return target_type + elif isinstance(callee.node, TypeInfo): + return callee.node + return None + + def unbound_to_instance( api: SemanticAnalyzerPluginInterface, typ: Type ) -> Type: @@ -290,15 +312,3 @@ def info_for_cls( return sym.node return cls.info - - -def expr_to_mapped_constructor(expr: Expression) -> CallExpr: - column_descriptor = NameExpr("__sa_Mapped") - column_descriptor.fullname = "sqlalchemy.orm.attributes.Mapped" - member_expr = MemberExpr(column_descriptor, "_empty_constructor") - return CallExpr( - member_expr, - [expr], - [ARG_POS], - ["arg1"], - ) diff --git a/lib/sqlalchemy/ext/orderinglist.py b/lib/sqlalchemy/ext/orderinglist.py index 5a327d1a52..5384851b10 100644 --- a/lib/sqlalchemy/ext/orderinglist.py +++ b/lib/sqlalchemy/ext/orderinglist.py @@ -119,14 +119,28 @@ start numbering at 1 or some other integer, provide ``count_from=1``. """ +from typing import Callable +from typing import List +from typing import Optional +from typing import Sequence +from typing import TypeVar + from ..orm.collections import collection from ..orm.collections import collection_adapter +_T = TypeVar("_T") +OrderingFunc = Callable[[int, Sequence[_T]], int] + __all__ = ["ordering_list"] -def ordering_list(attr, count_from=None, **kw): +def ordering_list( + attr: str, + count_from: Optional[int] = None, + ordering_func: Optional[OrderingFunc] = None, + reorder_on_append: bool = False, +) -> Callable[[], "OrderingList"]: """Prepares an :class:`OrderingList` factory for use in mapper definitions. Returns an object suitable for use as an argument to a Mapper @@ -157,7 +171,11 @@ def ordering_list(attr, count_from=None, **kw): """ - kw = _unsugar_count_from(count_from=count_from, **kw) + kw = _unsugar_count_from( + count_from=count_from, + ordering_func=ordering_func, + reorder_on_append=reorder_on_append, + ) return lambda: OrderingList(attr, **kw) @@ -207,7 +225,7 @@ def _unsugar_count_from(**kw): return kw -class OrderingList(list): +class OrderingList(List[_T]): """A custom list that manages position information for its children. The :class:`.OrderingList` object is normally set up using the @@ -216,8 +234,15 @@ class OrderingList(list): """ + ordering_attr: str + ordering_func: OrderingFunc + reorder_on_append: bool + def __init__( - self, ordering_attr=None, ordering_func=None, reorder_on_append=False + self, + ordering_attr: Optional[str] = None, + ordering_func: Optional[OrderingFunc] = None, + reorder_on_append: bool = False, ): """A custom list that manages position information for its children. @@ -282,7 +307,7 @@ class OrderingList(list): def _set_order_value(self, entity, value): setattr(entity, self.ordering_attr, value) - def reorder(self): + def reorder(self) -> None: """Synchronize ordering for the entire collection. Sweeps through the list and ensures that each object has accurate diff --git a/lib/sqlalchemy/log.py b/lib/sqlalchemy/log.py index 885163ecbd..c6a8b6ea7f 100644 --- a/lib/sqlalchemy/log.py +++ b/lib/sqlalchemy/log.py @@ -74,6 +74,8 @@ def class_logger(cls: Type[_IT]) -> Type[_IT]: class Identified: + __slots__ = () + logging_name: Optional[str] = None logger: Union[logging.Logger, "InstanceLogger"] @@ -116,6 +118,8 @@ class InstanceLogger: _echo: _EchoFlagType + __slots__ = ("echo", "logger") + def __init__(self, echo: _EchoFlagType, name: str): self.echo = echo self.logger = logging.getLogger(name) diff --git a/lib/sqlalchemy/orm/__init__.py b/lib/sqlalchemy/orm/__init__.py index 55f2f31000..bbed933104 100644 --- a/lib/sqlalchemy/orm/__init__.py +++ b/lib/sqlalchemy/orm/__init__.py @@ -17,19 +17,27 @@ from . import exc as exc from . import mapper as mapperlib from . import strategy_options as strategy_options from ._orm_constructors import _mapper_fn as mapper +from ._orm_constructors import aliased as aliased from ._orm_constructors import backref as backref from ._orm_constructors import clear_mappers as clear_mappers from ._orm_constructors import column_property as column_property from ._orm_constructors import composite as composite +from ._orm_constructors import CompositeProperty as CompositeProperty from ._orm_constructors import contains_alias as contains_alias from ._orm_constructors import create_session as create_session from ._orm_constructors import deferred as deferred from ._orm_constructors import dynamic_loader as dynamic_loader +from ._orm_constructors import join as join from ._orm_constructors import mapped_column as mapped_column +from ._orm_constructors import MappedColumn as MappedColumn +from ._orm_constructors import outerjoin as outerjoin from ._orm_constructors import query_expression as query_expression from ._orm_constructors import relationship as relationship +from ._orm_constructors import RelationshipProperty as RelationshipProperty from ._orm_constructors import synonym as synonym +from ._orm_constructors import SynonymProperty as SynonymProperty from ._orm_constructors import with_loader_criteria as with_loader_criteria +from ._orm_constructors import with_polymorphic as with_polymorphic from .attributes import AttributeEvent as AttributeEvent from .attributes import InstrumentedAttribute as InstrumentedAttribute from .attributes import QueryableAttribute as QueryableAttribute @@ -46,8 +54,8 @@ from .decl_api import declared_attr as declared_attr from .decl_api import has_inherited_table as has_inherited_table from .decl_api import registry as registry from .decl_api import synonym_for as synonym_for -from .descriptor_props import CompositeProperty as CompositeProperty -from .descriptor_props import SynonymProperty as SynonymProperty +from .descriptor_props import Composite as Composite +from .descriptor_props import Synonym as Synonym from .dynamic import AppenderQuery as AppenderQuery from .events import AttributeEvents as AttributeEvents from .events import InstanceEvents as InstanceEvents @@ -81,7 +89,7 @@ from .query import AliasOption as AliasOption from .query import FromStatement as FromStatement from .query import Query as Query from .relationships import foreign as foreign -from .relationships import RelationshipProperty as RelationshipProperty +from .relationships import Relationship as Relationship from .relationships import remote as remote from .scoping import scoped_session as scoped_session from .session import close_all_sessions as close_all_sessions @@ -111,17 +119,13 @@ from .strategy_options import undefer as undefer from .strategy_options import undefer_group as undefer_group from .strategy_options import with_expression as with_expression from .unitofwork import UOWTransaction as UOWTransaction -from .util import aliased as aliased from .util import Bundle as Bundle from .util import CascadeOptions as CascadeOptions -from .util import join as join from .util import LoaderCriteriaOption as LoaderCriteriaOption from .util import object_mapper as object_mapper -from .util import outerjoin as outerjoin from .util import polymorphic_union as polymorphic_union from .util import was_deleted as was_deleted from .util import with_parent as with_parent -from .util import with_polymorphic as with_polymorphic from .. import util as _sa_util diff --git a/lib/sqlalchemy/orm/_orm_constructors.py b/lib/sqlalchemy/orm/_orm_constructors.py index 80607670eb..a1f1faa053 100644 --- a/lib/sqlalchemy/orm/_orm_constructors.py +++ b/lib/sqlalchemy/orm/_orm_constructors.py @@ -7,35 +7,52 @@ import typing from typing import Any -from typing import Callable from typing import Collection +from typing import List +from typing import Mapping from typing import Optional from typing import overload +from typing import Set from typing import Type from typing import Union from . import mapper as mapperlib from .base import Mapped -from .descriptor_props import CompositeProperty -from .descriptor_props import SynonymProperty +from .descriptor_props import Composite +from .descriptor_props import Synonym +from .mapper import Mapper from .properties import ColumnProperty +from .properties import MappedColumn from .query import AliasOption -from .relationships import RelationshipProperty +from .relationships import _RelationshipArgumentType +from .relationships import Relationship from .session import Session +from .util import _ORMJoin +from .util import AliasedClass +from .util import AliasedInsp from .util import LoaderCriteriaOption from .. import sql from .. import util from ..exc import InvalidRequestError -from ..sql.schema import Column -from ..sql.schema import SchemaEventTarget +from ..sql.base import SchemaEventTarget +from ..sql.selectable import Alias +from ..sql.selectable import FromClause from ..sql.type_api import TypeEngine from ..util.typing import Literal - -_RC = typing.TypeVar("_RC") _T = typing.TypeVar("_T") +CompositeProperty = Composite +"""Alias for :class:`_orm.Composite`.""" + +RelationshipProperty = Relationship +"""Alias for :class:`_orm.Relationship`.""" + +SynonymProperty = Synonym +"""Alias for :class:`_orm.Synonym`.""" + + @util.deprecated( "1.4", "The :class:`.AliasOption` object is not necessary " @@ -51,35 +68,45 @@ def contains_alias(alias) -> "AliasOption": return AliasOption(alias) +# see test/ext/mypy/plain_files/mapped_column.py for mapped column +# typing tests + + @overload def mapped_column( + __type: Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"], *args: SchemaEventTarget, - nullable: bool = ..., - primary_key: bool = ..., + nullable: Literal[None] = ..., + primary_key: Literal[None] = ..., + deferred: bool = ..., **kw: Any, -) -> "Mapped": +) -> "MappedColumn[Any]": ... @overload def mapped_column( + __name: str, __type: Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"], *args: SchemaEventTarget, - nullable: Union[Literal[None], Literal[True]] = ..., - primary_key: Union[Literal[None], Literal[False]] = ..., + nullable: Literal[None] = ..., + primary_key: Literal[None] = ..., + deferred: bool = ..., **kw: Any, -) -> "Mapped[Optional[_T]]": +) -> "MappedColumn[Any]": ... @overload def mapped_column( + __name: str, __type: Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"], *args: SchemaEventTarget, - nullable: Union[Literal[None], Literal[True]] = ..., - primary_key: Union[Literal[None], Literal[False]] = ..., + nullable: Literal[True] = ..., + primary_key: Literal[None] = ..., + deferred: bool = ..., **kw: Any, -) -> "Mapped[Optional[_T]]": +) -> "MappedColumn[Optional[_T]]": ... @@ -87,45 +114,48 @@ def mapped_column( def mapped_column( __type: Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"], *args: SchemaEventTarget, - nullable: Union[Literal[None], Literal[False]] = ..., - primary_key: Literal[True] = True, + nullable: Literal[True] = ..., + primary_key: Literal[None] = ..., + deferred: bool = ..., **kw: Any, -) -> "Mapped[_T]": +) -> "MappedColumn[Optional[_T]]": ... @overload def mapped_column( + __name: str, __type: Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"], *args: SchemaEventTarget, nullable: Literal[False] = ..., - primary_key: bool = ..., + primary_key: Literal[None] = ..., + deferred: bool = ..., **kw: Any, -) -> "Mapped[_T]": +) -> "MappedColumn[_T]": ... @overload def mapped_column( - __name: str, __type: Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"], *args: SchemaEventTarget, - nullable: Union[Literal[None], Literal[True]] = ..., - primary_key: Union[Literal[None], Literal[False]] = ..., + nullable: Literal[False] = ..., + primary_key: Literal[None] = ..., + deferred: bool = ..., **kw: Any, -) -> "Mapped[Optional[_T]]": +) -> "MappedColumn[_T]": ... @overload def mapped_column( - __name: str, __type: Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"], *args: SchemaEventTarget, - nullable: Union[Literal[None], Literal[True]] = ..., - primary_key: Union[Literal[None], Literal[False]] = ..., + nullable: bool = ..., + primary_key: Literal[True] = ..., + deferred: bool = ..., **kw: Any, -) -> "Mapped[Optional[_T]]": +) -> "MappedColumn[_T]": ... @@ -134,55 +164,209 @@ def mapped_column( __name: str, __type: Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"], *args: SchemaEventTarget, - nullable: Union[Literal[None], Literal[False]] = ..., - primary_key: Literal[True] = True, + nullable: bool = ..., + primary_key: Literal[True] = ..., + deferred: bool = ..., **kw: Any, -) -> "Mapped[_T]": +) -> "MappedColumn[_T]": ... @overload def mapped_column( __name: str, - __type: Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"], *args: SchemaEventTarget, - nullable: Literal[False] = ..., + nullable: bool = ..., primary_key: bool = ..., + deferred: bool = ..., **kw: Any, -) -> "Mapped[_T]": +) -> "MappedColumn[Any]": ... -def mapped_column(*args, **kw) -> "Mapped": - """construct a new ORM-mapped :class:`_schema.Column` construct. +@overload +def mapped_column( + *args: SchemaEventTarget, + nullable: bool = ..., + primary_key: bool = ..., + deferred: bool = ..., + **kw: Any, +) -> "MappedColumn[Any]": + ... + + +def mapped_column(*args: Any, **kw: Any) -> "MappedColumn[Any]": + r"""construct a new ORM-mapped :class:`_schema.Column` construct. + + The :func:`_orm.mapped_column` function provides an ORM-aware and + Python-typing-compatible construct which is used with + :ref:`declarative ` mappings to indicate an + attribute that's mapped to a Core :class:`_schema.Column` object. It + provides the equivalent feature as mapping an attribute to a + :class:`_schema.Column` object directly when using declarative. + + .. versionadded:: 2.0 - The :func:`_orm.mapped_column` function is shorthand for the construction - of a Core :class:`_schema.Column` object delivered within a - :func:`_orm.column_property` construct, which provides for consistent - typing information to be delivered to the class so that it works under - static type checkers such as mypy and delivers useful information in - IDE related type checkers such as pylance. The function can be used - in declarative mappings anywhere that :class:`_schema.Column` is normally - used:: + :func:`_orm.mapped_column` is normally used with explicit typing along with + the :class:`_orm.Mapped` mapped attribute type, where it can derive the SQL + type and nullability for the column automatically, such as:: + from typing import Optional + + from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column class User(Base): __tablename__ = 'user' - id = mapped_column(Integer) - name = mapped_column(String) + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column() + options: Mapped[Optional[str]] = mapped_column() + + In the above example, the ``int`` and ``str`` types are inferred by the + Declarative mapping system to indicate use of the :class:`_types.Integer` + and :class:`_types.String` datatypes, and the presence of ``Optional`` or + not indicates whether or not each non-primary-key column is to be + ``nullable=True`` or ``nullable=False``. + + The above example, when interpreted within a Declarative class, will result + in a table named ``"user"`` which is equivalent to the following:: + + from sqlalchemy import Integer + from sqlalchemy import String + from sqlalchemy import Table + + Table( + 'user', + Base.metadata, + Column("id", Integer, primary_key=True), + Column("name", String, nullable=False), + Column("options", String, nullable=True), + ) + The :func:`_orm.mapped_column` construct accepts the same arguments as + that of :class:`_schema.Column` directly, including optional "name" + and "type" fields, so the above mapping can be stated more explicitly + as:: - .. versionadded:: 2.0 + from typing import Optional + + from sqlalchemy import Integer + from sqlalchemy import String + from sqlalchemy.orm import Mapped + from sqlalchemy.orm import mapped_column + + class User(Base): + __tablename__ = 'user' + + id: Mapped[int] = mapped_column("id", Integer, primary_key=True) + name: Mapped[str] = mapped_column("name", String, nullable=False) + options: Mapped[Optional[str]] = mapped_column( + "name", String, nullable=True + ) + + Arguments passed to :func:`_orm.mapped_column` always supersede those which + would be derived from the type annotation and/or attribute name. To state + the above mapping with more specific datatypes for ``id`` and ``options``, + and a different column name for ``name``, looks like:: + + from sqlalchemy import BigInteger + + class User(Base): + __tablename__ = 'user' + + id: Mapped[int] = mapped_column("id", BigInteger, primary_key=True) + name: Mapped[str] = mapped_column("user_name") + options: Mapped[Optional[str]] = mapped_column(String(50)) + + Where again, datatypes and nullable parameters that can be automatically + derived may be omitted. + + The datatypes passed to :class:`_orm.Mapped` are mapped to SQL + :class:`_types.TypeEngine` types with the following default mapping:: + + _type_map = { + int: Integer(), + float: Float(), + bool: Boolean(), + decimal.Decimal: Numeric(), + dt.date: Date(), + dt.datetime: DateTime(), + dt.time: Time(), + dt.timedelta: Interval(), + util.NoneType: NULLTYPE, + bytes: LargeBinary(), + str: String(), + } + + The above mapping may be expanded to include any combination of Python + datatypes to SQL types by using the + :paramref:`_orm.registry.type_annotation_map` parameter to + :class:`_orm.registry`, or as the attribute ``type_annotation_map`` upon + the :class:`_orm.DeclarativeBase` base class. + + Finally, :func:`_orm.mapped_column` is implicitly used by the Declarative + mapping system for any :class:`_orm.Mapped` annotation that has no + attribute value set up. This is much in the way that Python dataclasses + allow the ``field()`` construct to be optional, only needed when additional + parameters should be associated with the field. Using this functionality, + our original mapping can be stated even more succinctly as:: + + from typing import Optional + + from sqlalchemy.orm import Mapped + from sqlalchemy.orm import mapped_column + + class User(Base): + __tablename__ = 'user' + + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] + options: Mapped[Optional[str]] + + Above, the ``name`` and ``options`` columns will be evaluated as + ``Column("name", String, nullable=False)`` and + ``Column("options", String, nullable=True)``, respectively. + + :param __name: String name to give to the :class:`_schema.Column`. This + is an optional, positional only argument that if present must be the + first positional argument passed. If omitted, the attribute name to + which the :func:`_orm.mapped_column` is mapped will be used as the SQL + column name. + :param __type: :class:`_types.TypeEngine` type or instance which will + indicate the datatype to be associated with the :class:`_schema.Column`. + This is an optional, positional-only argument that if present must + immediately follow the ``__name`` parameter if present also, or otherwise + be the first positional parameter. If omitted, the ultimate type for + the column may be derived either from the annotated type, or if a + :class:`_schema.ForeignKey` is present, from the datatype of the + referenced column. + :param \*args: Additional positional arguments include constructs such + as :class:`_schema.ForeignKey`, :class:`_schema.CheckConstraint`, + and :class:`_schema.Identity`, which are passed through to the constructed + :class:`_schema.Column`. + :param nullable: Optional bool, whether the column should be "NULL" or + "NOT NULL". If omitted, the nullability is derived from the type + annotation based on whether or not ``typing.Optional`` is present. + ``nullable`` defaults to ``True`` otherwise for non-primary key columns, + and ``False`` or primary key columns. + :param primary_key: optional bool, indicates the :class:`_schema.Column` + would be part of the table's primary key or not. + :param deferred: Optional bool - this keyword argument is consumed by the + ORM declarative process, and is not part of the :class:`_schema.Column` + itself; instead, it indicates that this column should be "deferred" for + loading as though mapped by :func:`_orm.deferred`. + :param \**kw: All remaining keyword argments are passed through to the + constructor for the :class:`_schema.Column`. """ - return column_property(Column(*args, **kw)) + + return MappedColumn(*args, **kw) def column_property( column: sql.ColumnElement[_T], *additional_columns, **kwargs -) -> "Mapped[_T]": +) -> "ColumnProperty[_T]": r"""Provide a column-level property for use with a mapping. Column-based properties can normally be applied to the mapper's @@ -269,22 +453,49 @@ def column_property( return ColumnProperty(column, *additional_columns, **kwargs) -def composite(class_: Type[_T], *attrs, **kwargs) -> "Mapped[_T]": +@overload +def composite( + class_: Type[_T], + *attrs: Union[sql.ColumnElement[Any], MappedColumn, str, Mapped[Any]], + **kwargs: Any, +) -> "Composite[_T]": + ... + + +@overload +def composite( + *attrs: Union[sql.ColumnElement[Any], MappedColumn, str, Mapped[Any]], + **kwargs: Any, +) -> "Composite[Any]": + ... + + +def composite( + class_: Any = None, + *attrs: Union[sql.ColumnElement[Any], MappedColumn, str, Mapped[Any]], + **kwargs: Any, +) -> "Composite[Any]": r"""Return a composite column-based property for use with a Mapper. See the mapping documentation section :ref:`mapper_composite` for a full usage example. The :class:`.MapperProperty` returned by :func:`.composite` - is the :class:`.CompositeProperty`. + is the :class:`.Composite`. :param class\_: The "composite type" class, or any classmethod or callable which will produce a new instance of the composite object given the column values in order. - :param \*cols: - List of Column objects to be mapped. + :param \*attrs: + List of elements to be mapped, which may include: + + * :class:`_schema.Column` objects + * :func:`_orm.mapped_column` constructs + * string names of other attributes on the mapped class, which may be + any other SQL or object-mapped attribute. This can for + example allow a composite that refers to a many-to-one relationship :param active_history=False: When ``True``, indicates that the "previous" value for a @@ -301,7 +512,7 @@ def composite(class_: Type[_T], *attrs, **kwargs) -> "Mapped[_T]": :func:`~sqlalchemy.orm.deferred`. :param comparator_factory: a class which extends - :class:`.CompositeProperty.Comparator` which provides custom SQL + :class:`.Composite.Comparator` which provides custom SQL clause generation for comparison operations. :param doc: @@ -312,7 +523,7 @@ def composite(class_: Type[_T], *attrs, **kwargs) -> "Mapped[_T]": :attr:`.MapperProperty.info` attribute of this object. """ - return CompositeProperty(class_, *attrs, **kwargs) + return Composite(class_, *attrs, **kwargs) def with_loader_criteria( @@ -500,143 +711,140 @@ def with_loader_criteria( @overload def relationship( - argument: Union[str, Type[_RC], Callable[[], Type[_RC]]], + argument: Optional[_RelationshipArgumentType[_T]], + secondary=None, + *, + uselist: Literal[False] = None, + collection_class: Literal[None] = None, + primaryjoin=None, + secondaryjoin=None, + back_populates=None, + **kw: Any, +) -> Relationship[_T]: + ... + + +@overload +def relationship( + argument: Optional[_RelationshipArgumentType[_T]], secondary=None, *, uselist: Literal[True] = None, + collection_class: Literal[None] = None, primaryjoin=None, secondaryjoin=None, - foreign_keys=None, - order_by=False, - backref=None, back_populates=None, - overlaps=None, - post_update=False, - cascade=False, - viewonly=False, - lazy="select", - collection_class=None, - passive_deletes=RelationshipProperty._persistence_only["passive_deletes"], - passive_updates=RelationshipProperty._persistence_only["passive_updates"], - remote_side=None, - enable_typechecks=RelationshipProperty._persistence_only[ - "enable_typechecks" - ], - join_depth=None, - comparator_factory=None, - single_parent=False, - innerjoin=False, - distinct_target_key=None, - doc=None, - active_history=RelationshipProperty._persistence_only["active_history"], - cascade_backrefs=RelationshipProperty._persistence_only[ - "cascade_backrefs" - ], - load_on_pending=False, - bake_queries=True, - _local_remote_pairs=None, - query_class=None, - info=None, - omit_join=None, - sync_backref=None, - _legacy_inactive_history_style=False, -) -> Mapped[Collection[_RC]]: + **kw: Any, +) -> Relationship[List[_T]]: ... @overload def relationship( - argument: Union[str, Type[_RC], Callable[[], Type[_RC]]], + argument: Optional[_RelationshipArgumentType[_T]], secondary=None, *, - uselist: Optional[bool] = None, + uselist: Union[Literal[None], Literal[True]] = None, + collection_class: Type[List] = None, primaryjoin=None, secondaryjoin=None, - foreign_keys=None, - order_by=False, - backref=None, back_populates=None, - overlaps=None, - post_update=False, - cascade=False, - viewonly=False, - lazy="select", - collection_class=None, - passive_deletes=RelationshipProperty._persistence_only["passive_deletes"], - passive_updates=RelationshipProperty._persistence_only["passive_updates"], - remote_side=None, - enable_typechecks=RelationshipProperty._persistence_only[ - "enable_typechecks" - ], - join_depth=None, - comparator_factory=None, - single_parent=False, - innerjoin=False, - distinct_target_key=None, - doc=None, - active_history=RelationshipProperty._persistence_only["active_history"], - cascade_backrefs=RelationshipProperty._persistence_only[ - "cascade_backrefs" - ], - load_on_pending=False, - bake_queries=True, - _local_remote_pairs=None, - query_class=None, - info=None, - omit_join=None, - sync_backref=None, - _legacy_inactive_history_style=False, -) -> Mapped[_RC]: + **kw: Any, +) -> Relationship[List[_T]]: ... +@overload def relationship( - argument: Union[str, Type[_RC], Callable[[], Type[_RC]]], + argument: Optional[_RelationshipArgumentType[_T]], secondary=None, *, + uselist: Union[Literal[None], Literal[True]] = None, + collection_class: Type[Set] = None, primaryjoin=None, secondaryjoin=None, - foreign_keys=None, + back_populates=None, + **kw: Any, +) -> Relationship[Set[_T]]: + ... + + +@overload +def relationship( + argument: Optional[_RelationshipArgumentType[_T]], + secondary=None, + *, + uselist: Union[Literal[None], Literal[True]] = None, + collection_class: Type[Mapping[Any, Any]] = None, + primaryjoin=None, + secondaryjoin=None, + back_populates=None, + **kw: Any, +) -> Relationship[Mapping[Any, _T]]: + ... + + +@overload +def relationship( + argument: _RelationshipArgumentType[_T], + secondary=None, + *, + uselist: Literal[None] = None, + collection_class: Literal[None] = None, + primaryjoin=None, + secondaryjoin=None, + back_populates=None, + **kw: Any, +) -> Relationship[Any]: + ... + + +@overload +def relationship( + argument: Optional[_RelationshipArgumentType[_T]] = None, + secondary=None, + *, + uselist: Literal[True] = None, + collection_class: Any = None, + primaryjoin=None, + secondaryjoin=None, + back_populates=None, + **kw: Any, +) -> Relationship[Any]: + ... + + +@overload +def relationship( + argument: Literal[None] = None, + secondary=None, + *, uselist: Optional[bool] = None, - order_by=False, - backref=None, + collection_class: Any = None, + primaryjoin=None, + secondaryjoin=None, back_populates=None, - overlaps=None, - post_update=False, - cascade=False, - viewonly=False, - lazy="select", - collection_class=None, - passive_deletes=RelationshipProperty._persistence_only["passive_deletes"], - passive_updates=RelationshipProperty._persistence_only["passive_updates"], - remote_side=None, - enable_typechecks=RelationshipProperty._persistence_only[ - "enable_typechecks" - ], - join_depth=None, - comparator_factory=None, - single_parent=False, - innerjoin=False, - distinct_target_key=None, - doc=None, - active_history=RelationshipProperty._persistence_only["active_history"], - cascade_backrefs=RelationshipProperty._persistence_only[ - "cascade_backrefs" - ], - load_on_pending=False, - bake_queries=True, - _local_remote_pairs=None, - query_class=None, - info=None, - omit_join=None, - sync_backref=None, - _legacy_inactive_history_style=False, -) -> Mapped[_RC]: + **kw: Any, +) -> Relationship[Any]: + ... + + +def relationship( + argument: Optional[_RelationshipArgumentType[_T]] = None, + secondary=None, + *, + uselist: Optional[bool] = None, + collection_class: Optional[Type[Collection]] = None, + primaryjoin=None, + secondaryjoin=None, + back_populates=None, + **kw: Any, +) -> Relationship[Any]: """Provide a relationship between two mapped classes. This corresponds to a parent-child or associative table relationship. The constructed class is an instance of - :class:`.RelationshipProperty`. + :class:`.Relationship`. A typical :func:`_orm.relationship`, used in a classical mapping:: @@ -897,7 +1105,7 @@ def relationship( examples. :param comparator_factory: - A class which extends :class:`.RelationshipProperty.Comparator` + A class which extends :class:`.Relationship.Comparator` which provides custom SQL clause generation for comparison operations. @@ -1447,42 +1655,15 @@ def relationship( """ - return RelationshipProperty( + return Relationship( argument, - secondary, - primaryjoin, - secondaryjoin, - foreign_keys, - uselist, - order_by, - backref, - back_populates, - overlaps, - post_update, - cascade, - viewonly, - lazy, - collection_class, - passive_deletes, - passive_updates, - remote_side, - enable_typechecks, - join_depth, - comparator_factory, - single_parent, - innerjoin, - distinct_target_key, - doc, - active_history, - cascade_backrefs, - load_on_pending, - bake_queries, - _local_remote_pairs, - query_class, - info, - omit_join, - sync_backref, - _legacy_inactive_history_style, + secondary=secondary, + uselist=uselist, + collection_class=collection_class, + primaryjoin=primaryjoin, + secondaryjoin=secondaryjoin, + back_populates=back_populates, + **kw, ) @@ -1493,7 +1674,7 @@ def synonym( comparator_factory=None, doc=None, info=None, -) -> "Mapped": +) -> "Synonym[Any]": """Denote an attribute name as a synonym to a mapped property, in that the attribute will mirror the value and expression behavior of another attribute. @@ -1597,9 +1778,7 @@ def synonym( than can be achieved with synonyms. """ - return SynonymProperty( - name, map_column, descriptor, comparator_factory, doc, info - ) + return Synonym(name, map_column, descriptor, comparator_factory, doc, info) def create_session(bind=None, **kwargs): @@ -1733,7 +1912,9 @@ def deferred(*columns, **kw): return ColumnProperty(deferred=True, *columns, **kw) -def query_expression(default_expr=sql.null()): +def query_expression( + default_expr: sql.ColumnElement[_T] = sql.null(), +) -> "Mapped[_T]": """Indicate an attribute that populates from a query-time SQL expression. :param default_expr: Optional SQL expression object that will be used in @@ -1787,3 +1968,273 @@ def clear_mappers(): """ mapperlib._dispose_registries(mapperlib._all_registries(), False) + + +@overload +def aliased( + element: Union[Type[_T], "Mapper[_T]", "AliasedClass[_T]"], + alias=None, + name=None, + flat=False, + adapt_on_names=False, +) -> "AliasedClass[_T]": + ... + + +@overload +def aliased( + element: "FromClause", + alias=None, + name=None, + flat=False, + adapt_on_names=False, +) -> "Alias": + ... + + +def aliased( + element: Union[Type[_T], "Mapper[_T]", "FromClause", "AliasedClass[_T]"], + alias=None, + name=None, + flat=False, + adapt_on_names=False, +) -> Union["AliasedClass[_T]", "Alias"]: + """Produce an alias of the given element, usually an :class:`.AliasedClass` + instance. + + E.g.:: + + my_alias = aliased(MyClass) + + session.query(MyClass, my_alias).filter(MyClass.id > my_alias.id) + + The :func:`.aliased` function is used to create an ad-hoc mapping of a + mapped class to a new selectable. By default, a selectable is generated + from the normally mapped selectable (typically a :class:`_schema.Table` + ) using the + :meth:`_expression.FromClause.alias` method. However, :func:`.aliased` + can also be + used to link the class to a new :func:`_expression.select` statement. + Also, the :func:`.with_polymorphic` function is a variant of + :func:`.aliased` that is intended to specify a so-called "polymorphic + selectable", that corresponds to the union of several joined-inheritance + subclasses at once. + + For convenience, the :func:`.aliased` function also accepts plain + :class:`_expression.FromClause` constructs, such as a + :class:`_schema.Table` or + :func:`_expression.select` construct. In those cases, the + :meth:`_expression.FromClause.alias` + method is called on the object and the new + :class:`_expression.Alias` object returned. The returned + :class:`_expression.Alias` is not + ORM-mapped in this case. + + .. seealso:: + + :ref:`tutorial_orm_entity_aliases` - in the :ref:`unified_tutorial` + + :ref:`orm_queryguide_orm_aliases` - in the :ref:`queryguide_toplevel` + + :ref:`ormtutorial_aliases` - in the legacy :ref:`ormtutorial_toplevel` + + :param element: element to be aliased. Is normally a mapped class, + but for convenience can also be a :class:`_expression.FromClause` + element. + + :param alias: Optional selectable unit to map the element to. This is + usually used to link the object to a subquery, and should be an aliased + select construct as one would produce from the + :meth:`_query.Query.subquery` method or + the :meth:`_expression.Select.subquery` or + :meth:`_expression.Select.alias` methods of the :func:`_expression.select` + construct. + + :param name: optional string name to use for the alias, if not specified + by the ``alias`` parameter. The name, among other things, forms the + attribute name that will be accessible via tuples returned by a + :class:`_query.Query` object. Not supported when creating aliases + of :class:`_sql.Join` objects. + + :param flat: Boolean, will be passed through to the + :meth:`_expression.FromClause.alias` call so that aliases of + :class:`_expression.Join` objects will alias the individual tables + inside the join, rather than creating a subquery. This is generally + supported by all modern databases with regards to right-nested joins + and generally produces more efficient queries. + + :param adapt_on_names: if True, more liberal "matching" will be used when + mapping the mapped columns of the ORM entity to those of the + given selectable - a name-based match will be performed if the + given selectable doesn't otherwise have a column that corresponds + to one on the entity. The use case for this is when associating + an entity with some derived selectable such as one that uses + aggregate functions:: + + class UnitPrice(Base): + __tablename__ = 'unit_price' + ... + unit_id = Column(Integer) + price = Column(Numeric) + + aggregated_unit_price = Session.query( + func.sum(UnitPrice.price).label('price') + ).group_by(UnitPrice.unit_id).subquery() + + aggregated_unit_price = aliased(UnitPrice, + alias=aggregated_unit_price, adapt_on_names=True) + + Above, functions on ``aggregated_unit_price`` which refer to + ``.price`` will return the + ``func.sum(UnitPrice.price).label('price')`` column, as it is + matched on the name "price". Ordinarily, the "price" function + wouldn't have any "column correspondence" to the actual + ``UnitPrice.price`` column as it is not a proxy of the original. + + """ + return AliasedInsp._alias_factory( + element, + alias=alias, + name=name, + flat=flat, + adapt_on_names=adapt_on_names, + ) + + +def with_polymorphic( + base, + classes, + selectable=False, + flat=False, + polymorphic_on=None, + aliased=False, + innerjoin=False, + _use_mapper_path=False, +): + """Produce an :class:`.AliasedClass` construct which specifies + columns for descendant mappers of the given base. + + Using this method will ensure that each descendant mapper's + tables are included in the FROM clause, and will allow filter() + criterion to be used against those tables. The resulting + instances will also have those columns already loaded so that + no "post fetch" of those columns will be required. + + .. seealso:: + + :ref:`with_polymorphic` - full discussion of + :func:`_orm.with_polymorphic`. + + :param base: Base class to be aliased. + + :param classes: a single class or mapper, or list of + class/mappers, which inherit from the base class. + Alternatively, it may also be the string ``'*'``, in which case + all descending mapped classes will be added to the FROM clause. + + :param aliased: when True, the selectable will be aliased. For a + JOIN, this means the JOIN will be SELECTed from inside of a subquery + unless the :paramref:`_orm.with_polymorphic.flat` flag is set to + True, which is recommended for simpler use cases. + + :param flat: Boolean, will be passed through to the + :meth:`_expression.FromClause.alias` call so that aliases of + :class:`_expression.Join` objects will alias the individual tables + inside the join, rather than creating a subquery. This is generally + supported by all modern databases with regards to right-nested joins + and generally produces more efficient queries. Setting this flag is + recommended as long as the resulting SQL is functional. + + :param selectable: a table or subquery that will + be used in place of the generated FROM clause. This argument is + required if any of the desired classes use concrete table + inheritance, since SQLAlchemy currently cannot generate UNIONs + among tables automatically. If used, the ``selectable`` argument + must represent the full set of tables and columns mapped by every + mapped class. Otherwise, the unaccounted mapped columns will + result in their table being appended directly to the FROM clause + which will usually lead to incorrect results. + + When left at its default value of ``False``, the polymorphic + selectable assigned to the base mapper is used for selecting rows. + However, it may also be passed as ``None``, which will bypass the + configured polymorphic selectable and instead construct an ad-hoc + selectable for the target classes given; for joined table inheritance + this will be a join that includes all target mappers and their + subclasses. + + :param polymorphic_on: a column to be used as the "discriminator" + column for the given selectable. If not given, the polymorphic_on + attribute of the base classes' mapper will be used, if any. This + is useful for mappings that don't have polymorphic loading + behavior by default. + + :param innerjoin: if True, an INNER JOIN will be used. This should + only be specified if querying for one specific subtype only + """ + return AliasedInsp._with_polymorphic_factory( + base, + classes, + selectable=selectable, + flat=flat, + polymorphic_on=polymorphic_on, + aliased=aliased, + innerjoin=innerjoin, + _use_mapper_path=_use_mapper_path, + ) + + +def join( + left, right, onclause=None, isouter=False, full=False, join_to_left=None +): + r"""Produce an inner join between left and right clauses. + + :func:`_orm.join` is an extension to the core join interface + provided by :func:`_expression.join()`, where the + left and right selectables may be not only core selectable + objects such as :class:`_schema.Table`, but also mapped classes or + :class:`.AliasedClass` instances. The "on" clause can + be a SQL expression, or an attribute or string name + referencing a configured :func:`_orm.relationship`. + + :func:`_orm.join` is not commonly needed in modern usage, + as its functionality is encapsulated within that of the + :meth:`_query.Query.join` method, which features a + significant amount of automation beyond :func:`_orm.join` + by itself. Explicit usage of :func:`_orm.join` + with :class:`_query.Query` involves usage of the + :meth:`_query.Query.select_from` method, as in:: + + from sqlalchemy.orm import join + session.query(User).\ + select_from(join(User, Address, User.addresses)).\ + filter(Address.email_address=='foo@bar.com') + + In modern SQLAlchemy the above join can be written more + succinctly as:: + + session.query(User).\ + join(User.addresses).\ + filter(Address.email_address=='foo@bar.com') + + See :meth:`_query.Query.join` for information on modern usage + of ORM level joins. + + .. deprecated:: 0.8 + + the ``join_to_left`` parameter is deprecated, and will be removed + in a future release. The parameter has no effect. + + """ + return _ORMJoin(left, right, onclause, isouter, full) + + +def outerjoin(left, right, onclause=None, full=False, join_to_left=None): + """Produce a left outer join between left and right clauses. + + This is the "outer join" version of the :func:`_orm.join` function, + featuring the same behavior except that an OUTER JOIN is generated. + See that function's documentation for other usage details. + + """ + return _ORMJoin(left, right, onclause, True, full) diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index 5a605b7c65..fbfb2b2ee0 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -35,6 +35,7 @@ from .base import instance_state from .base import instance_str from .base import LOAD_AGAINST_COMMITTED from .base import manager_of_class +from .base import Mapped as Mapped # noqa from .base import NEVER_SET # noqa from .base import NO_AUTOFLUSH from .base import NO_CHANGE # noqa @@ -79,6 +80,7 @@ class QueryableAttribute( traversals.HasCopyInternals, roles.JoinTargetRole, roles.OnClauseRole, + roles.ColumnsClauseRole, sql_base.Immutable, sql_base.MemoizedHasCacheKey, ): @@ -190,7 +192,7 @@ class QueryableAttribute( construct has defined one). * If the attribute refers to any other kind of - :class:`.MapperProperty`, including :class:`.RelationshipProperty`, + :class:`.MapperProperty`, including :class:`.Relationship`, the attribute will refer to the :attr:`.MapperProperty.info` dictionary associated with that :class:`.MapperProperty`. @@ -352,7 +354,7 @@ class QueryableAttribute( Return values here will commonly be instances of - :class:`.ColumnProperty` or :class:`.RelationshipProperty`. + :class:`.ColumnProperty` or :class:`.Relationship`. """ diff --git a/lib/sqlalchemy/orm/base.py b/lib/sqlalchemy/orm/base.py index 7ab4b77375..e6d4a67298 100644 --- a/lib/sqlalchemy/orm/base.py +++ b/lib/sqlalchemy/orm/base.py @@ -12,8 +12,11 @@ import operator import typing from typing import Any +from typing import Callable from typing import Generic +from typing import Optional from typing import overload +from typing import Tuple from typing import TypeVar from typing import Union @@ -22,8 +25,9 @@ from .. import exc as sa_exc from .. import inspection from .. import util from ..sql.elements import SQLCoreOperations -from ..util import typing as compat_typing from ..util.langhelpers import TypingOnly +from ..util.typing import Concatenate +from ..util.typing import ParamSpec if typing.TYPE_CHECKING: @@ -32,6 +36,9 @@ if typing.TYPE_CHECKING: _T = TypeVar("_T", bound=Any) +_IdentityKeyType = Tuple[type, Tuple[Any, ...], Optional[str]] + + PASSIVE_NO_RESULT = util.symbol( "PASSIVE_NO_RESULT", """Symbol returned by a loader callable or other attribute/history @@ -236,16 +243,16 @@ _DEFER_FOR_STATE = util.symbol("DEFER_FOR_STATE") _RAISE_FOR_STATE = util.symbol("RAISE_FOR_STATE") -_Fn = typing.TypeVar("_Fn", bound=typing.Callable) -_Args = compat_typing.ParamSpec("_Args") -_Self = typing.TypeVar("_Self") +_Fn = TypeVar("_Fn", bound=Callable) +_Args = ParamSpec("_Args") +_Self = TypeVar("_Self") def _assertions( - *assertions, -) -> typing.Callable[ - [typing.Callable[compat_typing.Concatenate[_Fn, _Args], _Self]], - typing.Callable[compat_typing.Concatenate[_Fn, _Args], _Self], + *assertions: Any, +) -> Callable[ + [Callable[Concatenate[_Self, _Fn, _Args], _Self]], + Callable[Concatenate[_Self, _Fn, _Args], _Self], ]: @util.decorator def generate( @@ -605,8 +612,8 @@ class SQLORMOperations(SQLCoreOperations[_T], TypingOnly): ... -class Mapped(Generic[_T], util.TypingOnly): - """Represent an ORM mapped attribute for typing purposes. +class Mapped(Generic[_T], TypingOnly): + """Represent an ORM mapped attribute on a mapped class. This class represents the complete descriptor interface for any class attribute that will have been :term:`instrumented` by the ORM @@ -650,7 +657,7 @@ class Mapped(Generic[_T], util.TypingOnly): ... @classmethod - def _empty_constructor(cls, arg1: Any) -> "SQLORMOperations[_T]": + def _empty_constructor(cls, arg1: Any) -> "Mapped[_T]": ... @overload diff --git a/lib/sqlalchemy/orm/clsregistry.py b/lib/sqlalchemy/orm/clsregistry.py index 3bf7ddde8f..c24b3c6969 100644 --- a/lib/sqlalchemy/orm/clsregistry.py +++ b/lib/sqlalchemy/orm/clsregistry.py @@ -10,11 +10,14 @@ This system allows specification of classes and expressions used in :func:`_orm.relationship` using strings. """ +import re +from typing import MutableMapping +from typing import Union import weakref from . import attributes from . import interfaces -from .descriptor_props import SynonymProperty +from .descriptor_props import Synonym from .properties import ColumnProperty from .util import class_mapper from .. import exc @@ -22,6 +25,8 @@ from .. import inspection from .. import util from ..sql.schema import _get_table_key +_ClsRegistryType = MutableMapping[str, Union[type, "ClsRegistryToken"]] + # strong references to registries which we place in # the _decl_class_registry, which is usually weak referencing. # the internal registries here link to classes with weakrefs and remove @@ -118,7 +123,13 @@ def _key_is_empty(key, decl_class_registry, test): return not test(thing) -class _MultipleClassMarker: +class ClsRegistryToken: + """an object that can be in the registry._class_registry as a value.""" + + __slots__ = () + + +class _MultipleClassMarker(ClsRegistryToken): """refers to multiple classes of the same name within _decl_class_registry. @@ -182,7 +193,7 @@ class _MultipleClassMarker: self.contents.add(weakref.ref(item, self._remove_item)) -class _ModuleMarker: +class _ModuleMarker(ClsRegistryToken): """Refers to a module name within _decl_class_registry. @@ -281,7 +292,7 @@ class _GetColumns: desc = mp.all_orm_descriptors[key] if desc.extension_type is interfaces.NOT_EXTENSION: prop = desc.property - if isinstance(prop, SynonymProperty): + if isinstance(prop, Synonym): key = prop.name elif not isinstance(prop, ColumnProperty): raise exc.InvalidRequestError( @@ -372,13 +383,26 @@ class _class_resolver: return self.fallback[key] def _raise_for_name(self, name, err): - raise exc.InvalidRequestError( - "When initializing mapper %s, expression %r failed to " - "locate a name (%r). If this is a class name, consider " - "adding this relationship() to the %r class after " - "both dependent classes have been defined." - % (self.prop.parent, self.arg, name, self.cls) - ) from err + generic_match = re.match(r"(.+)\[(.+)\]", name) + + if generic_match: + raise exc.InvalidRequestError( + f"When initializing mapper {self.prop.parent}, " + f'expression "relationship({self.arg!r})" seems to be ' + "using a generic class as the argument to relationship(); " + "please state the generic argument " + "using an annotation, e.g. " + f'"{self.prop.key}: Mapped[{generic_match.group(1)}' + f'[{generic_match.group(2)}]] = relationship()"' + ) from err + else: + raise exc.InvalidRequestError( + "When initializing mapper %s, expression %r failed to " + "locate a name (%r). If this is a class name, consider " + "adding this relationship() to the %r class after " + "both dependent classes have been defined." + % (self.prop.parent, self.arg, name, self.cls) + ) from err def _resolve_name(self): name = self.arg diff --git a/lib/sqlalchemy/orm/collections.py b/lib/sqlalchemy/orm/collections.py index 75ce8216f6..ba4225563d 100644 --- a/lib/sqlalchemy/orm/collections.py +++ b/lib/sqlalchemy/orm/collections.py @@ -102,18 +102,20 @@ The owning object and :class:`.CollectionAttributeImpl` are also reachable through the adapter, allowing for some very sophisticated behavior. """ - import operator import threading +import typing import weakref -from sqlalchemy.util.compat import inspect_getfullargspec -from . import base from .. import exc as sa_exc from .. import util -from ..sql import coercions -from ..sql import expression -from ..sql import roles +from ..util.compat import inspect_getfullargspec + +if typing.TYPE_CHECKING: + from .mapped_collection import attribute_mapped_collection + from .mapped_collection import column_mapped_collection + from .mapped_collection import mapped_collection + from .mapped_collection import MappedCollection # noqa: F401 __all__ = [ "collection", @@ -126,180 +128,6 @@ __all__ = [ __instrumentation_mutex = threading.Lock() -class _PlainColumnGetter: - """Plain column getter, stores collection of Column objects - directly. - - Serializes to a :class:`._SerializableColumnGetterV2` - which has more expensive __call__() performance - and some rare caveats. - - """ - - def __init__(self, cols): - self.cols = cols - self.composite = len(cols) > 1 - - def __reduce__(self): - return _SerializableColumnGetterV2._reduce_from_cols(self.cols) - - def _cols(self, mapper): - return self.cols - - def __call__(self, value): - state = base.instance_state(value) - m = base._state_mapper(state) - - key = [ - m._get_state_attr_by_column(state, state.dict, col) - for col in self._cols(m) - ] - - if self.composite: - return tuple(key) - else: - return key[0] - - -class _SerializableColumnGetter: - """Column-based getter used in version 0.7.6 only. - - Remains here for pickle compatibility with 0.7.6. - - """ - - def __init__(self, colkeys): - self.colkeys = colkeys - self.composite = len(colkeys) > 1 - - def __reduce__(self): - return _SerializableColumnGetter, (self.colkeys,) - - def __call__(self, value): - state = base.instance_state(value) - m = base._state_mapper(state) - key = [ - m._get_state_attr_by_column( - state, state.dict, m.mapped_table.columns[k] - ) - for k in self.colkeys - ] - if self.composite: - return tuple(key) - else: - return key[0] - - -class _SerializableColumnGetterV2(_PlainColumnGetter): - """Updated serializable getter which deals with - multi-table mapped classes. - - Two extremely unusual cases are not supported. - Mappings which have tables across multiple metadata - objects, or which are mapped to non-Table selectables - linked across inheriting mappers may fail to function - here. - - """ - - def __init__(self, colkeys): - self.colkeys = colkeys - self.composite = len(colkeys) > 1 - - def __reduce__(self): - return self.__class__, (self.colkeys,) - - @classmethod - def _reduce_from_cols(cls, cols): - def _table_key(c): - if not isinstance(c.table, expression.TableClause): - return None - else: - return c.table.key - - colkeys = [(c.key, _table_key(c)) for c in cols] - return _SerializableColumnGetterV2, (colkeys,) - - def _cols(self, mapper): - cols = [] - metadata = getattr(mapper.local_table, "metadata", None) - for (ckey, tkey) in self.colkeys: - if tkey is None or metadata is None or tkey not in metadata: - cols.append(mapper.local_table.c[ckey]) - else: - cols.append(metadata.tables[tkey].c[ckey]) - return cols - - -def column_mapped_collection(mapping_spec): - """A dictionary-based collection type with column-based keying. - - Returns a :class:`.MappedCollection` factory with a keying function - generated from mapping_spec, which may be a Column or a sequence - of Columns. - - The key value must be immutable for the lifetime of the object. You - can not, for example, map on foreign key values if those key values will - change during the session, i.e. from None to a database-assigned integer - after a session flush. - - """ - cols = [ - coercions.expect(roles.ColumnArgumentRole, q, argname="mapping_spec") - for q in util.to_list(mapping_spec) - ] - keyfunc = _PlainColumnGetter(cols) - return lambda: MappedCollection(keyfunc) - - -class _SerializableAttrGetter: - def __init__(self, name): - self.name = name - self.getter = operator.attrgetter(name) - - def __call__(self, target): - return self.getter(target) - - def __reduce__(self): - return _SerializableAttrGetter, (self.name,) - - -def attribute_mapped_collection(attr_name): - """A dictionary-based collection type with attribute-based keying. - - Returns a :class:`.MappedCollection` factory with a keying based on the - 'attr_name' attribute of entities in the collection, where ``attr_name`` - is the string name of the attribute. - - .. warning:: the key value must be assigned to its final value - **before** it is accessed by the attribute mapped collection. - Additionally, changes to the key attribute are **not tracked** - automatically, which means the key in the dictionary is not - automatically synchronized with the key value on the target object - itself. See the section :ref:`key_collections_mutations` - for an example. - - """ - getter = _SerializableAttrGetter(attr_name) - return lambda: MappedCollection(getter) - - -def mapped_collection(keyfunc): - """A dictionary-based collection type with arbitrary keying. - - Returns a :class:`.MappedCollection` factory with a keying function - generated from keyfunc, a callable that takes an entity and returns a - key value. - - The key value must be immutable for the lifetime of the object. You - can not, for example, map on foreign key values if those key values will - change during the session, i.e. from None to a database-assigned integer - after a session flush. - - """ - return lambda: MappedCollection(keyfunc) - - class collection: """Decorators for entity collection classes. @@ -1620,63 +1448,24 @@ __interfaces = { } -class MappedCollection(dict): - """A basic dictionary-based collection class. - - Extends dict with the minimal bag semantics that collection - classes require. ``set`` and ``remove`` are implemented in terms - of a keying function: any callable that takes an object and - returns an object for use as a dictionary key. - - """ - - def __init__(self, keyfunc): - """Create a new collection with keying provided by keyfunc. +def __go(lcls): - keyfunc may be any callable that takes an object and returns an object - for use as a dictionary key. + global mapped_collection, column_mapped_collection + global attribute_mapped_collection, MappedCollection - The keyfunc will be called every time the ORM needs to add a member by - value-only (such as when loading instances from the database) or - remove a member. The usual cautions about dictionary keying apply- - ``keyfunc(object)`` should return the same output for the life of the - collection. Keying based on mutable properties can result in - unreachable instances "lost" in the collection. + from .mapped_collection import mapped_collection + from .mapped_collection import column_mapped_collection + from .mapped_collection import attribute_mapped_collection + from .mapped_collection import MappedCollection - """ - self.keyfunc = keyfunc - - @collection.appender - @collection.internally_instrumented - def set(self, value, _sa_initiator=None): - """Add an item by value, consulting the keyfunc for the key.""" - - key = self.keyfunc(value) - self.__setitem__(key, value, _sa_initiator) - - @collection.remover - @collection.internally_instrumented - def remove(self, value, _sa_initiator=None): - """Remove an item by value, consulting the keyfunc for the key.""" - - key = self.keyfunc(value) - # Let self[key] raise if key is not in this collection - # testlib.pragma exempt:__ne__ - if self[key] != value: - raise sa_exc.InvalidRequestError( - "Can not remove '%s': collection holds '%s' for key '%s'. " - "Possible cause: is the MappedCollection key function " - "based on mutable properties or properties that only obtain " - "values after flush?" % (value, self[key], key) - ) - self.__delitem__(key, _sa_initiator) + # ensure instrumentation is associated with + # these built-in classes; if a user-defined class + # subclasses these and uses @internally_instrumented, + # the superclass is otherwise not instrumented. + # see [ticket:2406]. + _instrument_class(InstrumentedList) + _instrument_class(InstrumentedSet) + _instrument_class(MappedCollection) -# ensure instrumentation is associated with -# these built-in classes; if a user-defined class -# subclasses these and uses @internally_instrumented, -# the superclass is otherwise not instrumented. -# see [ticket:2406]. -_instrument_class(MappedCollection) -_instrument_class(InstrumentedList) -_instrument_class(InstrumentedSet) +__go(locals()) diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py index 8e9cf66e28..34f291864f 100644 --- a/lib/sqlalchemy/orm/context.py +++ b/lib/sqlalchemy/orm/context.py @@ -5,16 +5,18 @@ # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php import itertools +from typing import List from . import attributes from . import interfaces from . import loading from .base import _is_aliased_class +from .interfaces import ORMColumnDescription from .interfaces import ORMColumnsClauseRole from .path_registry import PathRegistry from .util import _entity_corresponds_to from .util import _ORMJoin -from .util import aliased +from .util import AliasedClass from .util import Bundle from .util import ORMAdapter from .. import exc as sa_exc @@ -1570,7 +1572,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState): # when we are here, it means join() was called with an indicator # as to an exact left side, which means a path to a - # RelationshipProperty was given, e.g.: + # Relationship was given, e.g.: # # join(RightEntity, LeftEntity.right) # @@ -1725,7 +1727,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState): need_adapter = True # make the right hand side target into an ORM entity - right = aliased(right_mapper, right_selectable) + right = AliasedClass(right_mapper, right_selectable) util.warn_deprecated( "An alias is being generated automatically against " @@ -1750,7 +1752,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState): # test/orm/inheritance/test_relationships.py. There are also # general overlap cases with many-to-many tables where automatic # aliasing is desirable. - right = aliased(right, flat=True) + right = AliasedClass(right, flat=True) need_adapter = True util.warn( @@ -1910,7 +1912,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState): def _column_descriptions( query_or_select_stmt, compile_state=None, legacy=False -): +) -> List[ORMColumnDescription]: if compile_state is None: compile_state = ORMSelectCompileState._create_entities_collection( query_or_select_stmt, legacy=legacy diff --git a/lib/sqlalchemy/orm/decl_api.py b/lib/sqlalchemy/orm/decl_api.py index 59fabb9b6b..5ac9966dd0 100644 --- a/lib/sqlalchemy/orm/decl_api.py +++ b/lib/sqlalchemy/orm/decl_api.py @@ -11,7 +11,9 @@ import typing from typing import Any from typing import Callable from typing import ClassVar +from typing import Mapping from typing import Optional +from typing import Type from typing import TypeVar from typing import Union import weakref @@ -31,7 +33,7 @@ from .decl_base import _declarative_constructor from .decl_base import _DeferredMapperConfig from .decl_base import _del_attribute from .decl_base import _mapper -from .descriptor_props import SynonymProperty as _orm_synonym +from .descriptor_props import Synonym as _orm_synonym from .mapper import Mapper from .. import exc from .. import inspection @@ -39,14 +41,18 @@ from .. import util from ..sql.elements import SQLCoreOperations from ..sql.schema import MetaData from ..sql.selectable import FromClause +from ..sql.type_api import TypeEngine from ..util import hybridmethod from ..util import hybridproperty +from ..util import typing as compat_typing if typing.TYPE_CHECKING: from .state import InstanceState # noqa _T = TypeVar("_T", bound=Any) +_TypeAnnotationMapType = Mapping[Type, Union[Type[TypeEngine], TypeEngine]] + def has_inherited_table(cls): """Given a class, return True if any of the classes it inherits from has a @@ -67,8 +73,22 @@ def has_inherited_table(cls): return False +class _DynamicAttributesType(type): + def __setattr__(cls, key, value): + if "__mapper__" in cls.__dict__: + _add_attribute(cls, key, value) + else: + type.__setattr__(cls, key, value) + + def __delattr__(cls, key): + if "__mapper__" in cls.__dict__: + _del_attribute(cls, key) + else: + type.__delattr__(cls, key) + + class DeclarativeAttributeIntercept( - type, inspection.Inspectable["Mapper[Any]"] + _DynamicAttributesType, inspection.Inspectable["Mapper[Any]"] ): """Metaclass that may be used in conjunction with the :class:`_orm.DeclarativeBase` class to support addition of class @@ -76,15 +96,16 @@ class DeclarativeAttributeIntercept( """ - def __setattr__(cls, key, value): - _add_attribute(cls, key, value) - - def __delattr__(cls, key): - _del_attribute(cls, key) +class DeclarativeMeta( + _DynamicAttributesType, inspection.Inspectable["Mapper[Any]"] +): + metadata: MetaData + registry: "RegistryType" -class DeclarativeMeta(type, inspection.Inspectable["Mapper[Any]"]): - def __init__(cls, classname, bases, dict_, **kw): + def __init__( + cls, classname: Any, bases: Any, dict_: Any, **kw: Any + ) -> None: # early-consume registry from the initial declarative base, # assign privately to not conflict with subclass attributes named # "registry" @@ -103,12 +124,6 @@ class DeclarativeMeta(type, inspection.Inspectable["Mapper[Any]"]): _as_declarative(reg, cls, dict_) type.__init__(cls, classname, bases, dict_) - def __setattr__(cls, key, value): - _add_attribute(cls, key, value) - - def __delattr__(cls, key): - _del_attribute(cls, key) - def synonym_for(name, map_column=False): """Decorator that produces an :func:`_orm.synonym` @@ -250,6 +265,9 @@ class declared_attr(interfaces._MappedAttribute[_T]): self._cascading = cascading self.__doc__ = fn.__doc__ + def _collect_return_annotation(self) -> Optional[Type[Any]]: + return util.get_annotations(self.fget).get("return") + def __get__(self, instance, owner) -> InstrumentedAttribute[_T]: # the declared_attr needs to make use of a cache that exists # for the span of the declarative scan_attributes() phase. @@ -409,6 +427,11 @@ def _setup_declarative_base(cls): else: metadata = None + if "type_annotation_map" in cls.__dict__: + type_annotation_map = cls.__dict__["type_annotation_map"] + else: + type_annotation_map = None + reg = cls.__dict__.get("registry", None) if reg is not None: if not isinstance(reg, registry): @@ -416,8 +439,18 @@ def _setup_declarative_base(cls): "Declarative base class has a 'registry' attribute that is " "not an instance of sqlalchemy.orm.registry()" ) + elif type_annotation_map is not None: + raise exc.InvalidRequestError( + "Declarative base class has both a 'registry' attribute and a " + "type_annotation_map entry. Per-base type_annotation_maps " + "are not supported. Please apply the type_annotation_map " + "to this registry directly." + ) + else: - reg = registry(metadata=metadata) + reg = registry( + metadata=metadata, type_annotation_map=type_annotation_map + ) cls.registry = reg cls._sa_registry = reg @@ -476,6 +509,44 @@ class DeclarativeBase( mappings. The superclass makes use of the ``__init_subclass__()`` method to set up new classes and metaclasses aren't used. + When first used, the :class:`_orm.DeclarativeBase` class instantiates a new + :class:`_orm.registry` to be used with the base, assuming one was not + provided explicitly. The :class:`_orm.DeclarativeBase` class supports + class-level attributes which act as parameters for the construction of this + registry; such as to indicate a specific :class:`_schema.MetaData` + collection as well as a specific value for + :paramref:`_orm.registry.type_annotation_map`:: + + from typing import Annotation + + from sqlalchemy import BigInteger + from sqlalchemy import MetaData + from sqlalchemy import String + from sqlalchemy.orm import DeclarativeBase + + bigint = Annotation(int, "bigint") + my_metadata = MetaData() + + class Base(DeclarativeBase): + metadata = my_metadata + type_annotation_map = { + str: String().with_variant(String(255), "mysql", "mariadb"), + bigint: BigInteger() + } + + Class-level attributes which may be specified include: + + :param metadata: optional :class:`_schema.MetaData` collection. + If a :class:`_orm.registry` is constructed automatically, this + :class:`_schema.MetaData` collection will be used to construct it. + Otherwise, the local :class:`_schema.MetaData` collection will supercede + that used by an existing :class:`_orm.registry` passed using the + :paramref:`_orm.DeclarativeBase.registry` parameter. + :param type_annotation_map: optional type annotation map that will be + passed to the :class:`_orm.registry` as + :paramref:`_orm.registry.type_annotation_map`. + :param registry: supply a pre-existing :class:`_orm.registry` directly. + .. versionadded:: 2.0 """ @@ -516,12 +587,13 @@ def add_mapped_attribute(target, key, attr): def declarative_base( - metadata=None, + metadata: Optional[MetaData] = None, mapper=None, cls=object, name="Base", - constructor=_declarative_constructor, - class_registry=None, + class_registry: Optional[clsregistry._ClsRegistryType] = None, + type_annotation_map: Optional[_TypeAnnotationMapType] = None, + constructor: Callable[..., None] = _declarative_constructor, metaclass=DeclarativeMeta, ) -> Any: r"""Construct a base class for declarative class definitions. @@ -593,6 +665,14 @@ def declarative_base( to share the same registry of class names for simplified inter-base relationships. + :param type_annotation_map: optional dictionary of Python types to + SQLAlchemy :class:`_types.TypeEngine` classes or instances. This + is used exclusively by the :class:`_orm.MappedColumn` construct + to produce column types based on annotations within the + :class:`_orm.Mapped` type. + + .. versionadded:: 2.0 + :param metaclass: Defaults to :class:`.DeclarativeMeta`. A metaclass or __metaclass__ compatible callable to use as the meta type of the generated @@ -608,6 +688,7 @@ def declarative_base( metadata=metadata, class_registry=class_registry, constructor=constructor, + type_annotation_map=type_annotation_map, ).generate_base( mapper=mapper, cls=cls, @@ -651,9 +732,10 @@ class registry: def __init__( self, - metadata=None, - class_registry=None, - constructor=_declarative_constructor, + metadata: Optional[MetaData] = None, + class_registry: Optional[clsregistry._ClsRegistryType] = None, + type_annotation_map: Optional[_TypeAnnotationMapType] = None, + constructor: Callable[..., None] = _declarative_constructor, ): r"""Construct a new :class:`_orm.registry` @@ -679,6 +761,14 @@ class registry: to share the same registry of class names for simplified inter-base relationships. + :param type_annotation_map: optional dictionary of Python types to + SQLAlchemy :class:`_types.TypeEngine` classes or instances. This + is used exclusively by the :class:`_orm.MappedColumn` construct + to produce column types based on annotations within the + :class:`_orm.Mapped` type. + + .. versionadded:: 2.0 + """ lcl_metadata = metadata or MetaData() @@ -690,7 +780,9 @@ class registry: self._non_primary_mappers = weakref.WeakKeyDictionary() self.metadata = lcl_metadata self.constructor = constructor - + self.type_annotation_map = {} + if type_annotation_map is not None: + self.update_type_annotation_map(type_annotation_map) self._dependents = set() self._dependencies = set() @@ -699,6 +791,25 @@ class registry: with mapperlib._CONFIGURE_MUTEX: mapperlib._mapper_registries[self] = True + def update_type_annotation_map( + self, + type_annotation_map: Mapping[ + Type, Union[Type[TypeEngine], TypeEngine] + ], + ) -> None: + """update the :paramref:`_orm.registry.type_annotation_map` with new + values.""" + + self.type_annotation_map.update( + { + sub_type: sqltype + for typ, sqltype in type_annotation_map.items() + for sub_type in compat_typing.expand_unions( + typ, include_union=True, discard_none=True + ) + } + ) + @property def mappers(self): """read only collection of all :class:`_orm.Mapper` objects.""" @@ -1131,6 +1242,9 @@ class registry: return _mapper(self, class_, local_table, kw) +RegistryType = registry + + def as_declarative(**kw): """ Class decorator which will adapt a given class into a diff --git a/lib/sqlalchemy/orm/decl_base.py b/lib/sqlalchemy/orm/decl_base.py index fb736806c4..342aa772b0 100644 --- a/lib/sqlalchemy/orm/decl_base.py +++ b/lib/sqlalchemy/orm/decl_base.py @@ -5,23 +5,34 @@ # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php """Internal implementation for declarative.""" + +from __future__ import annotations + import collections +from typing import Any +from typing import Dict +from typing import Tuple import weakref -from sqlalchemy.orm import attributes -from sqlalchemy.orm import instrumentation +from . import attributes from . import clsregistry from . import exc as orm_exc +from . import instrumentation from . import mapperlib from .attributes import InstrumentedAttribute from .attributes import QueryableAttribute from .base import _is_mapped_class from .base import InspectionAttr -from .descriptor_props import CompositeProperty -from .descriptor_props import SynonymProperty +from .descriptor_props import Composite +from .descriptor_props import Synonym +from .interfaces import _IntrospectsAnnotations +from .interfaces import _MappedAttribute +from .interfaces import _MapsColumns from .interfaces import MapperProperty from .mapper import Mapper as mapper from .properties import ColumnProperty +from .properties import MappedColumn +from .util import _is_mapped_annotation from .util import class_mapper from .. import event from .. import exc @@ -130,7 +141,7 @@ def _mapper(registry, cls, table, mapper_kw): @util.preload_module("sqlalchemy.orm.decl_api") -def _is_declarative_props(obj): +def _is_declarative_props(obj: Any) -> bool: declared_attr = util.preloaded.orm_decl_api.declared_attr return isinstance(obj, (declared_attr, util.classproperty)) @@ -208,7 +219,7 @@ class _MapperConfig: class _ImperativeMapperConfig(_MapperConfig): - __slots__ = ("dict_", "local_table", "inherits") + __slots__ = ("local_table", "inherits") def __init__( self, @@ -221,7 +232,6 @@ class _ImperativeMapperConfig(_MapperConfig): registry, cls_, mapper_kw ) - self.dict_ = {} self.local_table = self.set_cls_attribute("__table__", table) with mapperlib._CONFIGURE_MUTEX: @@ -277,7 +287,10 @@ class _ImperativeMapperConfig(_MapperConfig): class _ClassScanMapperConfig(_MapperConfig): __slots__ = ( - "dict_", + "registry", + "clsdict_view", + "collected_attributes", + "collected_annotations", "local_table", "persist_selectable", "declared_columns", @@ -299,11 +312,17 @@ class _ClassScanMapperConfig(_MapperConfig): ): super(_ClassScanMapperConfig, self).__init__(registry, cls_, mapper_kw) - - self.dict_ = dict(dict_) if dict_ else {} + self.registry = registry self.persist_selectable = None - self.declared_columns = set() + + self.clsdict_view = ( + util.immutabledict(dict_) if dict_ else util.EMPTY_DICT + ) + self.collected_attributes = {} + self.collected_annotations: Dict[str, Tuple[Any, bool]] = {} + self.declared_columns = util.OrderedSet() self.column_copies = {} + self._setup_declared_events() self._scan_attributes() @@ -407,6 +426,19 @@ class _ClassScanMapperConfig(_MapperConfig): return attribute_is_overridden + _skip_attrs = frozenset( + [ + "__module__", + "__annotations__", + "__doc__", + "__dict__", + "__weakref__", + "_sa_class_manager", + "__dict__", + "__weakref__", + ] + ) + def _cls_attr_resolver(self, cls): """produce a function to iterate the "attributes" of a class, adjusting for SQLAlchemy fields embedded in dataclass fields. @@ -416,31 +448,52 @@ class _ClassScanMapperConfig(_MapperConfig): cls, "__sa_dataclass_metadata_key__", None ) + cls_annotations = util.get_annotations(cls) + + cls_vars = vars(cls) + + skip = self._skip_attrs + + names = util.merge_lists_w_ordering( + [n for n in cls_vars if n not in skip], list(cls_annotations) + ) if sa_dataclass_metadata_key is None: def local_attributes_for_class(): - for name, obj in vars(cls).items(): - yield name, obj, False + return ( + ( + name, + cls_vars.get(name), + cls_annotations.get(name), + False, + ) + for name in names + ) else: - field_names = set() + dataclass_fields = { + field.name: field for field in util.local_dataclass_fields(cls) + } def local_attributes_for_class(): - for field in util.local_dataclass_fields(cls): - if sa_dataclass_metadata_key in field.metadata: - field_names.add(field.name) + for name in names: + field = dataclass_fields.get(name, None) + if field and sa_dataclass_metadata_key in field.metadata: yield field.name, _as_dc_declaredattr( field.metadata, sa_dataclass_metadata_key - ), True - for name, obj in vars(cls).items(): - if name not in field_names: - yield name, obj, False + ), cls_annotations.get(field.name), True + else: + yield name, cls_vars.get(name), cls_annotations.get( + name + ), False return local_attributes_for_class def _scan_attributes(self): cls = self.cls - dict_ = self.dict_ + + clsdict_view = self.clsdict_view + collected_attributes = self.collected_attributes column_copies = self.column_copies mapper_args_fn = None table_args = inherited_table_args = None @@ -462,10 +515,16 @@ class _ClassScanMapperConfig(_MapperConfig): if not class_mapped and base is not cls: self._produce_column_copies( - local_attributes_for_class, attribute_is_overridden + local_attributes_for_class, + attribute_is_overridden, ) - for name, obj, is_dataclass in local_attributes_for_class(): + for ( + name, + obj, + annotation, + is_dataclass, + ) in local_attributes_for_class(): if name == "__mapper_args__": check_decl = _check_declared_props_nocascade( obj, name, cls @@ -514,7 +573,12 @@ class _ClassScanMapperConfig(_MapperConfig): elif base is not cls: # we're a mixin, abstract base, or something that is # acting like that for now. - if isinstance(obj, Column): + + if isinstance(obj, (Column, MappedColumn)): + self.collected_annotations[name] = ( + annotation, + False, + ) # already copied columns to the mapped class. continue elif isinstance(obj, MapperProperty): @@ -526,8 +590,12 @@ class _ClassScanMapperConfig(_MapperConfig): "field() objects, use a lambda:" ) elif _is_declarative_props(obj): + # tried to get overloads to tell this to + # pylance, no luck + assert obj is not None + if obj._cascading: - if name in dict_: + if name in clsdict_view: # unfortunately, while we can use the user- # defined attribute here to allow a clean # override, if there's another @@ -541,7 +609,7 @@ class _ClassScanMapperConfig(_MapperConfig): "@declared_attr.cascading; " "skipping" % (name, cls) ) - dict_[name] = column_copies[ + collected_attributes[name] = column_copies[ obj ] = ret = obj.__get__(obj, cls) setattr(cls, name, ret) @@ -579,19 +647,36 @@ class _ClassScanMapperConfig(_MapperConfig): ): ret = ret.descriptor - dict_[name] = column_copies[obj] = ret + collected_attributes[name] = column_copies[ + obj + ] = ret if ( isinstance(ret, (Column, MapperProperty)) and ret.doc is None ): ret.doc = obj.__doc__ - # here, the attribute is some other kind of property that - # we assume is not part of the declarative mapping. - # however, check for some more common mistakes + + self.collected_annotations[name] = ( + obj._collect_return_annotation(), + False, + ) + elif _is_mapped_annotation(annotation, cls): + self.collected_annotations[name] = ( + annotation, + is_dataclass, + ) + if obj is None: + collected_attributes[name] = MappedColumn() + else: + collected_attributes[name] = obj else: + # here, the attribute is some other kind of + # property that we assume is not part of the + # declarative mapping. however, check for some + # more common mistakes self._warn_for_decl_attributes(base, name, obj) elif is_dataclass and ( - name not in dict_ or dict_[name] is not obj + name not in clsdict_view or clsdict_view[name] is not obj ): # here, we are definitely looking at the target class # and not a superclass. this is currently a @@ -606,7 +691,20 @@ class _ClassScanMapperConfig(_MapperConfig): if _is_declarative_props(obj): obj = obj.fget() - dict_[name] = obj + collected_attributes[name] = obj + self.collected_annotations[name] = ( + annotation, + True, + ) + else: + self.collected_annotations[name] = ( + annotation, + False, + ) + if obj is None and _is_mapped_annotation(annotation, cls): + collected_attributes[name] = MappedColumn() + elif name in clsdict_view: + collected_attributes[name] = obj if inherited_table_args and not tablename: table_args = None @@ -618,46 +716,55 @@ class _ClassScanMapperConfig(_MapperConfig): def _warn_for_decl_attributes(self, cls, key, c): if isinstance(c, expression.ColumnClause): util.warn( - "Attribute '%s' on class %s appears to be a non-schema " - "'sqlalchemy.sql.column()' " + f"Attribute '{key}' on class {cls} appears to " + "be a non-schema 'sqlalchemy.sql.column()' " "object; this won't be part of the declarative mapping" - % (key, cls) ) def _produce_column_copies( self, attributes_for_class, attribute_is_overridden ): cls = self.cls - dict_ = self.dict_ + dict_ = self.clsdict_view + collected_attributes = self.collected_attributes column_copies = self.column_copies # copy mixin columns to the mapped class - for name, obj, is_dataclass in attributes_for_class(): - if isinstance(obj, Column): + for name, obj, annotation, is_dataclass in attributes_for_class(): + if isinstance(obj, (Column, MappedColumn)): if attribute_is_overridden(name, obj): # if column has been overridden # (like by the InstrumentedAttribute of the # superclass), skip continue - elif obj.foreign_keys: - raise exc.InvalidRequestError( - "Columns with foreign keys to other columns " - "must be declared as @declared_attr callables " - "on declarative mixin classes. For dataclass " - "field() objects, use a lambda:." - ) elif name not in dict_ and not ( "__table__" in dict_ and (obj.name or name) in dict_["__table__"].c ): + if obj.foreign_keys: + for fk in obj.foreign_keys: + if ( + fk._table_column is not None + and fk._table_column.table is None + ): + raise exc.InvalidRequestError( + "Columns with foreign keys to " + "non-table-bound " + "columns must be declared as " + "@declared_attr callables " + "on declarative mixin classes. " + "For dataclass " + "field() objects, use a lambda:." + ) + column_copies[obj] = copy_ = obj._copy() - copy_._creation_order = obj._creation_order + collected_attributes[name] = copy_ + setattr(cls, name, copy_) - dict_[name] = copy_ def _extract_mappable_attributes(self): cls = self.cls - dict_ = self.dict_ + collected_attributes = self.collected_attributes our_stuff = self.properties @@ -665,13 +772,17 @@ class _ClassScanMapperConfig(_MapperConfig): cls, "_sa_decl_prepare_nocascade", strict=True ) - for k in list(dict_): + for k in list(collected_attributes): if k in ("__table__", "__tablename__", "__mapper_args__"): continue - value = dict_[k] + value = collected_attributes[k] + if _is_declarative_props(value): + # @declared_attr in collected_attributes only occurs here for a + # @declared_attr that's directly on the mapped class; + # for a mixin, these have already been evaluated if value._cascading: util.warn( "Use of @declared_attr.cascading only applies to " @@ -689,13 +800,13 @@ class _ClassScanMapperConfig(_MapperConfig): ): # detect a QueryableAttribute that's already mapped being # assigned elsewhere in userland, turn into a synonym() - value = SynonymProperty(value.key) + value = Synonym(value.key) setattr(cls, k, value) if ( isinstance(value, tuple) and len(value) == 1 - and isinstance(value[0], (Column, MapperProperty)) + and isinstance(value[0], (Column, _MappedAttribute)) ): util.warn( "Ignoring declarative-like tuple value of attribute " @@ -703,12 +814,12 @@ class _ClassScanMapperConfig(_MapperConfig): "accidentally placed at the end of the line?" % k ) continue - elif not isinstance(value, (Column, MapperProperty)): + elif not isinstance(value, (Column, MapperProperty, _MapsColumns)): # using @declared_attr for some object that - # isn't Column/MapperProperty; remove from the dict_ + # isn't Column/MapperProperty; remove from the clsdict_view # and place the evaluated value onto the class. if not k.startswith("__"): - dict_.pop(k) + collected_attributes.pop(k) self._warn_for_decl_attributes(cls, k, value) if not late_mapped: setattr(cls, k, value) @@ -722,27 +833,37 @@ class _ClassScanMapperConfig(_MapperConfig): "for the MetaData instance when using a " "declarative base class." ) + elif isinstance(value, _IntrospectsAnnotations): + annotation, is_dataclass = self.collected_annotations.get( + k, (None, None) + ) + value.declarative_scan( + self.registry, cls, k, annotation, is_dataclass + ) our_stuff[k] = value def _extract_declared_columns(self): our_stuff = self.properties - # set up attributes in the order they were created - util.sort_dictionary( - our_stuff, key=lambda key: our_stuff[key]._creation_order - ) - # extract columns from the class dict declared_columns = self.declared_columns name_to_prop_key = collections.defaultdict(set) for key, c in list(our_stuff.items()): - if isinstance(c, (ColumnProperty, CompositeProperty)): - for col in c.columns: - if isinstance(col, Column) and col.table is None: - _undefer_column_name(key, col) - if not isinstance(c, CompositeProperty): - name_to_prop_key[col.name].add(key) - declared_columns.add(col) + if isinstance(c, _MapsColumns): + for col in c.columns_to_assign: + if not isinstance(c, Composite): + name_to_prop_key[col.name].add(key) + declared_columns.add(col) + + # remove object from the dictionary that will be passed + # as mapper(properties={...}) if it is not a MapperProperty + # (i.e. this currently means it's a MappedColumn) + mp_to_assign = c.mapper_property_to_assign + if mp_to_assign: + our_stuff[key] = mp_to_assign + else: + del our_stuff[key] + elif isinstance(c, Column): _undefer_column_name(key, c) name_to_prop_key[c.name].add(key) @@ -769,16 +890,12 @@ class _ClassScanMapperConfig(_MapperConfig): cls = self.cls tablename = self.tablename table_args = self.table_args - dict_ = self.dict_ + clsdict_view = self.clsdict_view declared_columns = self.declared_columns manager = attributes.manager_of_class(cls) - declared_columns = self.declared_columns = sorted( - declared_columns, key=lambda c: c._creation_order - ) - - if "__table__" not in dict_ and table is None: + if "__table__" not in clsdict_view and table is None: if hasattr(cls, "__table_cls__"): table_cls = util.unbound_method_to_callable(cls.__table_cls__) else: @@ -796,11 +913,11 @@ class _ClassScanMapperConfig(_MapperConfig): else: args = table_args - autoload_with = dict_.get("__autoload_with__") + autoload_with = clsdict_view.get("__autoload_with__") if autoload_with: table_kw["autoload_with"] = autoload_with - autoload = dict_.get("__autoload__") + autoload = clsdict_view.get("__autoload__") if autoload: table_kw["autoload"] = True @@ -1095,18 +1212,21 @@ def _add_attribute(cls, key, value): _undefer_column_name(key, value) cls.__table__.append_column(value, replace_existing=True) cls.__mapper__.add_property(key, value) - elif isinstance(value, ColumnProperty): - for col in value.columns: - if isinstance(col, Column) and col.table is None: - _undefer_column_name(key, col) - cls.__table__.append_column(col, replace_existing=True) - cls.__mapper__.add_property(key, value) + elif isinstance(value, _MapsColumns): + mp = value.mapper_property_to_assign + for col in value.columns_to_assign: + _undefer_column_name(key, col) + cls.__table__.append_column(col, replace_existing=True) + if not mp: + cls.__mapper__.add_property(key, col) + if mp: + cls.__mapper__.add_property(key, mp) elif isinstance(value, MapperProperty): cls.__mapper__.add_property(key, value) elif isinstance(value, QueryableAttribute) and value.key != key: # detect a QueryableAttribute that's already mapped being # assigned elsewhere in userland, turn into a synonym() - value = SynonymProperty(value.key) + value = Synonym(value.key) cls.__mapper__.add_property(key, value) else: type.__setattr__(cls, key, value) @@ -1124,7 +1244,7 @@ def _del_attribute(cls, key): ): value = cls.__dict__[key] if isinstance( - value, (Column, ColumnProperty, MapperProperty, QueryableAttribute) + value, (Column, _MapsColumns, MapperProperty, QueryableAttribute) ): raise NotImplementedError( "Can't un-map individual mapped attributes on a mapped class." diff --git a/lib/sqlalchemy/orm/descriptor_props.py b/lib/sqlalchemy/orm/descriptor_props.py index 5e67b64cd9..4526a8b332 100644 --- a/lib/sqlalchemy/orm/descriptor_props.py +++ b/lib/sqlalchemy/orm/descriptor_props.py @@ -10,14 +10,26 @@ that exist as configurational elements, but don't participate as actively in the load/persist ORM loop. """ +import inspect +import itertools +import operator +import typing from typing import Any -from typing import Type +from typing import Callable +from typing import List +from typing import Optional +from typing import Tuple from typing import TypeVar +from typing import Union from . import attributes from . import util as orm_util +from .base import Mapped +from .interfaces import _IntrospectsAnnotations +from .interfaces import _MapsColumns from .interfaces import MapperProperty from .interfaces import PropComparator +from .util import _extract_mapped_subtype from .util import _none_set from .. import event from .. import exc as sa_exc @@ -27,6 +39,9 @@ from .. import util from ..sql import expression from ..sql import operators +if typing.TYPE_CHECKING: + from .properties import MappedColumn + _T = TypeVar("_T", bound=Any) _PT = TypeVar("_PT", bound=Any) @@ -92,30 +107,48 @@ class DescriptorProperty(MapperProperty[_T]): mapper.class_manager.instrument_attribute(self.key, proxy_attr) -class CompositeProperty(DescriptorProperty[_T]): +class Composite( + _MapsColumns[_T], _IntrospectsAnnotations, DescriptorProperty[_T] +): """Defines a "composite" mapped attribute, representing a collection of columns as one attribute. - :class:`.CompositeProperty` is constructed using the :func:`.composite` + :class:`.Composite` is constructed using the :func:`.composite` function. + .. versionchanged:: 2.0 Renamed :class:`_orm.CompositeProperty` + to :class:`_orm.Composite`. The old name + :class:`_orm.CompositeProperty` remains as an alias. + .. seealso:: :ref:`mapper_composite` """ - def __init__(self, class_: Type[_T], *attrs, **kwargs): - super(CompositeProperty, self).__init__() + composite_class: Union[type, Callable[..., type]] + attrs: Tuple[ + Union[sql.ColumnElement[Any], "MappedColumn", str, Mapped[Any]], ... + ] + + def __init__(self, class_=None, *attrs, **kwargs): + super().__init__() + + if isinstance(class_, (Mapped, str, sql.ColumnElement)): + self.attrs = (class_,) + attrs + # will initialize within declarative_scan + self.composite_class = None # type: ignore + else: + self.composite_class = class_ + self.attrs = attrs - self.attrs = attrs - self.composite_class = class_ self.active_history = kwargs.get("active_history", False) self.deferred = kwargs.get("deferred", False) self.group = kwargs.get("group", None) self.comparator_factory = kwargs.pop( "comparator_factory", self.__class__.Comparator ) + self._generated_composite_accessor = None if "info" in kwargs: self.info = kwargs.pop("info") @@ -123,11 +156,26 @@ class CompositeProperty(DescriptorProperty[_T]): self._create_descriptor() def instrument_class(self, mapper): - super(CompositeProperty, self).instrument_class(mapper) + super().instrument_class(mapper) self._setup_event_handlers() + def _composite_values_from_instance(self, value): + if self._generated_composite_accessor: + return self._generated_composite_accessor(value) + else: + try: + accessor = value.__composite_values__ + except AttributeError as ae: + raise sa_exc.InvalidRequestError( + f"Composite class {self.composite_class.__name__} is not " + f"a dataclass and does not define a __composite_values__()" + " method; can't get state" + ) from ae + else: + return accessor() + def do_init(self): - """Initialization which occurs after the :class:`.CompositeProperty` + """Initialization which occurs after the :class:`.Composite` has been associated with its parent mapper. """ @@ -181,7 +229,8 @@ class CompositeProperty(DescriptorProperty[_T]): setattr(instance, key, None) else: for key, value in zip( - self._attribute_keys, value.__composite_values__() + self._attribute_keys, + self._composite_values_from_instance(value), ): setattr(instance, key, value) @@ -196,18 +245,74 @@ class CompositeProperty(DescriptorProperty[_T]): self.descriptor = property(fget, fset, fdel) + @util.preload_module("sqlalchemy.orm.properties") + @util.preload_module("sqlalchemy.orm.decl_base") + def declarative_scan( + self, registry, cls, key, annotation, is_dataclass_field + ): + MappedColumn = util.preloaded.orm_properties.MappedColumn + decl_base = util.preloaded.orm_decl_base + + argument = _extract_mapped_subtype( + annotation, + cls, + key, + MappedColumn, + self.composite_class is None, + is_dataclass_field, + ) + + if argument and self.composite_class is None: + if isinstance(argument, str) or hasattr( + argument, "__forward_arg__" + ): + raise sa_exc.ArgumentError( + f"Can't use forward ref {argument} for composite " + f"class argument" + ) + self.composite_class = argument + insp = inspect.signature(self.composite_class) + for param, attr in itertools.zip_longest( + insp.parameters.values(), self.attrs + ): + if param is None or attr is None: + raise sa_exc.ArgumentError( + f"number of arguments to {self.composite_class.__name__} " + f"class and number of attributes don't match" + ) + if isinstance(attr, MappedColumn): + attr.declarative_scan_for_composite( + registry, cls, key, param.name, param.annotation + ) + elif isinstance(attr, schema.Column): + decl_base._undefer_column_name(param.name, attr) + + if not hasattr(cls, "__composite_values__"): + getter = operator.attrgetter( + *[p.name for p in insp.parameters.values()] + ) + if len(insp.parameters) == 1: + self._generated_composite_accessor = lambda obj: (getter(obj),) + else: + self._generated_composite_accessor = getter + @util.memoized_property def _comparable_elements(self): return [getattr(self.parent.class_, prop.key) for prop in self.props] @util.memoized_property + @util.preload_module("orm.properties") def props(self): props = [] + MappedColumn = util.preloaded.orm_properties.MappedColumn + for attr in self.attrs: if isinstance(attr, str): prop = self.parent.get_property(attr, _configure_mappers=False) elif isinstance(attr, schema.Column): prop = self.parent._columntoproperty[attr] + elif isinstance(attr, MappedColumn): + prop = self.parent._columntoproperty[attr.column] elif isinstance(attr, attributes.InstrumentedAttribute): prop = attr.property else: @@ -220,8 +325,22 @@ class CompositeProperty(DescriptorProperty[_T]): return props @property + @util.preload_module("orm.properties") def columns(self): - return [a for a in self.attrs if isinstance(a, schema.Column)] + MappedColumn = util.preloaded.orm_properties.MappedColumn + return [ + a.column if isinstance(a, MappedColumn) else a + for a in self.attrs + if isinstance(a, (schema.Column, MappedColumn)) + ] + + @property + def mapper_property_to_assign(self) -> Optional["MapperProperty[_T]"]: + return self + + @property + def columns_to_assign(self) -> List[schema.Column]: + return [c for c in self.columns if c.table is None] def _setup_arguments_on_columns(self): """Propagate configuration arguments made on this composite @@ -351,9 +470,7 @@ class CompositeProperty(DescriptorProperty[_T]): class CompositeBundle(orm_util.Bundle): def __init__(self, property_, expr): self.property = property_ - super(CompositeProperty.CompositeBundle, self).__init__( - property_.key, *expr - ) + super().__init__(property_.key, *expr) def create_row_processor(self, query, procs, labels): def proc(row): @@ -365,7 +482,7 @@ class CompositeProperty(DescriptorProperty[_T]): class Comparator(PropComparator[_PT]): """Produce boolean, comparison, and other operators for - :class:`.CompositeProperty` attributes. + :class:`.Composite` attributes. See the example in :ref:`composite_operations` for an overview of usage , as well as the documentation for :class:`.PropComparator`. @@ -402,7 +519,7 @@ class CompositeProperty(DescriptorProperty[_T]): "proxy_key": self.prop.key, } ) - return CompositeProperty.CompositeBundle(self.prop, clauses) + return Composite.CompositeBundle(self.prop, clauses) def _bulk_update_tuples(self, value): if isinstance(value, sql.elements.BindParameter): @@ -411,7 +528,7 @@ class CompositeProperty(DescriptorProperty[_T]): if value is None: values = [None for key in self.prop._attribute_keys] elif isinstance(value, self.prop.composite_class): - values = value.__composite_values__() + values = self.prop._composite_values_from_instance(value) else: raise sa_exc.ArgumentError( "Can't UPDATE composite attribute %s to %r" @@ -434,7 +551,7 @@ class CompositeProperty(DescriptorProperty[_T]): if other is None: values = [None] * len(self.prop._comparable_elements) else: - values = other.__composite_values__() + values = self.prop._composite_values_from_instance(other) comparisons = [ a == b for a, b in zip(self.prop._comparable_elements, values) ] @@ -477,7 +594,7 @@ class ConcreteInheritedProperty(DescriptorProperty[_T]): return comparator_callable def __init__(self): - super(ConcreteInheritedProperty, self).__init__() + super().__init__() def warn(): raise AttributeError( @@ -502,7 +619,24 @@ class ConcreteInheritedProperty(DescriptorProperty[_T]): self.descriptor = NoninheritedConcreteProp() -class SynonymProperty(DescriptorProperty[_T]): +class Synonym(DescriptorProperty[_T]): + """Denote an attribute name as a synonym to a mapped property, + in that the attribute will mirror the value and expression behavior + of another attribute. + + :class:`.Synonym` is constructed using the :func:`_orm.synonym` + function. + + .. versionchanged:: 2.0 Renamed :class:`_orm.SynonymProperty` + to :class:`_orm.Synonym`. The old name + :class:`_orm.SynonymProperty` remains as an alias. + + .. seealso:: + + :ref:`synonyms` - Overview of synonyms + + """ + def __init__( self, name, @@ -512,7 +646,7 @@ class SynonymProperty(DescriptorProperty[_T]): doc=None, info=None, ): - super(SynonymProperty, self).__init__() + super().__init__() self.name = name self.map_column = map_column diff --git a/lib/sqlalchemy/orm/dynamic.py b/lib/sqlalchemy/orm/dynamic.py index ade47480d5..3d9c61c205 100644 --- a/lib/sqlalchemy/orm/dynamic.py +++ b/lib/sqlalchemy/orm/dynamic.py @@ -28,7 +28,7 @@ from ..engine import result @log.class_logger -@relationships.RelationshipProperty.strategy_for(lazy="dynamic") +@relationships.Relationship.strategy_for(lazy="dynamic") class DynaLoader(strategies.AbstractRelationshipLoader): def init_class_attribute(self, mapper): self.is_class_level = True diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index b9a5aaf518..1f9ec78f76 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -20,7 +20,12 @@ import collections import typing from typing import Any from typing import cast +from typing import List +from typing import Optional +from typing import Tuple +from typing import Type from typing import TypeVar +from typing import Union from . import exc as orm_exc from . import path_registry @@ -41,8 +46,15 @@ from .. import util from ..sql import operators from ..sql import roles from ..sql import visitors +from ..sql._typing import _ColumnsClauseElement from ..sql.base import ExecutableOption from ..sql.cache_key import HasCacheKey +from ..sql.schema import Column +from ..sql.type_api import TypeEngine +from ..util.typing import TypedDict + +if typing.TYPE_CHECKING: + from .decl_api import RegistryType _T = TypeVar("_T", bound=Any) @@ -85,6 +97,54 @@ class ORMFromClauseRole(roles.StrictFromClauseRole): _role_name = "ORM mapped entity, aliased entity, or FROM expression" +class ORMColumnDescription(TypedDict): + name: str + type: Union[Type, TypeEngine] + aliased: bool + expr: _ColumnsClauseElement + entity: Optional[_ColumnsClauseElement] + + +class _IntrospectsAnnotations: + __slots__ = () + + def declarative_scan( + self, + registry: "RegistryType", + cls: type, + key: str, + annotation: Optional[type], + is_dataclass_field: Optional[bool], + ) -> None: + """Perform class-specific initializaton at early declarative scanning + time. + + .. versionadded:: 2.0 + + """ + + +class _MapsColumns(_MappedAttribute[_T]): + """interface for declarative-capable construct that delivers one or more + Column objects to the declarative process to be part of a Table. + """ + + __slots__ = () + + @property + def mapper_property_to_assign(self) -> Optional["MapperProperty[_T]"]: + """return a MapperProperty to be assigned to the declarative mapping""" + raise NotImplementedError() + + @property + def columns_to_assign(self) -> List[Column]: + """A list of Column objects that should be declaratively added to the + new Table object. + + """ + raise NotImplementedError() + + @inspection._self_inspects class MapperProperty( HasCacheKey, _MappedAttribute[_T], InspectionAttr, util.MemoizedSlots @@ -96,7 +156,7 @@ class MapperProperty( an instance of :class:`.ColumnProperty`, and a reference to another class produced by :func:`_orm.relationship`, represented in the mapping as an instance of - :class:`.RelationshipProperty`. + :class:`.Relationship`. """ @@ -118,7 +178,7 @@ class MapperProperty( This collection is checked before the 'cascade_iterator' method is called. - The collection typically only applies to a RelationshipProperty. + The collection typically only applies to a Relationship. """ @@ -132,7 +192,7 @@ class MapperProperty( def _links_to_entity(self): """True if this MapperProperty refers to a mapped entity. - Should only be True for RelationshipProperty, False for all others. + Should only be True for Relationship, False for all others. """ raise NotImplementedError() @@ -189,7 +249,7 @@ class MapperProperty( Note that the 'cascade' collection on this MapperProperty is checked first for the given type before cascade_iterator is called. - This method typically only applies to RelationshipProperty. + This method typically only applies to Relationship. """ @@ -323,7 +383,7 @@ class PropComparator( be redefined at both the Core and ORM level. :class:`.PropComparator` is the base class of operator redefinition for ORM-level operations, including those of :class:`.ColumnProperty`, - :class:`.RelationshipProperty`, and :class:`.CompositeProperty`. + :class:`.Relationship`, and :class:`.Composite`. User-defined subclasses of :class:`.PropComparator` may be created. The built-in Python comparison and math operator methods, such as @@ -339,19 +399,19 @@ class PropComparator( from sqlalchemy.orm.properties import \ ColumnProperty,\ - CompositeProperty,\ - RelationshipProperty + Composite,\ + Relationship class MyColumnComparator(ColumnProperty.Comparator): def __eq__(self, other): return self.__clause_element__() == other - class MyRelationshipComparator(RelationshipProperty.Comparator): + class MyRelationshipComparator(Relationship.Comparator): def any(self, expression): "define the 'any' operation" # ... - class MyCompositeComparator(CompositeProperty.Comparator): + class MyCompositeComparator(Composite.Comparator): def __gt__(self, other): "redefine the 'greater than' operation" @@ -386,9 +446,9 @@ class PropComparator( :class:`.ColumnProperty.Comparator` - :class:`.RelationshipProperty.Comparator` + :class:`.Relationship.Comparator` - :class:`.CompositeProperty.Comparator` + :class:`.Composite.Comparator` :class:`.ColumnOperators` @@ -552,7 +612,7 @@ class PropComparator( given criterion. The usual implementation of ``any()`` is - :meth:`.RelationshipProperty.Comparator.any`. + :meth:`.Relationship.Comparator.any`. :param criterion: an optional ClauseElement formulated against the member class' table or attributes. @@ -570,7 +630,7 @@ class PropComparator( given criterion. The usual implementation of ``has()`` is - :meth:`.RelationshipProperty.Comparator.has`. + :meth:`.Relationship.Comparator.has`. :param criterion: an optional ClauseElement formulated against the member class' table or attributes. @@ -606,10 +666,13 @@ class StrategizedProperty(MapperProperty[_T]): "strategy", "_wildcard_token", "_default_path_loader_key", + "strategy_key", ) inherit_cache = True strategy_wildcard_key = None + strategy_key: Tuple[Any, ...] + def _memoized_attr__wildcard_token(self): return ( f"{self.strategy_wildcard_key}:{path_registry._WILDCARD_TOKEN}", diff --git a/lib/sqlalchemy/orm/mapped_collection.py b/lib/sqlalchemy/orm/mapped_collection.py new file mode 100644 index 0000000000..75abeef4cd --- /dev/null +++ b/lib/sqlalchemy/orm/mapped_collection.py @@ -0,0 +1,232 @@ +# orm/collections.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +import operator +from typing import Any +from typing import Callable +from typing import Dict +from typing import Type +from typing import TypeVar + +from . import base +from .collections import collection +from .. import exc as sa_exc +from .. import util +from ..sql import coercions +from ..sql import expression +from ..sql import roles + +_KT = TypeVar("_KT", bound=Any) +_VT = TypeVar("_VT", bound=Any) + + +class _PlainColumnGetter: + """Plain column getter, stores collection of Column objects + directly. + + Serializes to a :class:`._SerializableColumnGetterV2` + which has more expensive __call__() performance + and some rare caveats. + + """ + + __slots__ = ("cols", "composite") + + def __init__(self, cols): + self.cols = cols + self.composite = len(cols) > 1 + + def __reduce__(self): + return _SerializableColumnGetterV2._reduce_from_cols(self.cols) + + def _cols(self, mapper): + return self.cols + + def __call__(self, value): + state = base.instance_state(value) + m = base._state_mapper(state) + + key = [ + m._get_state_attr_by_column(state, state.dict, col) + for col in self._cols(m) + ] + + if self.composite: + return tuple(key) + else: + return key[0] + + +class _SerializableColumnGetterV2(_PlainColumnGetter): + """Updated serializable getter which deals with + multi-table mapped classes. + + Two extremely unusual cases are not supported. + Mappings which have tables across multiple metadata + objects, or which are mapped to non-Table selectables + linked across inheriting mappers may fail to function + here. + + """ + + __slots__ = ("colkeys",) + + def __init__(self, colkeys): + self.colkeys = colkeys + self.composite = len(colkeys) > 1 + + def __reduce__(self): + return self.__class__, (self.colkeys,) + + @classmethod + def _reduce_from_cols(cls, cols): + def _table_key(c): + if not isinstance(c.table, expression.TableClause): + return None + else: + return c.table.key + + colkeys = [(c.key, _table_key(c)) for c in cols] + return _SerializableColumnGetterV2, (colkeys,) + + def _cols(self, mapper): + cols = [] + metadata = getattr(mapper.local_table, "metadata", None) + for (ckey, tkey) in self.colkeys: + if tkey is None or metadata is None or tkey not in metadata: + cols.append(mapper.local_table.c[ckey]) + else: + cols.append(metadata.tables[tkey].c[ckey]) + return cols + + +def column_mapped_collection(mapping_spec): + """A dictionary-based collection type with column-based keying. + + Returns a :class:`.MappedCollection` factory with a keying function + generated from mapping_spec, which may be a Column or a sequence + of Columns. + + The key value must be immutable for the lifetime of the object. You + can not, for example, map on foreign key values if those key values will + change during the session, i.e. from None to a database-assigned integer + after a session flush. + + """ + cols = [ + coercions.expect(roles.ColumnArgumentRole, q, argname="mapping_spec") + for q in util.to_list(mapping_spec) + ] + keyfunc = _PlainColumnGetter(cols) + return _mapped_collection_cls(keyfunc) + + +def attribute_mapped_collection(attr_name: str) -> Type["MappedCollection"]: + """A dictionary-based collection type with attribute-based keying. + + Returns a :class:`.MappedCollection` factory with a keying based on the + 'attr_name' attribute of entities in the collection, where ``attr_name`` + is the string name of the attribute. + + .. warning:: the key value must be assigned to its final value + **before** it is accessed by the attribute mapped collection. + Additionally, changes to the key attribute are **not tracked** + automatically, which means the key in the dictionary is not + automatically synchronized with the key value on the target object + itself. See the section :ref:`key_collections_mutations` + for an example. + + """ + getter = operator.attrgetter(attr_name) + return _mapped_collection_cls(getter) + + +def mapped_collection( + keyfunc: Callable[[Any], _KT] +) -> Type["MappedCollection[_KT, Any]"]: + """A dictionary-based collection type with arbitrary keying. + + Returns a :class:`.MappedCollection` factory with a keying function + generated from keyfunc, a callable that takes an entity and returns a + key value. + + The key value must be immutable for the lifetime of the object. You + can not, for example, map on foreign key values if those key values will + change during the session, i.e. from None to a database-assigned integer + after a session flush. + + """ + return _mapped_collection_cls(keyfunc) + + +class MappedCollection(Dict[_KT, _VT]): + """A basic dictionary-based collection class. + + Extends dict with the minimal bag semantics that collection + classes require. ``set`` and ``remove`` are implemented in terms + of a keying function: any callable that takes an object and + returns an object for use as a dictionary key. + + """ + + def __init__(self, keyfunc): + """Create a new collection with keying provided by keyfunc. + + keyfunc may be any callable that takes an object and returns an object + for use as a dictionary key. + + The keyfunc will be called every time the ORM needs to add a member by + value-only (such as when loading instances from the database) or + remove a member. The usual cautions about dictionary keying apply- + ``keyfunc(object)`` should return the same output for the life of the + collection. Keying based on mutable properties can result in + unreachable instances "lost" in the collection. + + """ + self.keyfunc = keyfunc + + @classmethod + def _unreduce(cls, keyfunc, values): + mp = MappedCollection(keyfunc) + mp.update(values) + return mp + + def __reduce__(self): + return (MappedCollection._unreduce, (self.keyfunc, dict(self))) + + @collection.appender + @collection.internally_instrumented + def set(self, value, _sa_initiator=None): + """Add an item by value, consulting the keyfunc for the key.""" + + key = self.keyfunc(value) + self.__setitem__(key, value, _sa_initiator) + + @collection.remover + @collection.internally_instrumented + def remove(self, value, _sa_initiator=None): + """Remove an item by value, consulting the keyfunc for the key.""" + + key = self.keyfunc(value) + # Let self[key] raise if key is not in this collection + # testlib.pragma exempt:__ne__ + if self[key] != value: + raise sa_exc.InvalidRequestError( + "Can not remove '%s': collection holds '%s' for key '%s'. " + "Possible cause: is the MappedCollection key function " + "based on mutable properties or properties that only obtain " + "values after flush?" % (value, self[key], key) + ) + self.__delitem__(key, _sa_initiator) + + +def _mapped_collection_cls(keyfunc): + class _MKeyfuncMapped(MappedCollection): + def __init__(self): + super().__init__(keyfunc) + + return _MKeyfuncMapped diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index fdf065488a..cd0d1e8203 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -580,7 +580,16 @@ class Mapper( self.version_id_prop = version_id_col self.version_id_col = None else: - self.version_id_col = version_id_col + self.version_id_col = ( + coercions.expect( + roles.ColumnArgumentOrKeyRole, + version_id_col, + argname="version_id_col", + ) + if version_id_col is not None + else None + ) + if version_id_generator is False: self.version_id_generator = False elif version_id_generator is None: @@ -2473,7 +2482,7 @@ class Mapper( @HasMemoized.memoized_attribute @util.preload_module("sqlalchemy.orm.descriptor_props") def synonyms(self): - """Return a namespace of all :class:`.SynonymProperty` + """Return a namespace of all :class:`.Synonym` properties maintained by this :class:`_orm.Mapper`. .. seealso:: @@ -2485,7 +2494,7 @@ class Mapper( """ descriptor_props = util.preloaded.orm_descriptor_props - return self._filter_properties(descriptor_props.SynonymProperty) + return self._filter_properties(descriptor_props.Synonym) @property def entity_namespace(self): @@ -2508,7 +2517,7 @@ class Mapper( @util.preload_module("sqlalchemy.orm.relationships") @HasMemoized.memoized_attribute def relationships(self): - """A namespace of all :class:`.RelationshipProperty` properties + """A namespace of all :class:`.Relationship` properties maintained by this :class:`_orm.Mapper`. .. warning:: @@ -2531,13 +2540,13 @@ class Mapper( """ return self._filter_properties( - util.preloaded.orm_relationships.RelationshipProperty + util.preloaded.orm_relationships.Relationship ) @HasMemoized.memoized_attribute @util.preload_module("sqlalchemy.orm.descriptor_props") def composites(self): - """Return a namespace of all :class:`.CompositeProperty` + """Return a namespace of all :class:`.Composite` properties maintained by this :class:`_orm.Mapper`. .. seealso:: @@ -2548,7 +2557,7 @@ class Mapper( """ return self._filter_properties( - util.preloaded.orm_descriptor_props.CompositeProperty + util.preloaded.orm_descriptor_props.Composite ) def _filter_properties(self, type_): diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index b035dbef2f..f28c45fab8 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -13,37 +13,60 @@ mapped attributes. """ from typing import Any +from typing import cast +from typing import List +from typing import Optional +from typing import Set from typing import TypeVar from . import attributes from . import strategy_options -from .descriptor_props import CompositeProperty +from .base import SQLCoreOperations +from .descriptor_props import Composite from .descriptor_props import ConcreteInheritedProperty -from .descriptor_props import SynonymProperty +from .descriptor_props import Synonym +from .interfaces import _IntrospectsAnnotations +from .interfaces import _MapsColumns +from .interfaces import MapperProperty from .interfaces import PropComparator from .interfaces import StrategizedProperty -from .relationships import RelationshipProperty +from .relationships import Relationship +from .util import _extract_mapped_subtype from .util import _orm_full_deannotate +from .. import exc as sa_exc +from .. import ForeignKey from .. import log from .. import sql from .. import util from ..sql import coercions +from ..sql import operators from ..sql import roles +from ..sql import sqltypes +from ..sql.schema import Column +from ..util.typing import de_optionalize_union_types +from ..util.typing import de_stringify_annotation +from ..util.typing import is_fwd_ref +from ..util.typing import NoneType _T = TypeVar("_T", bound=Any) _PT = TypeVar("_PT", bound=Any) __all__ = [ "ColumnProperty", - "CompositeProperty", + "Composite", "ConcreteInheritedProperty", - "RelationshipProperty", - "SynonymProperty", + "Relationship", + "Synonym", ] @log.class_logger -class ColumnProperty(StrategizedProperty[_T]): +class ColumnProperty( + _MapsColumns[_T], + StrategizedProperty[_T], + _IntrospectsAnnotations, + log.Identified, +): """Describes an object attribute that corresponds to a table column. Public constructor is the :func:`_orm.column_property` function. @@ -65,7 +88,6 @@ class ColumnProperty(StrategizedProperty[_T]): "active_history", "expire_on_flush", "doc", - "strategy_key", "_creation_order", "_is_polymorphic_discriminator", "_mapped_by_synonym", @@ -84,8 +106,8 @@ class ColumnProperty(StrategizedProperty[_T]): coercions.expect(roles.LabeledColumnExprRole, c) for c in columns ] self.columns = [ - coercions.expect( - roles.LabeledColumnExprRole, _orm_full_deannotate(c) + _orm_full_deannotate( + coercions.expect(roles.LabeledColumnExprRole, c) ) for c in columns ] @@ -130,6 +152,27 @@ class ColumnProperty(StrategizedProperty[_T]): if self.raiseload: self.strategy_key += (("raiseload", True),) + def declarative_scan( + self, registry, cls, key, annotation, is_dataclass_field + ): + column = self.columns[0] + if column.key is None: + column.key = key + if column.name is None: + column.name = key + + @property + def mapper_property_to_assign(self) -> Optional["MapperProperty[_T]"]: + return self + + @property + def columns_to_assign(self) -> List[Column]: + return [ + c + for c in self.columns + if isinstance(c, Column) and c.table is None + ] + def _memoized_attr__renders_in_subqueries(self): return ("deferred", True) not in self.strategy_key or ( self not in self.parent._readonly_props @@ -197,7 +240,7 @@ class ColumnProperty(StrategizedProperty[_T]): ) def do_init(self): - super(ColumnProperty, self).do_init() + super().do_init() if len(self.columns) > 1 and set(self.parent.primary_key).issuperset( self.columns @@ -364,3 +407,135 @@ class ColumnProperty(StrategizedProperty[_T]): if not self.parent or not self.key: return object.__repr__(self) return str(self.parent.class_.__name__) + "." + self.key + + +class MappedColumn( + SQLCoreOperations[_T], + operators.ColumnOperators[SQLCoreOperations], + _IntrospectsAnnotations, + _MapsColumns[_T], +): + """Maps a single :class:`_schema.Column` on a class. + + :class:`_orm.MappedColumn` is a specialization of the + :class:`_orm.ColumnProperty` class and is oriented towards declarative + configuration. + + To construct :class:`_orm.MappedColumn` objects, use the + :func:`_orm.mapped_column` constructor function. + + .. versionadded:: 2.0 + + + """ + + __slots__ = ( + "column", + "_creation_order", + "foreign_keys", + "_has_nullable", + "deferred", + ) + + deferred: bool + column: Column[_T] + foreign_keys: Optional[Set[ForeignKey]] + + def __init__(self, *arg, **kw): + self.deferred = kw.pop("deferred", False) + self.column = cast("Column[_T]", Column(*arg, **kw)) + self.foreign_keys = self.column.foreign_keys + self._has_nullable = "nullable" in kw + util.set_creation_order(self) + + def _copy(self, **kw): + new = self.__class__.__new__(self.__class__) + new.column = self.column._copy(**kw) + new.deferred = self.deferred + new.foreign_keys = new.column.foreign_keys + new._has_nullable = self._has_nullable + util.set_creation_order(new) + return new + + @property + def mapper_property_to_assign(self) -> Optional["MapperProperty[_T]"]: + if self.deferred: + return ColumnProperty(self.column, deferred=True) + else: + return None + + @property + def columns_to_assign(self) -> List[Column]: + return [self.column] + + def __clause_element__(self): + return self.column + + def operate(self, op, *other, **kwargs): + return op(self.__clause_element__(), *other, **kwargs) + + def reverse_operate(self, op, other, **kwargs): + col = self.__clause_element__() + return op(col._bind_param(op, other), col, **kwargs) + + def declarative_scan( + self, registry, cls, key, annotation, is_dataclass_field + ): + column = self.column + if column.key is None: + column.key = key + if column.name is None: + column.name = key + + sqltype = column.type + + argument = _extract_mapped_subtype( + annotation, + cls, + key, + MappedColumn, + sqltype._isnull and not self.column.foreign_keys, + is_dataclass_field, + ) + if argument is None: + return + + self._init_column_for_annotation(cls, registry, argument) + + @util.preload_module("sqlalchemy.orm.decl_base") + def declarative_scan_for_composite( + self, registry, cls, key, param_name, param_annotation + ): + decl_base = util.preloaded.orm_decl_base + decl_base._undefer_column_name(param_name, self.column) + self._init_column_for_annotation(cls, registry, param_annotation) + + def _init_column_for_annotation(self, cls, registry, argument): + sqltype = self.column.type + + nullable = False + + if hasattr(argument, "__origin__"): + nullable = NoneType in argument.__args__ + + if not self._has_nullable: + self.column.nullable = nullable + + if sqltype._isnull and not self.column.foreign_keys: + sqltype = None + our_type = de_optionalize_union_types(argument) + + if is_fwd_ref(our_type): + our_type = de_stringify_annotation(cls, our_type) + + if registry.type_annotation_map: + sqltype = registry.type_annotation_map.get(our_type) + if sqltype is None: + sqltype = sqltypes._type_map_get(our_type) + + if sqltype is None: + raise sa_exc.ArgumentError( + f"Could not locate SQLAlchemy Core " + f"type for Python type: {our_type}" + ) + self.column.type = sqltype diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 15259f130c..61174487ad 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -21,7 +21,12 @@ database to return iterable result sets. import collections.abc as collections_abc import itertools import operator -import typing +from typing import Any +from typing import Generic +from typing import Iterable +from typing import List +from typing import Optional +from typing import TypeVar from . import exc as orm_exc from . import interfaces @@ -35,8 +40,9 @@ from .context import LABEL_STYLE_LEGACY_ORM from .context import ORMCompileState from .context import ORMFromStatementCompileState from .context import QueryContext +from .interfaces import ORMColumnDescription from .interfaces import ORMColumnsClauseRole -from .util import aliased +from .util import AliasedClass from .util import object_mapper from .util import with_parent from .. import exc as sa_exc @@ -45,16 +51,19 @@ from .. import inspection from .. import log from .. import sql from .. import util +from ..engine import Result from ..sql import coercions from ..sql import expression from ..sql import roles from ..sql import Select from ..sql import util as sql_util from ..sql import visitors +from ..sql._typing import _FromClauseElement from ..sql.annotation import SupportsCloneAnnotations from ..sql.base import _entity_namespace_key from ..sql.base import _generative from ..sql.base import Executable +from ..sql.expression import Exists from ..sql.selectable import _MemoizedSelectEntities from ..sql.selectable import _SelectFromElements from ..sql.selectable import ForUpdateArg @@ -67,9 +76,12 @@ from ..sql.selectable import SelectBase from ..sql.selectable import SelectStatementGrouping from ..sql.visitors import InternalTraversal -__all__ = ["Query", "QueryContext", "aliased"] -SelfQuery = typing.TypeVar("SelfQuery", bound="Query") +__all__ = ["Query", "QueryContext"] + +_T = TypeVar("_T", bound=Any) + +SelfQuery = TypeVar("SelfQuery", bound="Query") @inspection._self_inspects @@ -80,7 +92,9 @@ class Query( HasPrefixes, HasSuffixes, HasHints, + log.Identified, Executable, + Generic[_T], ): """ORM-level SQL construction object. @@ -1040,7 +1054,7 @@ class Query( for prop in mapper.iterate_properties: if ( - isinstance(prop, relationships.RelationshipProperty) + isinstance(prop, relationships.Relationship) and prop.mapper is entity_zero.mapper ): property = prop # noqa @@ -1064,7 +1078,7 @@ class Query( if alias is not None: # TODO: deprecate - entity = aliased(entity, alias) + entity = AliasedClass(entity, alias) self._raw_columns = list(self._raw_columns) @@ -1992,7 +2006,9 @@ class Query( @_generative @_assertions(_no_clauseelement_condition) - def select_from(self: SelfQuery, *from_obj) -> SelfQuery: + def select_from( + self: SelfQuery, *from_obj: _FromClauseElement + ) -> SelfQuery: r"""Set the FROM clause of this :class:`.Query` explicitly. :meth:`.Query.select_from` is often used in conjunction with @@ -2144,7 +2160,7 @@ class Query( self._distinct = True return self - def all(self): + def all(self) -> List[_T]: """Return the results represented by this :class:`_query.Query` as a list. @@ -2183,7 +2199,7 @@ class Query( self._statement = statement return self - def first(self): + def first(self) -> Optional[_T]: """Return the first result of this ``Query`` or None if the result doesn't contain any row. @@ -2209,7 +2225,7 @@ class Query( else: return self.limit(1)._iter().first() - def one_or_none(self): + def one_or_none(self) -> Optional[_T]: """Return at most one result or raise an exception. Returns ``None`` if the query selects @@ -2235,7 +2251,7 @@ class Query( """ return self._iter().one_or_none() - def one(self): + def one(self) -> _T: """Return exactly one result or raise an exception. Raises ``sqlalchemy.orm.exc.NoResultFound`` if the query selects @@ -2255,7 +2271,7 @@ class Query( """ return self._iter().one() - def scalar(self): + def scalar(self) -> Any: """Return the first element of the first result or None if no rows present. If multiple rows are returned, raises MultipleResultsFound. @@ -2283,7 +2299,7 @@ class Query( except orm_exc.NoResultFound: return None - def __iter__(self): + def __iter__(self) -> Iterable[_T]: return self._iter().__iter__() def _iter(self): @@ -2309,7 +2325,7 @@ class Query( return result - def __str__(self): + def __str__(self) -> str: statement = self._statement_20() try: @@ -2327,7 +2343,7 @@ class Query( return fn(clause=statement, **kw) @property - def column_descriptions(self): + def column_descriptions(self) -> List[ORMColumnDescription]: """Return metadata about the columns which would be returned by this :class:`_query.Query`. @@ -2368,7 +2384,7 @@ class Query( return _column_descriptions(self, legacy=True) - def instances(self, result_proxy, context=None): + def instances(self, result_proxy: Result, context=None) -> Any: """Return an ORM result given a :class:`_engine.CursorResult` and :class:`.QueryContext`. @@ -2400,6 +2416,7 @@ class Query( if result._attributes.get("filtered", False): result = result.unique() + # TODO: isn't this supposed to be a list? return result @util.became_legacy_20( @@ -2436,7 +2453,7 @@ class Query( return loading.merge_result(self, iterator, load) - def exists(self): + def exists(self) -> Exists: """A convenience method that turns a query into an EXISTS subquery of the form EXISTS (SELECT 1 FROM ... WHERE ...). diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py index c5ea07051a..1b8f778c0a 100644 --- a/lib/sqlalchemy/orm/relationships.py +++ b/lib/sqlalchemy/orm/relationships.py @@ -13,10 +13,15 @@ SQL annotation and aliasing behavior focused on the `primaryjoin` and `secondaryjoin` aspects of :func:`_orm.relationship`. """ +from __future__ import annotations + import collections +from collections import abc import re +import typing from typing import Any from typing import Callable +from typing import Optional from typing import Type from typing import TypeVar from typing import Union @@ -26,11 +31,13 @@ from . import attributes from . import strategy_options from .base import _is_mapped_class from .base import state_str +from .interfaces import _IntrospectsAnnotations from .interfaces import MANYTOMANY from .interfaces import MANYTOONE from .interfaces import ONETOMANY from .interfaces import PropComparator from .interfaces import StrategizedProperty +from .util import _extract_mapped_subtype from .util import _orm_annotate from .util import _orm_deannotate from .util import CascadeOptions @@ -53,10 +60,26 @@ from ..sql.util import join_condition from ..sql.util import selectables_overlap from ..sql.util import visit_binary_product +if typing.TYPE_CHECKING: + from .mapper import Mapper + from .util import AliasedClass + from .util import AliasedInsp + _T = TypeVar("_T", bound=Any) _PT = TypeVar("_PT", bound=Any) +_RelationshipArgumentType = Union[ + str, + Type[_T], + Callable[[], Type[_T]], + "Mapper[_T]", + "AliasedClass[_T]", + Callable[[], "Mapper[_T]"], + Callable[[], "AliasedClass[_T]"], +] + + def remote(expr): """Annotate a portion of a primaryjoin expression with a 'remote' annotation. @@ -97,7 +120,9 @@ def foreign(expr): @log.class_logger -class RelationshipProperty(StrategizedProperty[_T]): +class Relationship( + _IntrospectsAnnotations, StrategizedProperty[_T], log.Identified +): """Describes an object property that holds a single item or list of items that correspond to a related database table. @@ -107,6 +132,10 @@ class RelationshipProperty(StrategizedProperty[_T]): :ref:`relationship_config_toplevel` + .. versionchanged:: 2.0 Renamed :class:`_orm.RelationshipProperty` + to :class:`_orm.Relationship`. The old name + :class:`_orm.RelationshipProperty` remains as an alias. + """ strategy_wildcard_key = strategy_options._RELATIONSHIP_TOKEN @@ -126,7 +155,7 @@ class RelationshipProperty(StrategizedProperty[_T]): def __init__( self, - argument: Union[str, Type[_T], Callable[[], Type[_T]]], + argument: Optional[_RelationshipArgumentType[_T]] = None, secondary=None, primaryjoin=None, secondaryjoin=None, @@ -162,7 +191,7 @@ class RelationshipProperty(StrategizedProperty[_T]): sync_backref=None, _legacy_inactive_history_style=False, ): - super(RelationshipProperty, self).__init__() + super(Relationship, self).__init__() self.uselist = uselist self.argument = argument @@ -221,9 +250,7 @@ class RelationshipProperty(StrategizedProperty[_T]): self.local_remote_pairs = _local_remote_pairs self.bake_queries = bake_queries self.load_on_pending = load_on_pending - self.comparator_factory = ( - comparator_factory or RelationshipProperty.Comparator - ) + self.comparator_factory = comparator_factory or Relationship.Comparator self.comparator = self.comparator_factory(self, None) util.set_creation_order(self) @@ -288,7 +315,7 @@ class RelationshipProperty(StrategizedProperty[_T]): class Comparator(PropComparator[_PT]): """Produce boolean, comparison, and other operators for - :class:`.RelationshipProperty` attributes. + :class:`.Relationship` attributes. See the documentation for :class:`.PropComparator` for a brief overview of ORM level operator definition. @@ -318,7 +345,7 @@ class RelationshipProperty(StrategizedProperty[_T]): of_type=None, extra_criteria=(), ): - """Construction of :class:`.RelationshipProperty.Comparator` + """Construction of :class:`.Relationship.Comparator` is internal to the ORM's attribute mechanics. """ @@ -340,7 +367,7 @@ class RelationshipProperty(StrategizedProperty[_T]): @util.memoized_property def entity(self): """The target entity referred to by this - :class:`.RelationshipProperty.Comparator`. + :class:`.Relationship.Comparator`. This is either a :class:`_orm.Mapper` or :class:`.AliasedInsp` object. @@ -360,7 +387,7 @@ class RelationshipProperty(StrategizedProperty[_T]): @util.memoized_property def mapper(self): """The target :class:`_orm.Mapper` referred to by this - :class:`.RelationshipProperty.Comparator`. + :class:`.Relationship.Comparator`. This is the "target" or "remote" side of the :func:`_orm.relationship`. @@ -411,7 +438,7 @@ class RelationshipProperty(StrategizedProperty[_T]): """ - return RelationshipProperty.Comparator( + return Relationship.Comparator( self.property, self._parententity, adapt_to_entity=self._adapt_to_entity, @@ -427,7 +454,7 @@ class RelationshipProperty(StrategizedProperty[_T]): .. versionadded:: 1.4 """ - return RelationshipProperty.Comparator( + return Relationship.Comparator( self.property, self._parententity, adapt_to_entity=self._adapt_to_entity, @@ -468,7 +495,7 @@ class RelationshipProperty(StrategizedProperty[_T]): many-to-one comparisons: * Comparisons against collections are not supported. - Use :meth:`~.RelationshipProperty.Comparator.contains`. + Use :meth:`~.Relationship.Comparator.contains`. * Compared to a scalar one-to-many, will produce a clause that compares the target columns in the parent to the given target. @@ -479,7 +506,7 @@ class RelationshipProperty(StrategizedProperty[_T]): queries that go beyond simple AND conjunctions of comparisons, such as those which use OR. Use explicit joins, outerjoins, or - :meth:`~.RelationshipProperty.Comparator.has` for + :meth:`~.Relationship.Comparator.has` for more comprehensive non-many-to-one scalar membership tests. * Comparisons against ``None`` given in a one-to-many @@ -613,12 +640,12 @@ class RelationshipProperty(StrategizedProperty[_T]): EXISTS (SELECT 1 FROM related WHERE related.my_id=my_table.id AND related.x=2) - Because :meth:`~.RelationshipProperty.Comparator.any` uses + Because :meth:`~.Relationship.Comparator.any` uses a correlated subquery, its performance is not nearly as good when compared against large target tables as that of using a join. - :meth:`~.RelationshipProperty.Comparator.any` is particularly + :meth:`~.Relationship.Comparator.any` is particularly useful for testing for empty collections:: session.query(MyClass).filter( @@ -631,10 +658,10 @@ class RelationshipProperty(StrategizedProperty[_T]): NOT (EXISTS (SELECT 1 FROM related WHERE related.my_id=my_table.id)) - :meth:`~.RelationshipProperty.Comparator.any` is only + :meth:`~.Relationship.Comparator.any` is only valid for collections, i.e. a :func:`_orm.relationship` that has ``uselist=True``. For scalar references, - use :meth:`~.RelationshipProperty.Comparator.has`. + use :meth:`~.Relationship.Comparator.has`. """ if not self.property.uselist: @@ -662,15 +689,15 @@ class RelationshipProperty(StrategizedProperty[_T]): EXISTS (SELECT 1 FROM related WHERE related.id==my_table.related_id AND related.x=2) - Because :meth:`~.RelationshipProperty.Comparator.has` uses + Because :meth:`~.Relationship.Comparator.has` uses a correlated subquery, its performance is not nearly as good when compared against large target tables as that of using a join. - :meth:`~.RelationshipProperty.Comparator.has` is only + :meth:`~.Relationship.Comparator.has` is only valid for scalar references, i.e. a :func:`_orm.relationship` that has ``uselist=False``. For collection references, - use :meth:`~.RelationshipProperty.Comparator.any`. + use :meth:`~.Relationship.Comparator.any`. """ if self.property.uselist: @@ -683,7 +710,7 @@ class RelationshipProperty(StrategizedProperty[_T]): """Return a simple expression that tests a collection for containment of a particular item. - :meth:`~.RelationshipProperty.Comparator.contains` is + :meth:`~.Relationship.Comparator.contains` is only valid for a collection, i.e. a :func:`_orm.relationship` that implements one-to-many or many-to-many with ``uselist=True``. @@ -700,12 +727,12 @@ class RelationshipProperty(StrategizedProperty[_T]): Where ```` is the value of the foreign key attribute on ``other`` which refers to the primary key of its parent object. From this it follows that - :meth:`~.RelationshipProperty.Comparator.contains` is + :meth:`~.Relationship.Comparator.contains` is very useful when used with simple one-to-many operations. For many-to-many operations, the behavior of - :meth:`~.RelationshipProperty.Comparator.contains` + :meth:`~.Relationship.Comparator.contains` has more caveats. The association table will be rendered in the statement, producing an "implicit" join, that is, includes multiple tables in the FROM @@ -722,14 +749,14 @@ class RelationshipProperty(StrategizedProperty[_T]): Where ```` would be the primary key of ``other``. From the above, it is clear that - :meth:`~.RelationshipProperty.Comparator.contains` + :meth:`~.Relationship.Comparator.contains` will **not** work with many-to-many collections when used in queries that move beyond simple AND conjunctions, such as multiple - :meth:`~.RelationshipProperty.Comparator.contains` + :meth:`~.Relationship.Comparator.contains` expressions joined by OR. In such cases subqueries or explicit "outer joins" will need to be used instead. - See :meth:`~.RelationshipProperty.Comparator.any` for + See :meth:`~.Relationship.Comparator.any` for a less-performant alternative using EXISTS, or refer to :meth:`_query.Query.outerjoin` as well as :ref:`ormtutorial_joins` @@ -818,7 +845,7 @@ class RelationshipProperty(StrategizedProperty[_T]): * Comparisons against collections are not supported. Use - :meth:`~.RelationshipProperty.Comparator.contains` + :meth:`~.Relationship.Comparator.contains` in conjunction with :func:`_expression.not_`. * Compared to a scalar one-to-many, will produce a clause that compares the target columns in the parent to @@ -830,7 +857,7 @@ class RelationshipProperty(StrategizedProperty[_T]): queries that go beyond simple AND conjunctions of comparisons, such as those which use OR. Use explicit joins, outerjoins, or - :meth:`~.RelationshipProperty.Comparator.has` in + :meth:`~.Relationship.Comparator.has` in conjunction with :func:`_expression.not_` for more comprehensive non-many-to-one scalar membership tests. @@ -1249,7 +1276,7 @@ class RelationshipProperty(StrategizedProperty[_T]): def _add_reverse_property(self, key): other = self.mapper.get_property(key, _configure_mappers=False) - if not isinstance(other, RelationshipProperty): + if not isinstance(other, Relationship): raise sa_exc.InvalidRequestError( "back_populates on relationship '%s' refers to attribute '%s' " "that is not a relationship. The back_populates parameter " @@ -1269,6 +1296,8 @@ class RelationshipProperty(StrategizedProperty[_T]): self._reverse_property.add(other) other._reverse_property.add(self) + other._setup_entity() + if not other.mapper.common_parent(self.parent): raise sa_exc.ArgumentError( "reverse_property %r on " @@ -1289,48 +1318,18 @@ class RelationshipProperty(StrategizedProperty[_T]): ) @util.memoized_property - @util.preload_module("sqlalchemy.orm.mapper") - def entity(self): + def entity(self) -> Union["Mapper", "AliasedInsp"]: """Return the target mapped entity, which is an inspect() of the class or aliased class that is referred towards. """ - - mapperlib = util.preloaded.orm_mapper - - if isinstance(self.argument, str): - argument = self._clsregistry_resolve_name(self.argument)() - - elif callable(self.argument) and not isinstance( - self.argument, (type, mapperlib.Mapper) - ): - argument = self.argument() - else: - argument = self.argument - - if isinstance(argument, type): - return mapperlib.class_mapper(argument, configure=False) - - try: - entity = inspect(argument) - except sa_exc.NoInspectionAvailable: - pass - else: - if hasattr(entity, "mapper"): - return entity - - raise sa_exc.ArgumentError( - "relationship '%s' expects " - "a class or a mapper argument (received: %s)" - % (self.key, type(argument)) - ) + self.parent._check_configure() + return self.entity @util.memoized_property - def mapper(self): + def mapper(self) -> "Mapper": """Return the targeted :class:`_orm.Mapper` for this - :class:`.RelationshipProperty`. - - This is a lazy-initializing static attribute. + :class:`.Relationship`. """ return self.entity.mapper @@ -1338,13 +1337,14 @@ class RelationshipProperty(StrategizedProperty[_T]): def do_init(self): self._check_conflicts() self._process_dependent_arguments() + self._setup_entity() self._setup_registry_dependencies() self._setup_join_conditions() self._check_cascade_settings(self._cascade) self._post_init() self._generate_backref() self._join_condition._warn_for_conflicting_sync_targets() - super(RelationshipProperty, self).do_init() + super(Relationship, self).do_init() self._lazy_strategy = self._get_strategy((("lazy", "select"),)) def _setup_registry_dependencies(self): @@ -1432,6 +1432,84 @@ class RelationshipProperty(StrategizedProperty[_T]): for x in util.to_column_set(self.remote_side) ) + def declarative_scan( + self, registry, cls, key, annotation, is_dataclass_field + ): + argument = _extract_mapped_subtype( + annotation, + cls, + key, + Relationship, + self.argument is None, + is_dataclass_field, + ) + if argument is None: + return + + if hasattr(argument, "__origin__"): + + collection_class = argument.__origin__ + if issubclass(collection_class, abc.Collection): + if self.collection_class is None: + self.collection_class = collection_class + else: + self.uselist = False + if argument.__args__: + if issubclass(argument.__origin__, typing.Mapping): + type_arg = argument.__args__[1] + else: + type_arg = argument.__args__[0] + if hasattr(type_arg, "__forward_arg__"): + str_argument = type_arg.__forward_arg__ + argument = str_argument + else: + argument = type_arg + else: + raise sa_exc.ArgumentError( + f"Generic alias {argument} requires an argument" + ) + elif hasattr(argument, "__forward_arg__"): + argument = argument.__forward_arg__ + + self.argument = argument + + @util.preload_module("sqlalchemy.orm.mapper") + def _setup_entity(self, __argument=None): + if "entity" in self.__dict__: + return + + mapperlib = util.preloaded.orm_mapper + + if __argument: + argument = __argument + else: + argument = self.argument + + if isinstance(argument, str): + argument = self._clsregistry_resolve_name(argument)() + elif callable(argument) and not isinstance( + argument, (type, mapperlib.Mapper) + ): + argument = argument() + else: + argument = argument + + if isinstance(argument, type): + entity = mapperlib.class_mapper(argument, configure=False) + else: + try: + entity = inspect(argument) + except sa_exc.NoInspectionAvailable: + entity = None + + if not hasattr(entity, "mapper"): + raise sa_exc.ArgumentError( + "relationship '%s' expects " + "a class or a mapper argument (received: %s)" + % (self.key, type(argument)) + ) + + self.entity = entity # type: ignore self.target = self.entity.persist_selectable def _setup_join_conditions(self): @@ -1502,7 +1580,7 @@ class RelationshipProperty(StrategizedProperty[_T]): @property def cascade(self): """Return the current cascade setting for this - :class:`.RelationshipProperty`. + :class:`.Relationship`. """ return self._cascade @@ -1666,7 +1744,7 @@ class RelationshipProperty(StrategizedProperty[_T]): kwargs.setdefault("passive_updates", self.passive_updates) kwargs.setdefault("sync_backref", self.sync_backref) self.back_populates = backref_key - relationship = RelationshipProperty( + relationship = Relationship( parent, self.secondary, pj, diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index cf47ee7299..6911ab5058 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -9,6 +9,15 @@ import contextlib import itertools import sys +import typing +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import overload +from typing import Tuple +from typing import Type +from typing import Union import weakref from . import attributes @@ -20,12 +29,15 @@ from . import persistence from . import query from . import state as statelib from .base import _class_to_mapper +from .base import _IdentityKeyType from .base import _none_set from .base import _state_mapper from .base import instance_str from .base import object_mapper from .base import object_state from .base import state_str +from .query import Query +from .state import InstanceState from .state_changes import _StateChange from .state_changes import _StateChangeState from .state_changes import _StateChangeStates @@ -34,14 +46,26 @@ from .. import engine from .. import exc as sa_exc from .. import sql from .. import util +from ..engine import Connection +from ..engine import Engine from ..engine.util import TransactionalContext from ..inspection import inspect from ..sql import coercions from ..sql import dml from ..sql import roles from ..sql import visitors +from ..sql._typing import _ColumnsClauseElement from ..sql.base import CompileState from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL +from ..util.typing import Literal + +if typing.TYPE_CHECKING: + from .mapper import Mapper + from ..engine import Row + from ..sql._typing import _ExecuteOptions + from ..sql._typing import _ExecuteParams + from ..sql.base import Executable + from ..sql.schema import Table __all__ = [ "Session", @@ -78,23 +102,60 @@ class _SessionClassMethods: "removed in a future release. Please refer to " ":func:`.session.close_all_sessions`.", ) - def close_all(cls): + def close_all(cls) -> None: """Close *all* sessions in memory.""" close_all_sessions() + @classmethod + @overload + def identity_key( + cls, + class_: type, + ident: Tuple[Any, ...], + *, + identity_token: Optional[str], + ) -> _IdentityKeyType: + ... + + @classmethod + @overload + def identity_key(cls, *, instance: Any) -> _IdentityKeyType: + ... + + @classmethod + @overload + def identity_key( + cls, class_: type, *, row: "Row", identity_token: Optional[str] + ) -> _IdentityKeyType: + ... + @classmethod @util.preload_module("sqlalchemy.orm.util") - def identity_key(cls, *args, **kwargs): + def identity_key( + cls, + class_=None, + ident=None, + *, + instance=None, + row=None, + identity_token=None, + ) -> _IdentityKeyType: """Return an identity key. This is an alias of :func:`.util.identity_key`. """ - return util.preloaded.orm_util.identity_key(*args, **kwargs) + return util.preloaded.orm_util.identity_key( + class_, + ident, + instance=instance, + row=row, + identity_token=identity_token, + ) @classmethod - def object_session(cls, instance): + def object_session(cls, instance: Any) -> "Session": """Return the :class:`.Session` to which an object belongs. This is an alias of :func:`.object_session`. @@ -142,15 +203,26 @@ class ORMExecuteState(util.MemoizedSlots): "_update_execution_options", ) + session: "Session" + statement: "Executable" + parameters: "_ExecuteParams" + execution_options: "_ExecuteOptions" + local_execution_options: "_ExecuteOptions" + bind_arguments: Dict[str, Any] + _compile_state_cls: Type[context.ORMCompileState] + _starting_event_idx: Optional[int] + _events_todo: List[Any] + _update_execution_options: Optional["_ExecuteOptions"] + def __init__( self, - session, - statement, - parameters, - execution_options, - bind_arguments, - compile_state_cls, - events_todo, + session: "Session", + statement: "Executable", + parameters: "_ExecuteParams", + execution_options: "_ExecuteOptions", + bind_arguments: Dict[str, Any], + compile_state_cls: Type[context.ORMCompileState], + events_todo: List[Any], ): self.session = session self.statement = statement @@ -834,7 +906,7 @@ class SessionTransaction(_StateChange, TransactionalContext): (SessionTransactionState.ACTIVE, SessionTransactionState.PREPARED), SessionTransactionState.CLOSED, ) - def commit(self, _to_root=False): + def commit(self, _to_root: bool = False) -> None: if self._state is not SessionTransactionState.PREPARED: with self._expect_state(SessionTransactionState.PREPARED): self._prepare_impl() @@ -981,18 +1053,42 @@ class Session(_SessionClassMethods): _is_asyncio = False + identity_map: identity.IdentityMap + _new: Dict["InstanceState", Any] + _deleted: Dict["InstanceState", Any] + bind: Optional[Union[Engine, Connection]] + __binds: Dict[ + Union[type, "Mapper", "Table"], + Union[engine.Engine, engine.Connection], + ] + _flusing: bool + _warn_on_events: bool + _transaction: Optional[SessionTransaction] + _nested_transaction: Optional[SessionTransaction] + hash_key: int + autoflush: bool + expire_on_commit: bool + enable_baked_queries: bool + twophase: bool + _query_cls: Type[Query] + def __init__( self, - bind=None, - autoflush=True, - future=True, - expire_on_commit=True, - twophase=False, - binds=None, - enable_baked_queries=True, - info=None, - query_cls=None, - autocommit=False, + bind: Optional[Union[engine.Engine, engine.Connection]] = None, + autoflush: bool = True, + future: Literal[True] = True, + expire_on_commit: bool = True, + twophase: bool = False, + binds: Optional[ + Dict[ + Union[type, "Mapper", "Table"], + Union[engine.Engine, engine.Connection], + ] + ] = None, + enable_baked_queries: bool = True, + info: Optional[Dict[Any, Any]] = None, + query_cls: Optional[Type[query.Query]] = None, + autocommit: Literal[False] = False, ): r"""Construct a new Session. @@ -1054,7 +1150,8 @@ class Session(_SessionClassMethods): :class:`.sessionmaker` function, and is not sent directly to the constructor for ``Session``. - :param enable_baked_queries: defaults to ``True``. A flag consumed + :param enable_baked_queries: legacy; defaults to ``True``. + A parameter consumed by the :mod:`sqlalchemy.ext.baked` extension to determine if "baked queries" should be cached, as is the normal operation of this extension. When set to ``False``, caching as used by @@ -1331,7 +1428,7 @@ class Session(_SessionClassMethods): else: self._transaction.rollback(_to_root=True) - def commit(self): + def commit(self) -> None: """Flush pending changes and commit the current transaction. If no transaction is in progress, the method will first @@ -1353,7 +1450,7 @@ class Session(_SessionClassMethods): self._transaction.commit(_to_root=True) - def prepare(self): + def prepare(self) -> None: """Prepare the current transaction in progress for two phase commit. If no transaction is in progress, this method raises an @@ -1370,7 +1467,11 @@ class Session(_SessionClassMethods): self._transaction.prepare() - def connection(self, bind_arguments=None, execution_options=None): + def connection( + self, + bind_arguments: Optional[Dict[str, Any]] = None, + execution_options: Optional["_ExecuteOptions"] = None, + ) -> "Connection": r"""Return a :class:`_engine.Connection` object corresponding to this :class:`.Session` object's transactional state. @@ -1425,12 +1526,12 @@ class Session(_SessionClassMethods): def execute( self, - statement, - params=None, - execution_options=util.EMPTY_DICT, - bind_arguments=None, - _parent_execute_state=None, - _add_event=None, + statement: "Executable", + params: Optional["_ExecuteParams"] = None, + execution_options: "_ExecuteOptions" = util.EMPTY_DICT, + bind_arguments: Optional[Dict[str, Any]] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, ): r"""Execute a SQL expression construct. @@ -1936,7 +2037,9 @@ class Session(_SessionClassMethods): % (", ".join(context),), ) - def query(self, *entities, **kwargs): + def query( + self, *entities: "_ColumnsClauseElement", **kwargs: Any + ) -> "Query": """Return a new :class:`_query.Query` object corresponding to this :class:`_orm.Session`. @@ -2391,7 +2494,7 @@ class Session(_SessionClassMethods): if persistent_to_deleted is not None: persistent_to_deleted(self, state) - def add(self, instance, _warn=True): + def add(self, instance: Any, _warn: bool = True) -> None: """Place an object in the ``Session``. Its state will be persisted to the database on the next flush diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 07e71d4c0b..316aa7ed73 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -34,7 +34,7 @@ from .interfaces import StrategizedProperty from .session import _state_session from .state import InstanceState from .util import _none_set -from .util import aliased +from .util import AliasedClass from .. import event from .. import exc as sa_exc from .. import inspect @@ -564,7 +564,7 @@ class AbstractRelationshipLoader(LoaderStrategy): @log.class_logger -@relationships.RelationshipProperty.strategy_for(do_nothing=True) +@relationships.Relationship.strategy_for(do_nothing=True) class DoNothingLoader(LoaderStrategy): """Relationship loader that makes no change to the object's state. @@ -576,10 +576,10 @@ class DoNothingLoader(LoaderStrategy): @log.class_logger -@relationships.RelationshipProperty.strategy_for(lazy="noload") -@relationships.RelationshipProperty.strategy_for(lazy=None) +@relationships.Relationship.strategy_for(lazy="noload") +@relationships.Relationship.strategy_for(lazy=None) class NoLoader(AbstractRelationshipLoader): - """Provide loading behavior for a :class:`.RelationshipProperty` + """Provide loading behavior for a :class:`.Relationship` with "lazy=None". """ @@ -617,13 +617,13 @@ class NoLoader(AbstractRelationshipLoader): @log.class_logger -@relationships.RelationshipProperty.strategy_for(lazy=True) -@relationships.RelationshipProperty.strategy_for(lazy="select") -@relationships.RelationshipProperty.strategy_for(lazy="raise") -@relationships.RelationshipProperty.strategy_for(lazy="raise_on_sql") -@relationships.RelationshipProperty.strategy_for(lazy="baked_select") +@relationships.Relationship.strategy_for(lazy=True) +@relationships.Relationship.strategy_for(lazy="select") +@relationships.Relationship.strategy_for(lazy="raise") +@relationships.Relationship.strategy_for(lazy="raise_on_sql") +@relationships.Relationship.strategy_for(lazy="baked_select") class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots): - """Provide loading behavior for a :class:`.RelationshipProperty` + """Provide loading behavior for a :class:`.Relationship` with "lazy=True", that is loads when first accessed. """ @@ -1214,7 +1214,7 @@ class PostLoader(AbstractRelationshipLoader): ) -@relationships.RelationshipProperty.strategy_for(lazy="immediate") +@relationships.Relationship.strategy_for(lazy="immediate") class ImmediateLoader(PostLoader): __slots__ = () @@ -1250,7 +1250,7 @@ class ImmediateLoader(PostLoader): @log.class_logger -@relationships.RelationshipProperty.strategy_for(lazy="subquery") +@relationships.Relationship.strategy_for(lazy="subquery") class SubqueryLoader(PostLoader): __slots__ = ("join_depth",) @@ -1906,10 +1906,10 @@ class SubqueryLoader(PostLoader): @log.class_logger -@relationships.RelationshipProperty.strategy_for(lazy="joined") -@relationships.RelationshipProperty.strategy_for(lazy=False) +@relationships.Relationship.strategy_for(lazy="joined") +@relationships.Relationship.strategy_for(lazy=False) class JoinedLoader(AbstractRelationshipLoader): - """Provide loading behavior for a :class:`.RelationshipProperty` + """Provide loading behavior for a :class:`.Relationship` using joined eager loading. """ @@ -2628,7 +2628,7 @@ class JoinedLoader(AbstractRelationshipLoader): @log.class_logger -@relationships.RelationshipProperty.strategy_for(lazy="selectin") +@relationships.Relationship.strategy_for(lazy="selectin") class SelectInLoader(PostLoader, util.MemoizedSlots): __slots__ = ( "join_depth", @@ -2721,7 +2721,7 @@ class SelectInLoader(PostLoader, util.MemoizedSlots): ) def _init_for_join(self): - self._parent_alias = aliased(self.parent.class_) + self._parent_alias = AliasedClass(self.parent.class_) pa_insp = inspect(self._parent_alias) pk_cols = [ pa_insp._adapt_element(col) for col in self.parent.primary_key diff --git a/lib/sqlalchemy/orm/strategy_options.py b/lib/sqlalchemy/orm/strategy_options.py index 0f993b86cf..3f093e543d 100644 --- a/lib/sqlalchemy/orm/strategy_options.py +++ b/lib/sqlalchemy/orm/strategy_options.py @@ -1808,7 +1808,7 @@ class _AttributeStrategyLoad(_LoadElement): assert pwpi if not pwpi.is_aliased_class: pwpi = inspect( - orm_util.with_polymorphic( + orm_util.AliasedInsp._with_polymorphic_factory( pwpi.mapper.base_mapper, pwpi.mapper, aliased=True, diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 75f7110078..45c578355a 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -5,13 +5,22 @@ # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php - import re import types +import typing +from typing import Any +from typing import Generic +from typing import Optional +from typing import overload +from typing import Tuple +from typing import Type +from typing import TypeVar +from typing import Union import weakref from . import attributes # noqa from .base import _class_to_mapper # noqa +from .base import _IdentityKeyType from .base import _never_set # noqa from .base import _none_set # noqa from .base import attribute_str # noqa @@ -45,8 +54,17 @@ from ..sql import util as sql_util from ..sql import visitors from ..sql.annotation import SupportsCloneAnnotations from ..sql.base import ColumnCollection +from ..sql.selectable import FromClause from ..util.langhelpers import MemoizedSlots +from ..util.typing import de_stringify_annotation +from ..util.typing import is_origin_of + +if typing.TYPE_CHECKING: + from .mapper import Mapper + from ..engine import Row + from ..sql.selectable import Alias +_T = TypeVar("_T", bound=Any) all_cascades = frozenset( ( @@ -276,7 +294,28 @@ def polymorphic_union( return sql.union_all(*result).alias(aliasname) -def identity_key(*args, **kwargs): +@overload +def identity_key( + class_: type, ident: Tuple[Any, ...], *, identity_token: Optional[str] +) -> _IdentityKeyType: + ... + + +@overload +def identity_key(*, instance: Any) -> _IdentityKeyType: + ... + + +@overload +def identity_key( + class_: type, *, row: "Row", identity_token: Optional[str] +) -> _IdentityKeyType: + ... + + +def identity_key( + class_=None, ident=None, *, instance=None, row=None, identity_token=None +) -> _IdentityKeyType: r"""Generate "identity key" tuples, as are used as keys in the :attr:`.Session.identity_map` dictionary. @@ -340,29 +379,11 @@ def identity_key(*args, **kwargs): .. versionadded:: 1.2 added identity_token """ - if args: - row = None - largs = len(args) - if largs == 1: - class_ = args[0] - try: - row = kwargs.pop("row") - except KeyError: - ident = kwargs.pop("ident") - elif largs in (2, 3): - class_, ident = args - else: - raise sa_exc.ArgumentError( - "expected up to three positional arguments, " "got %s" % largs - ) - - identity_token = kwargs.pop("identity_token", None) - if kwargs: - raise sa_exc.ArgumentError( - "unknown keyword arguments: %s" % ", ".join(kwargs) - ) + if class_ is not None: mapper = class_mapper(class_) if row is None: + if ident is None: + raise sa_exc.ArgumentError("ident or row is required") return mapper.identity_key_from_primary_key( util.to_list(ident), identity_token=identity_token ) @@ -370,14 +391,11 @@ def identity_key(*args, **kwargs): return mapper.identity_key_from_row( row, identity_token=identity_token ) - else: - instance = kwargs.pop("instance") - if kwargs: - raise sa_exc.ArgumentError( - "unknown keyword arguments: %s" % ", ".join(kwargs.keys) - ) + elif instance is not None: mapper = object_mapper(instance) return mapper.identity_key_from_instance(instance) + else: + raise sa_exc.ArgumentError("class or instance is required") class ORMAdapter(sql_util.ColumnAdapter): @@ -420,7 +438,7 @@ class ORMAdapter(sql_util.ColumnAdapter): return not entity or entity.isa(self.mapper) -class AliasedClass: +class AliasedClass(inspection.Inspectable["AliasedInsp"], Generic[_T]): r"""Represents an "aliased" form of a mapped class for usage with Query. The ORM equivalent of a :func:`~sqlalchemy.sql.expression.alias` @@ -481,7 +499,7 @@ class AliasedClass: def __init__( self, - mapped_class_or_ac, + mapped_class_or_ac: Union[Type[_T], "Mapper[_T]", "AliasedClass[_T]"], alias=None, name=None, flat=False, @@ -611,6 +629,7 @@ class AliasedInsp( ORMEntityColumnsClauseRole, ORMFromClauseRole, sql_base.HasCacheKey, + roles.HasFromClauseElement, InspectionAttr, MemoizedSlots, ): @@ -747,6 +766,73 @@ class AliasedInsp( self._target = mapped_class_or_ac # self._target = mapper.class_ # mapped_class_or_ac + @classmethod + def _alias_factory( + cls, + element: Union[ + Type[_T], "Mapper[_T]", "FromClause", "AliasedClass[_T]" + ], + alias=None, + name=None, + flat=False, + adapt_on_names=False, + ) -> Union["AliasedClass[_T]", "Alias"]: + + if isinstance(element, FromClause): + if adapt_on_names: + raise sa_exc.ArgumentError( + "adapt_on_names only applies to ORM elements" + ) + if name: + return element.alias(name=name, flat=flat) + else: + return coercions.expect( + roles.AnonymizedFromClauseRole, element, flat=flat + ) + else: + return AliasedClass( + element, + alias=alias, + flat=flat, + name=name, + adapt_on_names=adapt_on_names, + ) + + @classmethod + def _with_polymorphic_factory( + cls, + base, + classes, + selectable=False, + flat=False, + polymorphic_on=None, + aliased=False, + innerjoin=False, + _use_mapper_path=False, + ): + + primary_mapper = _class_to_mapper(base) + + if selectable not in (None, False) and flat: + raise sa_exc.ArgumentError( + "the 'flat' and 'selectable' arguments cannot be passed " + "simultaneously to with_polymorphic()" + ) + + mappers, selectable = primary_mapper._with_polymorphic_args( + classes, selectable, innerjoin=innerjoin + ) + if aliased or flat: + selectable = selectable._anonymous_fromclause(flat=flat) + return AliasedClass( + base, + selectable, + with_polymorphic_mappers=mappers, + with_polymorphic_discriminator=polymorphic_on, + use_mapper_path=_use_mapper_path, + represents_outer_join=not innerjoin, + ) + @property def entity(self): # to eliminate reference cycles, the AliasedClass is held weakly. @@ -1107,215 +1193,6 @@ inspection._inspects(AliasedClass)(lambda target: target._aliased_insp) inspection._inspects(AliasedInsp)(lambda target: target) -def aliased(element, alias=None, name=None, flat=False, adapt_on_names=False): - """Produce an alias of the given element, usually an :class:`.AliasedClass` - instance. - - E.g.:: - - my_alias = aliased(MyClass) - - session.query(MyClass, my_alias).filter(MyClass.id > my_alias.id) - - The :func:`.aliased` function is used to create an ad-hoc mapping of a - mapped class to a new selectable. By default, a selectable is generated - from the normally mapped selectable (typically a :class:`_schema.Table` - ) using the - :meth:`_expression.FromClause.alias` method. However, :func:`.aliased` - can also be - used to link the class to a new :func:`_expression.select` statement. - Also, the :func:`.with_polymorphic` function is a variant of - :func:`.aliased` that is intended to specify a so-called "polymorphic - selectable", that corresponds to the union of several joined-inheritance - subclasses at once. - - For convenience, the :func:`.aliased` function also accepts plain - :class:`_expression.FromClause` constructs, such as a - :class:`_schema.Table` or - :func:`_expression.select` construct. In those cases, the - :meth:`_expression.FromClause.alias` - method is called on the object and the new - :class:`_expression.Alias` object returned. The returned - :class:`_expression.Alias` is not - ORM-mapped in this case. - - .. seealso:: - - :ref:`tutorial_orm_entity_aliases` - in the :ref:`unified_tutorial` - - :ref:`orm_queryguide_orm_aliases` - in the :ref:`queryguide_toplevel` - - :ref:`ormtutorial_aliases` - in the legacy :ref:`ormtutorial_toplevel` - - :param element: element to be aliased. Is normally a mapped class, - but for convenience can also be a :class:`_expression.FromClause` - element. - - :param alias: Optional selectable unit to map the element to. This is - usually used to link the object to a subquery, and should be an aliased - select construct as one would produce from the - :meth:`_query.Query.subquery` method or - the :meth:`_expression.Select.subquery` or - :meth:`_expression.Select.alias` methods of the :func:`_expression.select` - construct. - - :param name: optional string name to use for the alias, if not specified - by the ``alias`` parameter. The name, among other things, forms the - attribute name that will be accessible via tuples returned by a - :class:`_query.Query` object. Not supported when creating aliases - of :class:`_sql.Join` objects. - - :param flat: Boolean, will be passed through to the - :meth:`_expression.FromClause.alias` call so that aliases of - :class:`_expression.Join` objects will alias the individual tables - inside the join, rather than creating a subquery. This is generally - supported by all modern databases with regards to right-nested joins - and generally produces more efficient queries. - - :param adapt_on_names: if True, more liberal "matching" will be used when - mapping the mapped columns of the ORM entity to those of the - given selectable - a name-based match will be performed if the - given selectable doesn't otherwise have a column that corresponds - to one on the entity. The use case for this is when associating - an entity with some derived selectable such as one that uses - aggregate functions:: - - class UnitPrice(Base): - __tablename__ = 'unit_price' - ... - unit_id = Column(Integer) - price = Column(Numeric) - - aggregated_unit_price = Session.query( - func.sum(UnitPrice.price).label('price') - ).group_by(UnitPrice.unit_id).subquery() - - aggregated_unit_price = aliased(UnitPrice, - alias=aggregated_unit_price, adapt_on_names=True) - - Above, functions on ``aggregated_unit_price`` which refer to - ``.price`` will return the - ``func.sum(UnitPrice.price).label('price')`` column, as it is - matched on the name "price". Ordinarily, the "price" function - wouldn't have any "column correspondence" to the actual - ``UnitPrice.price`` column as it is not a proxy of the original. - - """ - if isinstance(element, expression.FromClause): - if adapt_on_names: - raise sa_exc.ArgumentError( - "adapt_on_names only applies to ORM elements" - ) - if name: - return element.alias(name=name, flat=flat) - else: - return coercions.expect( - roles.AnonymizedFromClauseRole, element, flat=flat - ) - else: - return AliasedClass( - element, - alias=alias, - flat=flat, - name=name, - adapt_on_names=adapt_on_names, - ) - - -def with_polymorphic( - base, - classes, - selectable=False, - flat=False, - polymorphic_on=None, - aliased=False, - innerjoin=False, - _use_mapper_path=False, -): - """Produce an :class:`.AliasedClass` construct which specifies - columns for descendant mappers of the given base. - - Using this method will ensure that each descendant mapper's - tables are included in the FROM clause, and will allow filter() - criterion to be used against those tables. The resulting - instances will also have those columns already loaded so that - no "post fetch" of those columns will be required. - - .. seealso:: - - :ref:`with_polymorphic` - full discussion of - :func:`_orm.with_polymorphic`. - - :param base: Base class to be aliased. - - :param classes: a single class or mapper, or list of - class/mappers, which inherit from the base class. - Alternatively, it may also be the string ``'*'``, in which case - all descending mapped classes will be added to the FROM clause. - - :param aliased: when True, the selectable will be aliased. For a - JOIN, this means the JOIN will be SELECTed from inside of a subquery - unless the :paramref:`_orm.with_polymorphic.flat` flag is set to - True, which is recommended for simpler use cases. - - :param flat: Boolean, will be passed through to the - :meth:`_expression.FromClause.alias` call so that aliases of - :class:`_expression.Join` objects will alias the individual tables - inside the join, rather than creating a subquery. This is generally - supported by all modern databases with regards to right-nested joins - and generally produces more efficient queries. Setting this flag is - recommended as long as the resulting SQL is functional. - - :param selectable: a table or subquery that will - be used in place of the generated FROM clause. This argument is - required if any of the desired classes use concrete table - inheritance, since SQLAlchemy currently cannot generate UNIONs - among tables automatically. If used, the ``selectable`` argument - must represent the full set of tables and columns mapped by every - mapped class. Otherwise, the unaccounted mapped columns will - result in their table being appended directly to the FROM clause - which will usually lead to incorrect results. - - When left at its default value of ``False``, the polymorphic - selectable assigned to the base mapper is used for selecting rows. - However, it may also be passed as ``None``, which will bypass the - configured polymorphic selectable and instead construct an ad-hoc - selectable for the target classes given; for joined table inheritance - this will be a join that includes all target mappers and their - subclasses. - - :param polymorphic_on: a column to be used as the "discriminator" - column for the given selectable. If not given, the polymorphic_on - attribute of the base classes' mapper will be used, if any. This - is useful for mappings that don't have polymorphic loading - behavior by default. - - :param innerjoin: if True, an INNER JOIN will be used. This should - only be specified if querying for one specific subtype only - """ - primary_mapper = _class_to_mapper(base) - - if selectable not in (None, False) and flat: - raise sa_exc.ArgumentError( - "the 'flat' and 'selectable' arguments cannot be passed " - "simultaneously to with_polymorphic()" - ) - - mappers, selectable = primary_mapper._with_polymorphic_args( - classes, selectable, innerjoin=innerjoin - ) - if aliased or flat: - selectable = selectable._anonymous_fromclause(flat=flat) - return AliasedClass( - base, - selectable, - with_polymorphic_mappers=mappers, - with_polymorphic_discriminator=polymorphic_on, - use_mapper_path=_use_mapper_path, - represents_outer_join=not innerjoin, - ) - - @inspection._self_inspects class Bundle( ORMColumnsClauseRole, @@ -1667,62 +1544,6 @@ class _ORMJoin(expression.Join): return _ORMJoin(self, right, onclause, isouter=True, full=full) -def join( - left, right, onclause=None, isouter=False, full=False, join_to_left=None -): - r"""Produce an inner join between left and right clauses. - - :func:`_orm.join` is an extension to the core join interface - provided by :func:`_expression.join()`, where the - left and right selectables may be not only core selectable - objects such as :class:`_schema.Table`, but also mapped classes or - :class:`.AliasedClass` instances. The "on" clause can - be a SQL expression, or an attribute or string name - referencing a configured :func:`_orm.relationship`. - - :func:`_orm.join` is not commonly needed in modern usage, - as its functionality is encapsulated within that of the - :meth:`_query.Query.join` method, which features a - significant amount of automation beyond :func:`_orm.join` - by itself. Explicit usage of :func:`_orm.join` - with :class:`_query.Query` involves usage of the - :meth:`_query.Query.select_from` method, as in:: - - from sqlalchemy.orm import join - session.query(User).\ - select_from(join(User, Address, User.addresses)).\ - filter(Address.email_address=='foo@bar.com') - - In modern SQLAlchemy the above join can be written more - succinctly as:: - - session.query(User).\ - join(User.addresses).\ - filter(Address.email_address=='foo@bar.com') - - See :meth:`_query.Query.join` for information on modern usage - of ORM level joins. - - .. deprecated:: 0.8 - - the ``join_to_left`` parameter is deprecated, and will be removed - in a future release. The parameter has no effect. - - """ - return _ORMJoin(left, right, onclause, isouter, full) - - -def outerjoin(left, right, onclause=None, full=False, join_to_left=None): - """Produce a left outer join between left and right clauses. - - This is the "outer join" version of the :func:`_orm.join` function, - featuring the same behavior except that an OUTER JOIN is generated. - See that function's documentation for other usage details. - - """ - return _ORMJoin(left, right, onclause, True, full) - - def with_parent(instance, prop, from_entity=None): """Create filtering criterion that relates this query's primary entity to the given related instance, using established @@ -1964,3 +1785,56 @@ def _getitem(iterable_query, item): return list(iterable_query)[-1] else: return list(iterable_query[item : item + 1])[0] + + +def _is_mapped_annotation(raw_annotation: Union[type, str], cls: type): + annotated = de_stringify_annotation(cls, raw_annotation) + return is_origin_of(annotated, "Mapped", module="sqlalchemy.orm") + + +def _extract_mapped_subtype( + raw_annotation: Union[type, str], + cls: type, + key: str, + attr_cls: type, + required: bool, + is_dataclass_field: bool, +) -> Optional[Union[type, str]]: + + if raw_annotation is None: + + if required: + raise sa_exc.ArgumentError( + f"Python typing annotation is required for attribute " + f'"{cls.__name__}.{key}" when primary argument(s) for ' + f'"{attr_cls.__name__}" construct are None or not present' + ) + return None + + annotated = de_stringify_annotation(cls, raw_annotation) + + if is_dataclass_field: + return annotated + else: + if ( + not hasattr(annotated, "__origin__") + or not issubclass(annotated.__origin__, attr_cls) + and not issubclass(attr_cls, annotated.__origin__) + ): + our_annotated_str = ( + annotated.__name__ + if not isinstance(annotated, str) + else repr(annotated) + ) + raise sa_exc.ArgumentError( + f'Type annotation for "{cls.__name__}.{key}" should use the ' + f'syntax "Mapped[{our_annotated_str}]" or ' + f'"{attr_cls.__name__}[{our_annotated_str}]".' + ) + + if len(annotated.__args__) != 1: + raise sa_exc.ArgumentError( + "Expected sub-type for Mapped[] annotation" + ) + + return annotated.__args__[0] diff --git a/lib/sqlalchemy/pool/__init__.py b/lib/sqlalchemy/pool/__init__.py index 38059856ee..bc2f93d57e 100644 --- a/lib/sqlalchemy/pool/__init__.py +++ b/lib/sqlalchemy/pool/__init__.py @@ -22,36 +22,17 @@ from .base import _AdhocProxiedConnection from .base import _ConnectionFairy from .base import _ConnectionRecord from .base import _finalize_fairy -from .base import Pool -from .base import PoolProxiedConnection -from .base import reset_commit -from .base import reset_none -from .base import reset_rollback -from .impl import AssertionPool -from .impl import AsyncAdaptedQueuePool -from .impl import FallbackAsyncAdaptedQueuePool -from .impl import NullPool -from .impl import QueuePool -from .impl import SingletonThreadPool -from .impl import StaticPool - - -__all__ = [ - "Pool", - "PoolProxiedConnection", - "reset_commit", - "reset_none", - "reset_rollback", - "clear_managers", - "manage", - "AssertionPool", - "NullPool", - "QueuePool", - "AsyncAdaptedQueuePool", - "FallbackAsyncAdaptedQueuePool", - "SingletonThreadPool", - "StaticPool", -] - -# as these are likely to be used in various test suites, debugging -# setups, keep them in the sqlalchemy.pool namespace +from .base import Pool as Pool +from .base import PoolProxiedConnection as PoolProxiedConnection +from .base import reset_commit as reset_commit +from .base import reset_none as reset_none +from .base import reset_rollback as reset_rollback +from .impl import AssertionPool as AssertionPool +from .impl import AsyncAdaptedQueuePool as AsyncAdaptedQueuePool +from .impl import ( + FallbackAsyncAdaptedQueuePool as FallbackAsyncAdaptedQueuePool, +) +from .impl import NullPool as NullPool +from .impl import QueuePool as QueuePool +from .impl import SingletonThreadPool as SingletonThreadPool +from .impl import StaticPool as StaticPool diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index c596dee5a6..b2ca1cfefa 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -9,50 +9,54 @@ """ -from .sql.base import SchemaVisitor # noqa -from .sql.ddl import _CreateDropBase # noqa -from .sql.ddl import _DDLCompiles # noqa -from .sql.ddl import _DropView # noqa -from .sql.ddl import AddConstraint # noqa -from .sql.ddl import CreateColumn # noqa -from .sql.ddl import CreateIndex # noqa -from .sql.ddl import CreateSchema # noqa -from .sql.ddl import CreateSequence # noqa -from .sql.ddl import CreateTable # noqa -from .sql.ddl import DDL # noqa -from .sql.ddl import DDLBase # noqa -from .sql.ddl import DDLElement # noqa -from .sql.ddl import DropColumnComment # noqa -from .sql.ddl import DropConstraint # noqa -from .sql.ddl import DropIndex # noqa -from .sql.ddl import DropSchema # noqa -from .sql.ddl import DropSequence # noqa -from .sql.ddl import DropTable # noqa -from .sql.ddl import DropTableComment # noqa -from .sql.ddl import SetColumnComment # noqa -from .sql.ddl import SetTableComment # noqa -from .sql.ddl import sort_tables # noqa -from .sql.ddl import sort_tables_and_constraints # noqa -from .sql.naming import conv # noqa -from .sql.schema import _get_table_key # noqa -from .sql.schema import BLANK_SCHEMA # noqa -from .sql.schema import CheckConstraint # noqa -from .sql.schema import Column # noqa -from .sql.schema import ColumnCollectionConstraint # noqa -from .sql.schema import ColumnCollectionMixin # noqa -from .sql.schema import ColumnDefault # noqa -from .sql.schema import Computed # noqa -from .sql.schema import Constraint # noqa -from .sql.schema import DefaultClause # noqa -from .sql.schema import DefaultGenerator # noqa -from .sql.schema import FetchedValue # noqa -from .sql.schema import ForeignKey # noqa -from .sql.schema import ForeignKeyConstraint # noqa -from .sql.schema import Identity # noqa -from .sql.schema import Index # noqa -from .sql.schema import MetaData # noqa -from .sql.schema import PrimaryKeyConstraint # noqa -from .sql.schema import SchemaItem # noqa -from .sql.schema import Sequence # noqa -from .sql.schema import Table # noqa -from .sql.schema import UniqueConstraint # noqa +from .sql.base import SchemaVisitor as SchemaVisitor +from .sql.ddl import _CreateDropBase as _CreateDropBase +from .sql.ddl import _DDLCompiles as _DDLCompiles +from .sql.ddl import _DropView as _DropView +from .sql.ddl import AddConstraint as AddConstraint +from .sql.ddl import CreateColumn as CreateColumn +from .sql.ddl import CreateIndex as CreateIndex +from .sql.ddl import CreateSchema as CreateSchema +from .sql.ddl import CreateSequence as CreateSequence +from .sql.ddl import CreateTable as CreateTable +from .sql.ddl import DDL as DDL +from .sql.ddl import DDLBase as DDLBase +from .sql.ddl import DDLElement as DDLElement +from .sql.ddl import DropColumnComment as DropColumnComment +from .sql.ddl import DropConstraint as DropConstraint +from .sql.ddl import DropIndex as DropIndex +from .sql.ddl import DropSchema as DropSchema +from .sql.ddl import DropSequence as DropSequence +from .sql.ddl import DropTable as DropTable +from .sql.ddl import DropTableComment as DropTableComment +from .sql.ddl import SetColumnComment as SetColumnComment +from .sql.ddl import SetTableComment as SetTableComment +from .sql.ddl import sort_tables as sort_tables +from .sql.ddl import ( + sort_tables_and_constraints as sort_tables_and_constraints, +) +from .sql.naming import conv as conv +from .sql.schema import _get_table_key as _get_table_key +from .sql.schema import BLANK_SCHEMA as BLANK_SCHEMA +from .sql.schema import CheckConstraint as CheckConstraint +from .sql.schema import Column as Column +from .sql.schema import ( + ColumnCollectionConstraint as ColumnCollectionConstraint, +) +from .sql.schema import ColumnCollectionMixin as ColumnCollectionMixin +from .sql.schema import ColumnDefault as ColumnDefault +from .sql.schema import Computed as Computed +from .sql.schema import Constraint as Constraint +from .sql.schema import DefaultClause as DefaultClause +from .sql.schema import DefaultGenerator as DefaultGenerator +from .sql.schema import FetchedValue as FetchedValue +from .sql.schema import ForeignKey as ForeignKey +from .sql.schema import ForeignKeyConstraint as ForeignKeyConstraint +from .sql.schema import Identity as Identity +from .sql.schema import Index as Index +from .sql.schema import MetaData as MetaData +from .sql.schema import PrimaryKeyConstraint as PrimaryKeyConstraint +from .sql.schema import SchemaItem as SchemaItem +from .sql.schema import Sequence as Sequence +from .sql.schema import Table as Table +from .sql.schema import UniqueConstraint as UniqueConstraint diff --git a/lib/sqlalchemy/sql/__init__.py b/lib/sqlalchemy/sql/__init__.py index 2f84370aa2..169ddf3dbb 100644 --- a/lib/sqlalchemy/sql/__init__.py +++ b/lib/sqlalchemy/sql/__init__.py @@ -75,6 +75,7 @@ from .expression import quoted_name as quoted_name from .expression import Select as Select from .expression import select as select from .expression import Selectable as Selectable +from .expression import SelectLabelStyle as SelectLabelStyle from .expression import StatementLambdaElement as StatementLambdaElement from .expression import Subquery as Subquery from .expression import table as table diff --git a/lib/sqlalchemy/sql/_selectable_constructors.py b/lib/sqlalchemy/sql/_selectable_constructors.py index 4b67c12f08..d3cf207da0 100644 --- a/lib/sqlalchemy/sql/_selectable_constructors.py +++ b/lib/sqlalchemy/sql/_selectable_constructors.py @@ -6,11 +6,11 @@ # the MIT License: https://www.opensource.org/licenses/mit-license.php from typing import Any -from typing import Type from typing import Union from . import coercions from . import roles +from ._typing import _ColumnsClauseElement from .elements import ColumnClause from .selectable import Alias from .selectable import CompoundSelect @@ -21,6 +21,8 @@ from .selectable import Select from .selectable import TableClause from .selectable import TableSample from .selectable import Values +from ..util.typing import _LiteralStar +from ..util.typing import Literal def alias(selectable, name=None, flat=False): @@ -279,7 +281,9 @@ def outerjoin(left, right, onclause=None, full=False): return Join(left, right, onclause, isouter=True, full=full) -def select(*entities: Union[roles.ColumnsClauseRole, Type]) -> "Select": +def select( + *entities: Union[_LiteralStar, Literal[1], _ColumnsClauseElement] +) -> "Select": r"""Construct a new :class:`_expression.Select`. diff --git a/lib/sqlalchemy/sql/_typing.py b/lib/sqlalchemy/sql/_typing.py index b5b0efb21a..4d2dd26884 100644 --- a/lib/sqlalchemy/sql/_typing.py +++ b/lib/sqlalchemy/sql/_typing.py @@ -1,9 +1,21 @@ from typing import Any from typing import Mapping from typing import Sequence +from typing import Type from typing import Union +from . import roles +from ..inspection import Inspectable +from ..util import immutabledict + _SingleExecuteParams = Mapping[str, Any] _MultiExecuteParams = Sequence[_SingleExecuteParams] _ExecuteParams = Union[_SingleExecuteParams, _MultiExecuteParams] _ExecuteOptions = Mapping[str, Any] +_ImmutableExecuteOptions = immutabledict[str, Any] +_ColumnsClauseElement = Union[ + roles.ColumnsClauseRole, Type, Inspectable[roles.HasClauseElement] +] +_FromClauseElement = Union[ + roles.FromClauseRole, Type, Inspectable[roles.HasFromClauseElement] +] diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index f4fe7afab2..5828f9369d 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -21,6 +21,7 @@ from typing import TypeVar from . import roles from . import visitors +from ._typing import _ImmutableExecuteOptions from .cache_key import HasCacheKey # noqa from .cache_key import MemoizedHasCacheKey # noqa from .traversals import HasCopyInternals # noqa @@ -832,9 +833,8 @@ class Executable(roles.StatementRole, Generative): """ - supports_execution = True - _execution_options = util.immutabledict() - _bind = None + supports_execution: bool = True + _execution_options: _ImmutableExecuteOptions = util.immutabledict() _with_options = () _with_context_options = () diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 9cf4d83974..bf78b4231a 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -889,7 +889,7 @@ class SQLCompiler(Compiled): def _apply_numbered_params(self): poscount = itertools.count(1) self.string = re.sub( - r"\[_POSITION\]", lambda m: str(util.next(poscount)), self.string + r"\[_POSITION\]", lambda m: str(next(poscount)), self.string ) @util.memoized_property diff --git a/lib/sqlalchemy/sql/ddl.py b/lib/sqlalchemy/sql/ddl.py index 18931ce67a..f622023b02 100644 --- a/lib/sqlalchemy/sql/ddl.py +++ b/lib/sqlalchemy/sql/ddl.py @@ -10,6 +10,11 @@ to invoke them for a create/drop call. """ import typing +from typing import Callable +from typing import List +from typing import Optional +from typing import Sequence +from typing import Tuple from . import roles from .base import _generative @@ -21,6 +26,11 @@ from .. import util from ..util import topological +if typing.TYPE_CHECKING: + from .schema import ForeignKeyConstraint + from .schema import Table + + class _DDLCompiles(ClauseElement): _hierarchy_supports_caching = False """disable cache warnings for all _DDLCompiles subclasses. """ @@ -1007,10 +1017,10 @@ class SchemaDropper(DDLBase): def sort_tables( - tables, - skip_fn=None, - extra_dependencies=None, -): + tables: Sequence["Table"], + skip_fn: Optional[Callable[["ForeignKeyConstraint"], bool]] = None, + extra_dependencies: Optional[Sequence[Tuple["Table", "Table"]]] = None, +) -> List["Table"]: """Sort a collection of :class:`_schema.Table` objects based on dependency. @@ -1051,7 +1061,7 @@ def sort_tables( :param tables: a sequence of :class:`_schema.Table` objects. :param skip_fn: optional callable which will be passed a - :class:`_schema.ForeignKey` object; if it returns True, this + :class:`_schema.ForeignKeyConstraint` object; if it returns True, this constraint will not be considered as a dependency. Note this is **different** from the same parameter in :func:`.sort_tables_and_constraints`, which is diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 0ed5bd9865..22195cd7c5 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -136,6 +136,7 @@ from .selectable import ScalarSelect as ScalarSelect from .selectable import Select as Select from .selectable import Selectable as Selectable from .selectable import SelectBase as SelectBase +from .selectable import SelectLabelStyle as SelectLabelStyle from .selectable import Subquery as Subquery from .selectable import TableClause as TableClause from .selectable import TableSample as TableSample diff --git a/lib/sqlalchemy/sql/naming.py b/lib/sqlalchemy/sql/naming.py index 00a2b1d897..15a1566a6f 100644 --- a/lib/sqlalchemy/sql/naming.py +++ b/lib/sqlalchemy/sql/naming.py @@ -14,7 +14,7 @@ import re from . import events # noqa from .elements import _NONE_NAME -from .elements import conv +from .elements import conv as conv from .schema import CheckConstraint from .schema import Column from .schema import Constraint diff --git a/lib/sqlalchemy/sql/roles.py b/lib/sqlalchemy/sql/roles.py index 787a1c25ee..b41ef7a5d1 100644 --- a/lib/sqlalchemy/sql/roles.py +++ b/lib/sqlalchemy/sql/roles.py @@ -4,10 +4,17 @@ # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +import typing +from sqlalchemy.util.langhelpers import TypingOnly from .. import util +if typing.TYPE_CHECKING: + from .elements import ClauseElement + from .selectable import FromClause + + class SQLRole: """Define a "role" within a SQL statement structure. @@ -284,3 +291,25 @@ class DDLReferredColumnRole(DDLConstraintColumnRole): _role_name = ( "String column name or Column object for DDL foreign key constraint" ) + + +class HasClauseElement(TypingOnly): + """indicates a class that has a __clause_element__() method""" + + __slots__ = () + + if typing.TYPE_CHECKING: + + def __clause_element__(self) -> "ClauseElement": + ... + + +class HasFromClauseElement(HasClauseElement, TypingOnly): + """indicates a class that has a __clause_element__() method""" + + __slots__ = () + + if typing.TYPE_CHECKING: + + def __clause_element__(self) -> "FromClause": + ... diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index a04fad05df..9387ae030c 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -31,9 +31,12 @@ as components in SQL expressions. import collections import typing from typing import Any +from typing import Dict +from typing import List from typing import MutableMapping from typing import Optional from typing import overload +from typing import Sequence as _typing_Sequence from typing import Type from typing import TypeVar from typing import Union @@ -52,6 +55,7 @@ from .elements import ClauseElement from .elements import ColumnClause from .elements import ColumnElement from .elements import quoted_name +from .elements import SQLCoreOperations from .elements import TextClause from .selectable import TableClause from .type_api import to_instance @@ -64,9 +68,12 @@ from ..util.typing import Literal if typing.TYPE_CHECKING: from .type_api import TypeEngine + from ..engine import Connection + from ..engine import Engine _T = TypeVar("_T", bound="Any") _ServerDefaultType = Union["FetchedValue", str, TextClause, ColumnElement] +_TAB = TypeVar("_TAB", bound="Table") RETAIN_SCHEMA = util.symbol("retain_schema") @@ -188,313 +195,6 @@ class Table(DialectKWArgs, SchemaItem, TableClause): :ref:`metadata_describing` - Introduction to database metadata - Constructor arguments are as follows: - - :param name: The name of this table as represented in the database. - - The table name, along with the value of the ``schema`` parameter, - forms a key which uniquely identifies this :class:`_schema.Table` - within - the owning :class:`_schema.MetaData` collection. - Additional calls to :class:`_schema.Table` with the same name, - metadata, - and schema name will return the same :class:`_schema.Table` object. - - Names which contain no upper case characters - will be treated as case insensitive names, and will not be quoted - unless they are a reserved word or contain special characters. - A name with any number of upper case characters is considered - to be case sensitive, and will be sent as quoted. - - To enable unconditional quoting for the table name, specify the flag - ``quote=True`` to the constructor, or use the :class:`.quoted_name` - construct to specify the name. - - :param metadata: a :class:`_schema.MetaData` - object which will contain this - table. The metadata is used as a point of association of this table - with other tables which are referenced via foreign key. It also - may be used to associate this table with a particular - :class:`.Connection` or :class:`.Engine`. - - :param \*args: Additional positional arguments are used primarily - to add the list of :class:`_schema.Column` - objects contained within this - table. Similar to the style of a CREATE TABLE statement, other - :class:`.SchemaItem` constructs may be added here, including - :class:`.PrimaryKeyConstraint`, and - :class:`_schema.ForeignKeyConstraint`. - - :param autoload: Defaults to ``False``, unless - :paramref:`_schema.Table.autoload_with` - is set in which case it defaults to ``True``; - :class:`_schema.Column` objects - for this table should be reflected from the database, possibly - augmenting objects that were explicitly specified. - :class:`_schema.Column` and other objects explicitly set on the - table will replace corresponding reflected objects. - - .. deprecated:: 1.4 - - The autoload parameter is deprecated and will be removed in - version 2.0. Please use the - :paramref:`_schema.Table.autoload_with` parameter, passing an - engine or connection. - - .. seealso:: - - :ref:`metadata_reflection_toplevel` - - :param autoload_replace: Defaults to ``True``; when using - :paramref:`_schema.Table.autoload` - in conjunction with :paramref:`_schema.Table.extend_existing`, - indicates - that :class:`_schema.Column` objects present in the already-existing - :class:`_schema.Table` - object should be replaced with columns of the same - name retrieved from the autoload process. When ``False``, columns - already present under existing names will be omitted from the - reflection process. - - Note that this setting does not impact :class:`_schema.Column` objects - specified programmatically within the call to :class:`_schema.Table` - that - also is autoloading; those :class:`_schema.Column` objects will always - replace existing columns of the same name when - :paramref:`_schema.Table.extend_existing` is ``True``. - - .. seealso:: - - :paramref:`_schema.Table.autoload` - - :paramref:`_schema.Table.extend_existing` - - :param autoload_with: An :class:`_engine.Engine` or - :class:`_engine.Connection` object, - or a :class:`_reflection.Inspector` object as returned by - :func:`_sa.inspect` - against one, with which this :class:`_schema.Table` - object will be reflected. - When set to a non-None value, the autoload process will take place - for this table against the given engine or connection. - - :param extend_existing: When ``True``, indicates that if this - :class:`_schema.Table` is already present in the given - :class:`_schema.MetaData`, - apply further arguments within the constructor to the existing - :class:`_schema.Table`. - - If :paramref:`_schema.Table.extend_existing` or - :paramref:`_schema.Table.keep_existing` are not set, - and the given name - of the new :class:`_schema.Table` refers to a :class:`_schema.Table` - that is - already present in the target :class:`_schema.MetaData` collection, - and - this :class:`_schema.Table` - specifies additional columns or other constructs - or flags that modify the table's state, an - error is raised. The purpose of these two mutually-exclusive flags - is to specify what action should be taken when a - :class:`_schema.Table` - is specified that matches an existing :class:`_schema.Table`, - yet specifies - additional constructs. - - :paramref:`_schema.Table.extend_existing` - will also work in conjunction - with :paramref:`_schema.Table.autoload` to run a new reflection - operation against the database, even if a :class:`_schema.Table` - of the same name is already present in the target - :class:`_schema.MetaData`; newly reflected :class:`_schema.Column` - objects - and other options will be added into the state of the - :class:`_schema.Table`, potentially overwriting existing columns - and options of the same name. - - As is always the case with :paramref:`_schema.Table.autoload`, - :class:`_schema.Column` objects can be specified in the same - :class:`_schema.Table` - constructor, which will take precedence. Below, the existing - table ``mytable`` will be augmented with :class:`_schema.Column` - objects - both reflected from the database, as well as the given - :class:`_schema.Column` - named "y":: - - Table("mytable", metadata, - Column('y', Integer), - extend_existing=True, - autoload_with=engine - ) - - .. seealso:: - - :paramref:`_schema.Table.autoload` - - :paramref:`_schema.Table.autoload_replace` - - :paramref:`_schema.Table.keep_existing` - - - :param implicit_returning: True by default - indicates that - RETURNING can be used by default to fetch newly inserted primary key - values, for backends which support this. Note that - :func:`_sa.create_engine` also provides an ``implicit_returning`` - flag. - - :param include_columns: A list of strings indicating a subset of - columns to be loaded via the ``autoload`` operation; table columns who - aren't present in this list will not be represented on the resulting - ``Table`` object. Defaults to ``None`` which indicates all columns - should be reflected. - - :param resolve_fks: Whether or not to reflect :class:`_schema.Table` - objects - related to this one via :class:`_schema.ForeignKey` objects, when - :paramref:`_schema.Table.autoload` or - :paramref:`_schema.Table.autoload_with` is - specified. Defaults to True. Set to False to disable reflection of - related tables as :class:`_schema.ForeignKey` - objects are encountered; may be - used either to save on SQL calls or to avoid issues with related tables - that can't be accessed. Note that if a related table is already present - in the :class:`_schema.MetaData` collection, or becomes present later, - a - :class:`_schema.ForeignKey` object associated with this - :class:`_schema.Table` will - resolve to that table normally. - - .. versionadded:: 1.3 - - .. seealso:: - - :paramref:`.MetaData.reflect.resolve_fks` - - - :param info: Optional data dictionary which will be populated into the - :attr:`.SchemaItem.info` attribute of this object. - - :param keep_existing: When ``True``, indicates that if this Table - is already present in the given :class:`_schema.MetaData`, ignore - further arguments within the constructor to the existing - :class:`_schema.Table`, and return the :class:`_schema.Table` - object as - originally created. This is to allow a function that wishes - to define a new :class:`_schema.Table` on first call, but on - subsequent calls will return the same :class:`_schema.Table`, - without any of the declarations (particularly constraints) - being applied a second time. - - If :paramref:`_schema.Table.extend_existing` or - :paramref:`_schema.Table.keep_existing` are not set, - and the given name - of the new :class:`_schema.Table` refers to a :class:`_schema.Table` - that is - already present in the target :class:`_schema.MetaData` collection, - and - this :class:`_schema.Table` - specifies additional columns or other constructs - or flags that modify the table's state, an - error is raised. The purpose of these two mutually-exclusive flags - is to specify what action should be taken when a - :class:`_schema.Table` - is specified that matches an existing :class:`_schema.Table`, - yet specifies - additional constructs. - - .. seealso:: - - :paramref:`_schema.Table.extend_existing` - - :param listeners: A list of tuples of the form ``(, )`` - which will be passed to :func:`.event.listen` upon construction. - This alternate hook to :func:`.event.listen` allows the establishment - of a listener function specific to this :class:`_schema.Table` before - the "autoload" process begins. Historically this has been intended - for use with the :meth:`.DDLEvents.column_reflect` event, however - note that this event hook may now be associated with the - :class:`_schema.MetaData` object directly:: - - def listen_for_reflect(table, column_info): - "handle the column reflection event" - # ... - - t = Table( - 'sometable', - autoload_with=engine, - listeners=[ - ('column_reflect', listen_for_reflect) - ]) - - .. seealso:: - - :meth:`_events.DDLEvents.column_reflect` - - :param must_exist: When ``True``, indicates that this Table must already - be present in the given :class:`_schema.MetaData` collection, else - an exception is raised. - - :param prefixes: - A list of strings to insert after CREATE in the CREATE TABLE - statement. They will be separated by spaces. - - :param quote: Force quoting of this table's name on or off, corresponding - to ``True`` or ``False``. When left at its default of ``None``, - the column identifier will be quoted according to whether the name is - case sensitive (identifiers with at least one upper case character are - treated as case sensitive), or if it's a reserved word. This flag - is only needed to force quoting of a reserved word which is not known - by the SQLAlchemy dialect. - - .. note:: setting this flag to ``False`` will not provide - case-insensitive behavior for table reflection; table reflection - will always search for a mixed-case name in a case sensitive - fashion. Case insensitive names are specified in SQLAlchemy only - by stating the name with all lower case characters. - - :param quote_schema: same as 'quote' but applies to the schema identifier. - - :param schema: The schema name for this table, which is required if - the table resides in a schema other than the default selected schema - for the engine's database connection. Defaults to ``None``. - - If the owning :class:`_schema.MetaData` of this :class:`_schema.Table` - specifies its - own :paramref:`_schema.MetaData.schema` parameter, - then that schema name will - be applied to this :class:`_schema.Table` - if the schema parameter here is set - to ``None``. To set a blank schema name on a :class:`_schema.Table` - that - would otherwise use the schema set on the owning - :class:`_schema.MetaData`, - specify the special symbol :attr:`.BLANK_SCHEMA`. - - .. versionadded:: 1.0.14 Added the :attr:`.BLANK_SCHEMA` symbol to - allow a :class:`_schema.Table` - to have a blank schema name even when the - parent :class:`_schema.MetaData` specifies - :paramref:`_schema.MetaData.schema`. - - The quoting rules for the schema name are the same as those for the - ``name`` parameter, in that quoting is applied for reserved words or - case-sensitive names; to enable unconditional quoting for the schema - name, specify the flag ``quote_schema=True`` to the constructor, or use - the :class:`.quoted_name` construct to specify the name. - - :param comment: Optional string that will render an SQL comment on table - creation. - - .. versionadded:: 1.2 Added the :paramref:`_schema.Table.comment` - parameter - to :class:`_schema.Table`. - - :param \**kw: Additional keyword arguments not mentioned above are - dialect specific, and passed in the form ``_``. - See the documentation regarding an individual dialect at - :ref:`dialect_toplevel` for detail on documented arguments. - """ __visit_name__ = "table" @@ -547,13 +247,21 @@ class Table(DialectKWArgs, SchemaItem, TableClause): else: return (self,) - @util.deprecated_params( - mustexist=( - "1.4", - "Deprecated alias of :paramref:`_schema.Table.must_exist`", - ), - ) - def __new__(cls, *args, **kw): + if not typing.TYPE_CHECKING: + # typing tools seem to be inconsistent in how they handle + # __new__, so suggest this pattern for classes that use + # __new__. apply typing to the __init__ method normally + @util.deprecated_params( + mustexist=( + "1.4", + "Deprecated alias of :paramref:`_schema.Table.must_exist`", + ), + ) + def __new__(cls, *args: Any, **kw: Any) -> Any: + return cls._new(*args, **kw) + + @classmethod + def _new(cls, *args, **kw): if not args and not kw: # python3k pickle seems to call this return object.__new__(cls) @@ -607,14 +315,323 @@ class Table(DialectKWArgs, SchemaItem, TableClause): with util.safe_reraise(): metadata._remove_table(name, schema) - def __init__(self, *args, **kw): - """Constructor for :class:`_schema.Table`. + def __init__( + self, + name: str, + metadata: "MetaData", + *args: SchemaItem, + **kw: Any, + ): + r"""Constructor for :class:`_schema.Table`. - This method is a no-op. See the top-level - documentation for :class:`_schema.Table` - for constructor arguments. - """ + :param name: The name of this table as represented in the database. + + The table name, along with the value of the ``schema`` parameter, + forms a key which uniquely identifies this :class:`_schema.Table` + within + the owning :class:`_schema.MetaData` collection. + Additional calls to :class:`_schema.Table` with the same name, + metadata, + and schema name will return the same :class:`_schema.Table` object. + + Names which contain no upper case characters + will be treated as case insensitive names, and will not be quoted + unless they are a reserved word or contain special characters. + A name with any number of upper case characters is considered + to be case sensitive, and will be sent as quoted. + + To enable unconditional quoting for the table name, specify the flag + ``quote=True`` to the constructor, or use the :class:`.quoted_name` + construct to specify the name. + + :param metadata: a :class:`_schema.MetaData` + object which will contain this + table. The metadata is used as a point of association of this table + with other tables which are referenced via foreign key. It also + may be used to associate this table with a particular + :class:`.Connection` or :class:`.Engine`. + + :param \*args: Additional positional arguments are used primarily + to add the list of :class:`_schema.Column` + objects contained within this + table. Similar to the style of a CREATE TABLE statement, other + :class:`.SchemaItem` constructs may be added here, including + :class:`.PrimaryKeyConstraint`, and + :class:`_schema.ForeignKeyConstraint`. + + :param autoload: Defaults to ``False``, unless + :paramref:`_schema.Table.autoload_with` + is set in which case it defaults to ``True``; + :class:`_schema.Column` objects + for this table should be reflected from the database, possibly + augmenting objects that were explicitly specified. + :class:`_schema.Column` and other objects explicitly set on the + table will replace corresponding reflected objects. + + .. deprecated:: 1.4 + + The autoload parameter is deprecated and will be removed in + version 2.0. Please use the + :paramref:`_schema.Table.autoload_with` parameter, passing an + engine or connection. + + .. seealso:: + + :ref:`metadata_reflection_toplevel` + + :param autoload_replace: Defaults to ``True``; when using + :paramref:`_schema.Table.autoload` + in conjunction with :paramref:`_schema.Table.extend_existing`, + indicates + that :class:`_schema.Column` objects present in the already-existing + :class:`_schema.Table` + object should be replaced with columns of the same + name retrieved from the autoload process. When ``False``, columns + already present under existing names will be omitted from the + reflection process. + + Note that this setting does not impact :class:`_schema.Column` objects + specified programmatically within the call to :class:`_schema.Table` + that + also is autoloading; those :class:`_schema.Column` objects will always + replace existing columns of the same name when + :paramref:`_schema.Table.extend_existing` is ``True``. + + .. seealso:: + + :paramref:`_schema.Table.autoload` + + :paramref:`_schema.Table.extend_existing` + + :param autoload_with: An :class:`_engine.Engine` or + :class:`_engine.Connection` object, + or a :class:`_reflection.Inspector` object as returned by + :func:`_sa.inspect` + against one, with which this :class:`_schema.Table` + object will be reflected. + When set to a non-None value, the autoload process will take place + for this table against the given engine or connection. + + :param extend_existing: When ``True``, indicates that if this + :class:`_schema.Table` is already present in the given + :class:`_schema.MetaData`, + apply further arguments within the constructor to the existing + :class:`_schema.Table`. + + If :paramref:`_schema.Table.extend_existing` or + :paramref:`_schema.Table.keep_existing` are not set, + and the given name + of the new :class:`_schema.Table` refers to a :class:`_schema.Table` + that is + already present in the target :class:`_schema.MetaData` collection, + and + this :class:`_schema.Table` + specifies additional columns or other constructs + or flags that modify the table's state, an + error is raised. The purpose of these two mutually-exclusive flags + is to specify what action should be taken when a + :class:`_schema.Table` + is specified that matches an existing :class:`_schema.Table`, + yet specifies + additional constructs. + + :paramref:`_schema.Table.extend_existing` + will also work in conjunction + with :paramref:`_schema.Table.autoload` to run a new reflection + operation against the database, even if a :class:`_schema.Table` + of the same name is already present in the target + :class:`_schema.MetaData`; newly reflected :class:`_schema.Column` + objects + and other options will be added into the state of the + :class:`_schema.Table`, potentially overwriting existing columns + and options of the same name. + + As is always the case with :paramref:`_schema.Table.autoload`, + :class:`_schema.Column` objects can be specified in the same + :class:`_schema.Table` + constructor, which will take precedence. Below, the existing + table ``mytable`` will be augmented with :class:`_schema.Column` + objects + both reflected from the database, as well as the given + :class:`_schema.Column` + named "y":: + + Table("mytable", metadata, + Column('y', Integer), + extend_existing=True, + autoload_with=engine + ) + + .. seealso:: + + :paramref:`_schema.Table.autoload` + + :paramref:`_schema.Table.autoload_replace` + + :paramref:`_schema.Table.keep_existing` + + + :param implicit_returning: True by default - indicates that + RETURNING can be used by default to fetch newly inserted primary key + values, for backends which support this. Note that + :func:`_sa.create_engine` also provides an ``implicit_returning`` + flag. + + :param include_columns: A list of strings indicating a subset of + columns to be loaded via the ``autoload`` operation; table columns who + aren't present in this list will not be represented on the resulting + ``Table`` object. Defaults to ``None`` which indicates all columns + should be reflected. + + :param resolve_fks: Whether or not to reflect :class:`_schema.Table` + objects + related to this one via :class:`_schema.ForeignKey` objects, when + :paramref:`_schema.Table.autoload` or + :paramref:`_schema.Table.autoload_with` is + specified. Defaults to True. Set to False to disable reflection of + related tables as :class:`_schema.ForeignKey` + objects are encountered; may be + used either to save on SQL calls or to avoid issues with related tables + that can't be accessed. Note that if a related table is already present + in the :class:`_schema.MetaData` collection, or becomes present later, + a + :class:`_schema.ForeignKey` object associated with this + :class:`_schema.Table` will + resolve to that table normally. + + .. versionadded:: 1.3 + + .. seealso:: + + :paramref:`.MetaData.reflect.resolve_fks` + + + :param info: Optional data dictionary which will be populated into the + :attr:`.SchemaItem.info` attribute of this object. + + :param keep_existing: When ``True``, indicates that if this Table + is already present in the given :class:`_schema.MetaData`, ignore + further arguments within the constructor to the existing + :class:`_schema.Table`, and return the :class:`_schema.Table` + object as + originally created. This is to allow a function that wishes + to define a new :class:`_schema.Table` on first call, but on + subsequent calls will return the same :class:`_schema.Table`, + without any of the declarations (particularly constraints) + being applied a second time. + + If :paramref:`_schema.Table.extend_existing` or + :paramref:`_schema.Table.keep_existing` are not set, + and the given name + of the new :class:`_schema.Table` refers to a :class:`_schema.Table` + that is + already present in the target :class:`_schema.MetaData` collection, + and + this :class:`_schema.Table` + specifies additional columns or other constructs + or flags that modify the table's state, an + error is raised. The purpose of these two mutually-exclusive flags + is to specify what action should be taken when a + :class:`_schema.Table` + is specified that matches an existing :class:`_schema.Table`, + yet specifies + additional constructs. + + .. seealso:: + + :paramref:`_schema.Table.extend_existing` + + :param listeners: A list of tuples of the form ``(, )`` + which will be passed to :func:`.event.listen` upon construction. + This alternate hook to :func:`.event.listen` allows the establishment + of a listener function specific to this :class:`_schema.Table` before + the "autoload" process begins. Historically this has been intended + for use with the :meth:`.DDLEvents.column_reflect` event, however + note that this event hook may now be associated with the + :class:`_schema.MetaData` object directly:: + + def listen_for_reflect(table, column_info): + "handle the column reflection event" + # ... + + t = Table( + 'sometable', + autoload_with=engine, + listeners=[ + ('column_reflect', listen_for_reflect) + ]) + + .. seealso:: + + :meth:`_events.DDLEvents.column_reflect` + + :param must_exist: When ``True``, indicates that this Table must already + be present in the given :class:`_schema.MetaData` collection, else + an exception is raised. + + :param prefixes: + A list of strings to insert after CREATE in the CREATE TABLE + statement. They will be separated by spaces. + + :param quote: Force quoting of this table's name on or off, corresponding + to ``True`` or ``False``. When left at its default of ``None``, + the column identifier will be quoted according to whether the name is + case sensitive (identifiers with at least one upper case character are + treated as case sensitive), or if it's a reserved word. This flag + is only needed to force quoting of a reserved word which is not known + by the SQLAlchemy dialect. + + .. note:: setting this flag to ``False`` will not provide + case-insensitive behavior for table reflection; table reflection + will always search for a mixed-case name in a case sensitive + fashion. Case insensitive names are specified in SQLAlchemy only + by stating the name with all lower case characters. + + :param quote_schema: same as 'quote' but applies to the schema identifier. + + :param schema: The schema name for this table, which is required if + the table resides in a schema other than the default selected schema + for the engine's database connection. Defaults to ``None``. + + If the owning :class:`_schema.MetaData` of this :class:`_schema.Table` + specifies its + own :paramref:`_schema.MetaData.schema` parameter, + then that schema name will + be applied to this :class:`_schema.Table` + if the schema parameter here is set + to ``None``. To set a blank schema name on a :class:`_schema.Table` + that + would otherwise use the schema set on the owning + :class:`_schema.MetaData`, + specify the special symbol :attr:`.BLANK_SCHEMA`. + + .. versionadded:: 1.0.14 Added the :attr:`.BLANK_SCHEMA` symbol to + allow a :class:`_schema.Table` + to have a blank schema name even when the + parent :class:`_schema.MetaData` specifies + :paramref:`_schema.MetaData.schema`. + + The quoting rules for the schema name are the same as those for the + ``name`` parameter, in that quoting is applied for reserved words or + case-sensitive names; to enable unconditional quoting for the schema + name, specify the flag ``quote_schema=True`` to the constructor, or use + the :class:`.quoted_name` construct to specify the name. + + :param comment: Optional string that will render an SQL comment on table + creation. + + .. versionadded:: 1.2 Added the :paramref:`_schema.Table.comment` + parameter + to :class:`_schema.Table`. + + :param \**kw: Additional keyword arguments not mentioned above are + dialect specific, and passed in the form ``_``. + See the documentation regarding an individual dialect at + :ref:`dialect_toplevel` for detail on documented arguments. + + """ # noqa E501 + # __init__ is overridden to prevent __new__ from # calling the superclass constructor. @@ -1203,7 +1220,7 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]): ) -> None: ... - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any): r""" Construct a new ``Column`` object. @@ -2179,18 +2196,18 @@ class ForeignKey(DialectKWArgs, SchemaItem): def __init__( self, - column, - _constraint=None, - use_alter=False, - name=None, - onupdate=None, - ondelete=None, - deferrable=None, - initially=None, - link_to_name=False, - match=None, - info=None, - **dialect_kw, + column: Union[str, Column, SQLCoreOperations], + _constraint: Optional["ForeignKeyConstraint"] = None, + use_alter: bool = False, + name: Optional[str] = None, + onupdate: Optional[str] = None, + ondelete: Optional[str] = None, + deferrable: Optional[bool] = None, + initially: Optional[bool] = None, + link_to_name: bool = False, + match: Optional[str] = None, + info: Optional[Dict[Any, Any]] = None, + **dialect_kw: Any, ): r""" Construct a column-level FOREIGN KEY. @@ -2337,7 +2354,7 @@ class ForeignKey(DialectKWArgs, SchemaItem): ) return self._schema_item_copy(fk) - def _get_colspec(self, schema=None, table_name=None): + def _get_colspec(self, schema=None, table_name=None, _is_copy=False): """Return a string based 'column specification' for this :class:`_schema.ForeignKey`. @@ -2357,6 +2374,14 @@ class ForeignKey(DialectKWArgs, SchemaItem): else: return "%s.%s" % (table_name, colname) elif self._table_column is not None: + if self._table_column.table is None: + if _is_copy: + raise exc.InvalidRequestError( + f"Can't copy ForeignKey object which refers to " + f"non-table bound Column {self._table_column!r}" + ) + else: + return self._table_column.key return "%s.%s" % ( self._table_column.table.fullname, self._table_column.key, @@ -3858,6 +3883,7 @@ class ForeignKeyConstraint(ColumnCollectionConstraint): if target_table is not None and x._table_key() == x.parent.table.key else None, + _is_copy=True, ) for x in self.elements ], @@ -4331,10 +4357,10 @@ class MetaData(SchemaItem): def __init__( self, - schema=None, - quote_schema=None, - naming_convention=None, - info=None, + schema: Optional[str] = None, + quote_schema: Optional[bool] = None, + naming_convention: Optional[Dict[str, str]] = None, + info: Optional[Dict[Any, Any]] = None, ): """Create a new MetaData object. @@ -4465,7 +4491,7 @@ class MetaData(SchemaItem): self._sequences = {} self._fk_memos = collections.defaultdict(list) - tables = None + tables: Dict[str, Table] """A dictionary of :class:`_schema.Table` objects keyed to their name or "table key". @@ -4483,10 +4509,10 @@ class MetaData(SchemaItem): """ - def __repr__(self): + def __repr__(self) -> str: return "MetaData()" - def __contains__(self, table_or_key): + def __contains__(self, table_or_key: Union[str, Table]) -> bool: if not isinstance(table_or_key, str): table_or_key = table_or_key.key return table_or_key in self.tables @@ -4530,20 +4556,20 @@ class MetaData(SchemaItem): self._schemas = state["schemas"] self._fk_memos = state["fk_memos"] - def clear(self): + def clear(self) -> None: """Clear all Table objects from this MetaData.""" dict.clear(self.tables) self._schemas.clear() self._fk_memos.clear() - def remove(self, table): + def remove(self, table: Table) -> None: """Remove the given Table object from this MetaData.""" self._remove_table(table.name, table.schema) @property - def sorted_tables(self): + def sorted_tables(self) -> List[Table]: """Returns a list of :class:`_schema.Table` objects sorted in order of foreign key dependency. @@ -4599,14 +4625,14 @@ class MetaData(SchemaItem): def reflect( self, - bind, - schema=None, - views=False, - only=None, - extend_existing=False, - autoload_replace=True, - resolve_fks=True, - **dialect_kwargs, + bind: Union["Engine", "Connection"], + schema: Optional[str] = None, + views: bool = False, + only: Optional[_typing_Sequence[str]] = None, + extend_existing: bool = False, + autoload_replace: bool = True, + resolve_fks: bool = True, + **dialect_kwargs: Any, ): r"""Load all available table definitions from the database. @@ -4754,7 +4780,12 @@ class MetaData(SchemaItem): except exc.UnreflectableTableError as uerr: util.warn("Skipping table %s: %s" % (name, uerr)) - def create_all(self, bind, tables=None, checkfirst=True): + def create_all( + self, + bind: Union["Engine", "Connection"], + tables: Optional[_typing_Sequence[Table]] = None, + checkfirst: bool = True, + ): """Create all tables stored in this metadata. Conditional by default, will not attempt to recreate tables already @@ -4777,7 +4808,12 @@ class MetaData(SchemaItem): ddl.SchemaGenerator, self, checkfirst=checkfirst, tables=tables ) - def drop_all(self, bind, tables=None, checkfirst=True): + def drop_all( + self, + bind: Union["Engine", "Connection"], + tables: Optional[_typing_Sequence[Table]] = None, + checkfirst: bool = True, + ): """Drop all tables stored in this metadata. Conditional by default, will not attempt to drop tables not present in diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index e1bbcffec8..b0985f75d8 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -12,14 +12,13 @@ SQL tables and derived rowsets. """ import collections +from enum import Enum import itertools from operator import attrgetter import typing from typing import Any as TODO_Any from typing import Optional from typing import Tuple -from typing import Type -from typing import Union from . import cache_key from . import coercions @@ -28,6 +27,7 @@ from . import roles from . import traversals from . import type_api from . import visitors +from ._typing import _ColumnsClauseElement from .annotation import Annotated from .annotation import SupportsCloneAnnotations from .base import _clone @@ -847,8 +847,11 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): return self.alias(name=name) -LABEL_STYLE_NONE = util.symbol( - "LABEL_STYLE_NONE", +class SelectLabelStyle(Enum): + """Label style constants that may be passed to + :meth:`_sql.Select.set_label_style`.""" + + LABEL_STYLE_NONE = 0 """Label style indicating no automatic labeling should be applied to the columns clause of a SELECT statement. @@ -867,11 +870,9 @@ LABEL_STYLE_NONE = util.symbol( .. versionadded:: 1.4 -""", # noqa E501 -) + """ # noqa E501 -LABEL_STYLE_TABLENAME_PLUS_COL = util.symbol( - "LABEL_STYLE_TABLENAME_PLUS_COL", + LABEL_STYLE_TABLENAME_PLUS_COL = 1 """Label style indicating all columns should be labeled as ``_`` when generating the columns clause of a SELECT statement, to disambiguate same-named columns referenced from different @@ -897,12 +898,9 @@ LABEL_STYLE_TABLENAME_PLUS_COL = util.symbol( .. versionadded:: 1.4 -""", # noqa E501 -) + """ # noqa: E501 - -LABEL_STYLE_DISAMBIGUATE_ONLY = util.symbol( - "LABEL_STYLE_DISAMBIGUATE_ONLY", + LABEL_STYLE_DISAMBIGUATE_ONLY = 2 """Label style indicating that columns with a name that conflicts with an existing name should be labeled with a semi-anonymizing label when generating the columns clause of a SELECT statement. @@ -924,17 +922,24 @@ LABEL_STYLE_DISAMBIGUATE_ONLY = util.symbol( .. versionadded:: 1.4 -""", # noqa: E501, -) + """ # noqa: E501 + LABEL_STYLE_DEFAULT = LABEL_STYLE_DISAMBIGUATE_ONLY + """The default label style, refers to + :data:`_sql.LABEL_STYLE_DISAMBIGUATE_ONLY`. -LABEL_STYLE_DEFAULT = LABEL_STYLE_DISAMBIGUATE_ONLY -"""The default label style, refers to -:data:`_sql.LABEL_STYLE_DISAMBIGUATE_ONLY`. + .. versionadded:: 1.4 -.. versionadded:: 1.4 + """ -""" + +( + LABEL_STYLE_NONE, + LABEL_STYLE_TABLENAME_PLUS_COL, + LABEL_STYLE_DISAMBIGUATE_ONLY, +) = list(SelectLabelStyle) + +LABEL_STYLE_DEFAULT = LABEL_STYLE_DISAMBIGUATE_ONLY class Join(roles.DMLTableRole, FromClause): @@ -2870,10 +2875,12 @@ class SelectStatementGrouping(GroupedElement, SelectBase): else: return self - def get_label_style(self): + def get_label_style(self) -> SelectLabelStyle: return self._label_style - def set_label_style(self, label_style): + def set_label_style( + self, label_style: SelectLabelStyle + ) -> "SelectStatementGrouping": return SelectStatementGrouping( self.element.set_label_style(label_style) ) @@ -3018,7 +3025,7 @@ class GenerativeSelect(SelectBase): ) return self - def get_label_style(self): + def get_label_style(self) -> SelectLabelStyle: """ Retrieve the current label style. @@ -3027,14 +3034,16 @@ class GenerativeSelect(SelectBase): """ return self._label_style - def set_label_style(self, style): + def set_label_style( + self: SelfGenerativeSelect, style: SelectLabelStyle + ) -> SelfGenerativeSelect: """Return a new selectable with the specified label style. There are three "label styles" available, - :data:`_sql.LABEL_STYLE_DISAMBIGUATE_ONLY`, - :data:`_sql.LABEL_STYLE_TABLENAME_PLUS_COL`, and - :data:`_sql.LABEL_STYLE_NONE`. The default style is - :data:`_sql.LABEL_STYLE_TABLENAME_PLUS_COL`. + :attr:`_sql.SelectLabelStyle.LABEL_STYLE_DISAMBIGUATE_ONLY`, + :attr:`_sql.SelectLabelStyle.LABEL_STYLE_TABLENAME_PLUS_COL`, and + :attr:`_sql.SelectLabelStyle.LABEL_STYLE_NONE`. The default style is + :attr:`_sql.SelectLabelStyle.LABEL_STYLE_TABLENAME_PLUS_COL`. In modern SQLAlchemy, there is not generally a need to change the labeling style, as per-expression labels are more effectively used by @@ -4131,7 +4140,7 @@ class Select( stmt.__dict__.update(kw) return stmt - def __init__(self, *entities: Union[roles.ColumnsClauseRole, Type]): + def __init__(self, *entities: _ColumnsClauseElement): r"""Construct a new :class:`_expression.Select`. The public constructor for :class:`_expression.Select` is the diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py index dd29b2c3ad..6b878dc70b 100644 --- a/lib/sqlalchemy/sql/type_api.py +++ b/lib/sqlalchemy/sql/type_api.py @@ -13,6 +13,7 @@ import typing from typing import Any from typing import Callable from typing import Generic +from typing import Optional from typing import Tuple from typing import Type from typing import TypeVar @@ -21,7 +22,7 @@ from typing import Union from .base import SchemaEventTarget from .cache_key import NO_CACHE from .operators import ColumnOperators -from .visitors import Traversible +from .visitors import Visitable from .. import exc from .. import util @@ -52,7 +53,7 @@ _CT = TypeVar("_CT", bound=Any) SelfTypeEngine = typing.TypeVar("SelfTypeEngine", bound="TypeEngine") -class TypeEngine(Traversible, Generic[_T]): +class TypeEngine(Visitable, Generic[_T]): """The ultimate base class for all SQL datatypes. Common subclasses of :class:`.TypeEngine` include @@ -573,7 +574,7 @@ class TypeEngine(Traversible, Generic[_T]): raise NotImplementedError() def with_variant( - self: SelfTypeEngine, type_: "TypeEngine", dialect_name: str + self: SelfTypeEngine, type_: "TypeEngine", *dialect_names: str ) -> SelfTypeEngine: r"""Produce a copy of this type object that will utilize the given type when applied to the dialect of the given name. @@ -586,7 +587,7 @@ class TypeEngine(Traversible, Generic[_T]): string_type = String() string_type = string_type.with_variant( - mysql.VARCHAR(collation='foo'), 'mysql' + mysql.VARCHAR(collation='foo'), 'mysql', 'mariadb' ) The variant mapping indicates that when this type is @@ -602,16 +603,20 @@ class TypeEngine(Traversible, Generic[_T]): :param type\_: a :class:`.TypeEngine` that will be selected as a variant from the originating type, when a dialect of the given name is in use. - :param dialect_name: base name of the dialect which uses - this type. (i.e. ``'postgresql'``, ``'mysql'``, etc.) + :param \*dialect_names: one or more base names of the dialect which + uses this type. (i.e. ``'postgresql'``, ``'mysql'``, etc.) + + .. versionchanged:: 2.0 multiple dialect names can be specified + for one variant. """ - if dialect_name in self._variant_mapping: - raise exc.ArgumentError( - "Dialect '%s' is already present in " - "the mapping for this %r" % (dialect_name, self) - ) + for dialect_name in dialect_names: + if dialect_name in self._variant_mapping: + raise exc.ArgumentError( + "Dialect '%s' is already present in " + "the mapping for this %r" % (dialect_name, self) + ) new_type = self.copy() if isinstance(type_, type): type_ = type_() @@ -620,8 +625,9 @@ class TypeEngine(Traversible, Generic[_T]): "can't pass a type that already has variants as a " "dialect-level type to with_variant()" ) + new_type._variant_mapping = self._variant_mapping.union( - {dialect_name: type_} + {dialect_name: type_ for dialect_name in dialect_names} ) return new_type @@ -919,7 +925,7 @@ class ExternalType: """ - cache_ok = None + cache_ok: Optional[bool] = None """Indicate if statements using this :class:`.ExternalType` are "safe to cache". @@ -1357,6 +1363,8 @@ class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine[_T]): _is_type_decorator = True + impl: Union[TypeEngine[Any], Type[TypeEngine[Any]]] + def __init__(self, *args, **kwargs): """Construct a :class:`.TypeDecorator`. diff --git a/lib/sqlalchemy/testing/config.py b/lib/sqlalchemy/testing/config.py index 268a564215..c1ca670dac 100644 --- a/lib/sqlalchemy/testing/config.py +++ b/lib/sqlalchemy/testing/config.py @@ -6,6 +6,11 @@ # the MIT License: https://www.opensource.org/licenses/mit-license.php import collections +import typing +from typing import Any +from typing import Iterable +from typing import Tuple +from typing import Union from .. import util @@ -20,10 +25,15 @@ any_async = False _current = None ident = "main" -_fixture_functions = None # installed by plugin_base +if typing.TYPE_CHECKING: + from .plugin.plugin_base import FixtureFunctions + _fixture_functions: FixtureFunctions +else: + _fixture_functions = None # installed by plugin_base -def combinations(*comb, **kw): + +def combinations(*comb: Union[Any, Tuple[Any, ...]], **kw: str): r"""Deliver multiple versions of a test based on positional combinations. This is a facade over pytest.mark.parametrize. @@ -89,25 +99,32 @@ def combinations(*comb, **kw): return _fixture_functions.combinations(*comb, **kw) -def combinations_list(arg_iterable, **kw): +def combinations_list( + arg_iterable: Iterable[ + Tuple[ + Any, + ] + ], + **kw, +): "As combination, but takes a single iterable" return combinations(*arg_iterable, **kw) -def fixture(*arg, **kw): +def fixture(*arg: Any, **kw: Any) -> Any: return _fixture_functions.fixture(*arg, **kw) -def get_current_test_name(): +def get_current_test_name() -> str: return _fixture_functions.get_current_test_name() -def mark_base_test_class(): +def mark_base_test_class() -> Any: return _fixture_functions.mark_base_test_class() class _AddToMarker: - def __getattr__(self, attr): + def __getattr__(self, attr: str) -> Any: return getattr(_fixture_functions.add_to_marker, attr) diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py index ecc20f1638..7228e5afeb 100644 --- a/lib/sqlalchemy/testing/fixtures.py +++ b/lib/sqlalchemy/testing/fixtures.py @@ -20,6 +20,7 @@ from .util import drop_all_tables_from_metadata from .. import event from .. import util from ..orm import declarative_base +from ..orm import DeclarativeBase from ..orm import registry from ..schema import sort_tables_and_constraints @@ -82,6 +83,21 @@ class TestBase: yield reg reg.dispose() + @config.fixture + def decl_base(self, metadata): + _md = metadata + + class Base(DeclarativeBase): + metadata = _md + type_annotation_map = { + str: sa.String().with_variant( + sa.String(50), "mysql", "mariadb" + ) + } + + yield Base + Base.registry.dispose() + @config.fixture() def future_connection(self, future_engine, connection): # integrate the future_engine and connection fixtures so diff --git a/lib/sqlalchemy/testing/plugin/plugin_base.py b/lib/sqlalchemy/testing/plugin/plugin_base.py index 0b4451b3c8..52e42bb974 100644 --- a/lib/sqlalchemy/testing/plugin/plugin_base.py +++ b/lib/sqlalchemy/testing/plugin/plugin_base.py @@ -19,6 +19,7 @@ import logging import os import re import sys +from typing import Any from sqlalchemy.testing import asyncio @@ -738,7 +739,7 @@ class FixtureFunctions(abc.ABC): raise NotImplementedError() @abc.abstractmethod - def mark_base_test_class(self): + def mark_base_test_class(self) -> Any: raise NotImplementedError() @abc.abstractproperty diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py index 410ab26edc..41e5d6772d 100644 --- a/lib/sqlalchemy/testing/requirements.py +++ b/lib/sqlalchemy/testing/requirements.py @@ -1325,6 +1325,18 @@ class SuiteRequirements(Requirements): return exclusions.only_if(check) + @property + def no_sqlalchemy2_stubs(self): + def check(config): + try: + __import__("sqlalchemy-stubs.ext.mypy") + except ImportError: + return False + else: + return True + + return exclusions.skip_if(check) + @property def python38(self): return exclusions.only_if( diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py index 91d15aae08..85bbca20f5 100644 --- a/lib/sqlalchemy/util/__init__.py +++ b/lib/sqlalchemy/util/__init__.py @@ -6,131 +6,135 @@ # the MIT License: https://www.opensource.org/licenses/mit-license.php -from collections import defaultdict -from functools import partial -from functools import update_wrapper +from collections import defaultdict as defaultdict +from functools import partial as partial +from functools import update_wrapper as update_wrapper -from ._collections import coerce_generator_arg -from ._collections import coerce_to_immutabledict -from ._collections import column_dict -from ._collections import column_set -from ._collections import EMPTY_DICT -from ._collections import EMPTY_SET -from ._collections import FacadeDict -from ._collections import flatten_iterator -from ._collections import has_dupes -from ._collections import has_intersection -from ._collections import IdentitySet -from ._collections import ImmutableContainer -from ._collections import immutabledict -from ._collections import ImmutableProperties -from ._collections import LRUCache -from ._collections import ordered_column_set -from ._collections import OrderedDict -from ._collections import OrderedIdentitySet -from ._collections import OrderedProperties -from ._collections import OrderedSet -from ._collections import PopulateDict -from ._collections import Properties -from ._collections import ScopedRegistry -from ._collections import sort_dictionary -from ._collections import ThreadLocalRegistry -from ._collections import to_column_set -from ._collections import to_list -from ._collections import to_set -from ._collections import unique_list -from ._collections import UniqueAppender -from ._collections import update_copy -from ._collections import WeakPopulateDict -from ._collections import WeakSequence -from ._preloaded import preload_module -from ._preloaded import preloaded -from .compat import arm -from .compat import b -from .compat import b64decode -from .compat import b64encode -from .compat import cmp -from .compat import cpython -from .compat import dataclass_fields -from .compat import decode_backslashreplace -from .compat import dottedgetter -from .compat import has_refcount_gc -from .compat import inspect_getfullargspec -from .compat import local_dataclass_fields -from .compat import next -from .compat import osx -from .compat import py38 -from .compat import py39 -from .compat import pypy -from .compat import win32 -from .concurrency import asyncio -from .concurrency import await_fallback -from .concurrency import await_only -from .concurrency import greenlet_spawn -from .concurrency import is_exit_exception -from .deprecations import became_legacy_20 -from .deprecations import deprecated -from .deprecations import deprecated_cls -from .deprecations import deprecated_params -from .deprecations import deprecated_property -from .deprecations import inject_docstring_text -from .deprecations import moved_20 -from .deprecations import warn_deprecated -from .langhelpers import add_parameter_text -from .langhelpers import as_interface -from .langhelpers import asbool -from .langhelpers import asint -from .langhelpers import assert_arg_type -from .langhelpers import attrsetter -from .langhelpers import bool_or_str -from .langhelpers import chop_traceback -from .langhelpers import class_hierarchy -from .langhelpers import classproperty -from .langhelpers import clsname_as_plain_name -from .langhelpers import coerce_kw_type -from .langhelpers import constructor_copy -from .langhelpers import constructor_key -from .langhelpers import counter -from .langhelpers import create_proxy_methods -from .langhelpers import decode_slice -from .langhelpers import decorator -from .langhelpers import dictlike_iteritems -from .langhelpers import duck_type_collection -from .langhelpers import ellipses_string -from .langhelpers import EnsureKWArg -from .langhelpers import format_argspec_init -from .langhelpers import format_argspec_plus -from .langhelpers import generic_repr -from .langhelpers import get_callable_argspec -from .langhelpers import get_cls_kwargs -from .langhelpers import get_func_kwargs -from .langhelpers import getargspec_init -from .langhelpers import has_compiled_ext -from .langhelpers import HasMemoized -from .langhelpers import hybridmethod -from .langhelpers import hybridproperty -from .langhelpers import iterate_attributes -from .langhelpers import map_bits -from .langhelpers import md5_hex -from .langhelpers import memoized_instancemethod -from .langhelpers import memoized_property -from .langhelpers import MemoizedSlots -from .langhelpers import method_is_overridden -from .langhelpers import methods_equivalent -from .langhelpers import monkeypatch_proxied_specials -from .langhelpers import NoneType -from .langhelpers import only_once -from .langhelpers import PluginLoader -from .langhelpers import portable_instancemethod -from .langhelpers import quoted_token_parser -from .langhelpers import safe_reraise -from .langhelpers import set_creation_order -from .langhelpers import string_or_unprintable -from .langhelpers import symbol -from .langhelpers import TypingOnly -from .langhelpers import unbound_method_to_callable -from .langhelpers import walk_subclasses -from .langhelpers import warn -from .langhelpers import warn_exception -from .langhelpers import warn_limited -from .langhelpers import wrap_callable +from ._collections import coerce_generator_arg as coerce_generator_arg +from ._collections import coerce_to_immutabledict as coerce_to_immutabledict +from ._collections import column_dict as column_dict +from ._collections import column_set as column_set +from ._collections import EMPTY_DICT as EMPTY_DICT +from ._collections import EMPTY_SET as EMPTY_SET +from ._collections import FacadeDict as FacadeDict +from ._collections import flatten_iterator as flatten_iterator +from ._collections import has_dupes as has_dupes +from ._collections import has_intersection as has_intersection +from ._collections import IdentitySet as IdentitySet +from ._collections import ImmutableContainer as ImmutableContainer +from ._collections import immutabledict as immutabledict +from ._collections import ImmutableProperties as ImmutableProperties +from ._collections import LRUCache as LRUCache +from ._collections import merge_lists_w_ordering as merge_lists_w_ordering +from ._collections import ordered_column_set as ordered_column_set +from ._collections import OrderedDict as OrderedDict +from ._collections import OrderedIdentitySet as OrderedIdentitySet +from ._collections import OrderedProperties as OrderedProperties +from ._collections import OrderedSet as OrderedSet +from ._collections import PopulateDict as PopulateDict +from ._collections import Properties as Properties +from ._collections import ScopedRegistry as ScopedRegistry +from ._collections import sort_dictionary as sort_dictionary +from ._collections import ThreadLocalRegistry as ThreadLocalRegistry +from ._collections import to_column_set as to_column_set +from ._collections import to_list as to_list +from ._collections import to_set as to_set +from ._collections import unique_list as unique_list +from ._collections import UniqueAppender as UniqueAppender +from ._collections import update_copy as update_copy +from ._collections import WeakPopulateDict as WeakPopulateDict +from ._collections import WeakSequence as WeakSequence +from ._preloaded import preload_module as preload_module +from ._preloaded import preloaded as preloaded +from .compat import arm as arm +from .compat import b as b +from .compat import b64decode as b64decode +from .compat import b64encode as b64encode +from .compat import cmp as cmp +from .compat import cpython as cpython +from .compat import dataclass_fields as dataclass_fields +from .compat import decode_backslashreplace as decode_backslashreplace +from .compat import dottedgetter as dottedgetter +from .compat import has_refcount_gc as has_refcount_gc +from .compat import inspect_getfullargspec as inspect_getfullargspec +from .compat import local_dataclass_fields as local_dataclass_fields +from .compat import osx as osx +from .compat import py38 as py38 +from .compat import py39 as py39 +from .compat import pypy as pypy +from .compat import win32 as win32 +from .concurrency import await_fallback as await_fallback +from .concurrency import await_only as await_only +from .concurrency import greenlet_spawn as greenlet_spawn +from .concurrency import is_exit_exception as is_exit_exception +from .deprecations import became_legacy_20 as became_legacy_20 +from .deprecations import deprecated as deprecated +from .deprecations import deprecated_cls as deprecated_cls +from .deprecations import deprecated_params as deprecated_params +from .deprecations import deprecated_property as deprecated_property +from .deprecations import moved_20 as moved_20 +from .deprecations import warn_deprecated as warn_deprecated +from .langhelpers import add_parameter_text as add_parameter_text +from .langhelpers import as_interface as as_interface +from .langhelpers import asbool as asbool +from .langhelpers import asint as asint +from .langhelpers import assert_arg_type as assert_arg_type +from .langhelpers import attrsetter as attrsetter +from .langhelpers import bool_or_str as bool_or_str +from .langhelpers import chop_traceback as chop_traceback +from .langhelpers import class_hierarchy as class_hierarchy +from .langhelpers import classproperty as classproperty +from .langhelpers import clsname_as_plain_name as clsname_as_plain_name +from .langhelpers import coerce_kw_type as coerce_kw_type +from .langhelpers import constructor_copy as constructor_copy +from .langhelpers import constructor_key as constructor_key +from .langhelpers import counter as counter +from .langhelpers import create_proxy_methods as create_proxy_methods +from .langhelpers import decode_slice as decode_slice +from .langhelpers import decorator as decorator +from .langhelpers import dictlike_iteritems as dictlike_iteritems +from .langhelpers import duck_type_collection as duck_type_collection +from .langhelpers import ellipses_string as ellipses_string +from .langhelpers import EnsureKWArg as EnsureKWArg +from .langhelpers import format_argspec_init as format_argspec_init +from .langhelpers import format_argspec_plus as format_argspec_plus +from .langhelpers import generic_repr as generic_repr +from .langhelpers import get_annotations as get_annotations +from .langhelpers import get_callable_argspec as get_callable_argspec +from .langhelpers import get_cls_kwargs as get_cls_kwargs +from .langhelpers import get_func_kwargs as get_func_kwargs +from .langhelpers import getargspec_init as getargspec_init +from .langhelpers import has_compiled_ext as has_compiled_ext +from .langhelpers import HasMemoized as HasMemoized +from .langhelpers import hybridmethod as hybridmethod +from .langhelpers import hybridproperty as hybridproperty +from .langhelpers import inject_docstring_text as inject_docstring_text +from .langhelpers import iterate_attributes as iterate_attributes +from .langhelpers import map_bits as map_bits +from .langhelpers import md5_hex as md5_hex +from .langhelpers import memoized_instancemethod as memoized_instancemethod +from .langhelpers import memoized_property as memoized_property +from .langhelpers import MemoizedSlots as MemoizedSlots +from .langhelpers import method_is_overridden as method_is_overridden +from .langhelpers import methods_equivalent as methods_equivalent +from .langhelpers import ( + monkeypatch_proxied_specials as monkeypatch_proxied_specials, +) +from .langhelpers import NoneType as NoneType +from .langhelpers import only_once as only_once +from .langhelpers import PluginLoader as PluginLoader +from .langhelpers import portable_instancemethod as portable_instancemethod +from .langhelpers import quoted_token_parser as quoted_token_parser +from .langhelpers import safe_reraise as safe_reraise +from .langhelpers import set_creation_order as set_creation_order +from .langhelpers import string_or_unprintable as string_or_unprintable +from .langhelpers import symbol as symbol +from .langhelpers import TypingOnly as TypingOnly +from .langhelpers import ( + unbound_method_to_callable as unbound_method_to_callable, +) +from .langhelpers import walk_subclasses as walk_subclasses +from .langhelpers import warn as warn +from .langhelpers import warn_exception as warn_exception +from .langhelpers import warn_limited as warn_limited +from .langhelpers import wrap_callable as wrap_callable diff --git a/lib/sqlalchemy/util/_collections.py b/lib/sqlalchemy/util/_collections.py index 3e4ef1310d..8509868028 100644 --- a/lib/sqlalchemy/util/_collections.py +++ b/lib/sqlalchemy/util/_collections.py @@ -34,19 +34,27 @@ from ._has_cy import HAS_CYEXTENSION from .typing import Literal if typing.TYPE_CHECKING or not HAS_CYEXTENSION: - from ._py_collections import immutabledict - from ._py_collections import IdentitySet - from ._py_collections import ImmutableContainer - from ._py_collections import ImmutableDictBase - from ._py_collections import OrderedSet - from ._py_collections import unique_list # noqa + from ._py_collections import immutabledict as immutabledict + from ._py_collections import IdentitySet as IdentitySet + from ._py_collections import ImmutableContainer as ImmutableContainer + from ._py_collections import ImmutableDictBase as ImmutableDictBase + from ._py_collections import OrderedSet as OrderedSet + from ._py_collections import unique_list as unique_list else: - from sqlalchemy.cyextension.immutabledict import ImmutableContainer - from sqlalchemy.cyextension.immutabledict import ImmutableDictBase - from sqlalchemy.cyextension.immutabledict import immutabledict - from sqlalchemy.cyextension.collections import IdentitySet - from sqlalchemy.cyextension.collections import OrderedSet - from sqlalchemy.cyextension.collections import unique_list # noqa + from sqlalchemy.cyextension.immutabledict import ( + ImmutableContainer as ImmutableContainer, + ) + from sqlalchemy.cyextension.immutabledict import ( + ImmutableDictBase as ImmutableDictBase, + ) + from sqlalchemy.cyextension.immutabledict import ( + immutabledict as immutabledict, + ) + from sqlalchemy.cyextension.collections import IdentitySet as IdentitySet + from sqlalchemy.cyextension.collections import OrderedSet as OrderedSet + from sqlalchemy.cyextension.collections import ( # noqa + unique_list as unique_list, + ) _T = TypeVar("_T", bound=Any) @@ -57,6 +65,62 @@ _VT = TypeVar("_VT", bound=Any) EMPTY_SET: FrozenSet[Any] = frozenset() +def merge_lists_w_ordering(a, b): + """merge two lists, maintaining ordering as much as possible. + + this is to reconcile vars(cls) with cls.__annotations__. + + Example:: + + >>> a = ['__tablename__', 'id', 'x', 'created_at'] + >>> b = ['id', 'name', 'data', 'y', 'created_at'] + >>> merge_lists_w_ordering(a, b) + ['__tablename__', 'id', 'name', 'data', 'y', 'x', 'created_at'] + + This is not necessarily the ordering that things had on the class, + in this case the class is:: + + class User(Base): + __tablename__ = "users" + + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] + data: Mapped[Optional[str]] + x = Column(Integer) + y: Mapped[int] + created_at: Mapped[datetime.datetime] = mapped_column() + + But things are *mostly* ordered. + + The algorithm could also be done by creating a partial ordering for + all items in both lists and then using topological_sort(), but that + is too much overhead. + + Background on how I came up with this is at: + https://gist.github.com/zzzeek/89de958cf0803d148e74861bd682ebae + + """ + overlap = set(a).intersection(b) + + result = [] + + current, other = iter(a), iter(b) + + while True: + for element in current: + if element in overlap: + overlap.discard(element) + other, current = current, other + break + + result.append(element) + else: + result.extend(other) + break + + return result + + def coerce_to_immutabledict(d): if not d: return EMPTY_DICT diff --git a/lib/sqlalchemy/util/compat.py b/lib/sqlalchemy/util/compat.py index 0f4befbb1f..62cffa556e 100644 --- a/lib/sqlalchemy/util/compat.py +++ b/lib/sqlalchemy/util/compat.py @@ -39,7 +39,6 @@ arm = "aarch" in platform.machine().lower() has_refcount_gc = bool(cpython) dottedgetter = operator.attrgetter -next = next # noqa class FullArgSpec(typing.NamedTuple): diff --git a/lib/sqlalchemy/util/concurrency.py b/lib/sqlalchemy/util/concurrency.py index 57ef230062..6b94a22948 100644 --- a/lib/sqlalchemy/util/concurrency.py +++ b/lib/sqlalchemy/util/concurrency.py @@ -16,15 +16,17 @@ except ImportError as e: pass else: have_greenlet = True - from ._concurrency_py3k import await_only - from ._concurrency_py3k import await_fallback - from ._concurrency_py3k import greenlet_spawn - from ._concurrency_py3k import is_exit_exception - from ._concurrency_py3k import AsyncAdaptedLock - from ._concurrency_py3k import _util_async_run # noqa F401 + from ._concurrency_py3k import await_only as await_only + from ._concurrency_py3k import await_fallback as await_fallback + from ._concurrency_py3k import greenlet_spawn as greenlet_spawn + from ._concurrency_py3k import is_exit_exception as is_exit_exception + from ._concurrency_py3k import AsyncAdaptedLock as AsyncAdaptedLock from ._concurrency_py3k import ( - _util_async_run_coroutine_function, - ) # noqa F401, E501 + _util_async_run as _util_async_run, + ) # noqa F401 + from ._concurrency_py3k import ( + _util_async_run_coroutine_function as _util_async_run_coroutine_function, # noqa F401, E501 + ) if not have_greenlet: diff --git a/lib/sqlalchemy/util/deprecations.py b/lib/sqlalchemy/util/deprecations.py index 565cbafe26..7c25861665 100644 --- a/lib/sqlalchemy/util/deprecations.py +++ b/lib/sqlalchemy/util/deprecations.py @@ -13,6 +13,7 @@ from typing import Any from typing import Callable from typing import cast from typing import Optional +from typing import Tuple from typing import TypeVar from . import compat @@ -209,7 +210,10 @@ def became_legacy_20(api_name, alternative=None, **kw): return deprecated("2.0", message=message, warning=warning_cls, **kw) -def deprecated_params(**specs): +_C = TypeVar("_C", bound=Callable[..., Any]) + + +def deprecated_params(**specs: Tuple[str, str]) -> Callable[[_C], _C]: """Decorates a function to warn on use of certain parameters. e.g. :: diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py index 9401c249fe..ed879894d5 100644 --- a/lib/sqlalchemy/util/langhelpers.py +++ b/lib/sqlalchemy/util/langhelpers.py @@ -30,6 +30,7 @@ from typing import FrozenSet from typing import Generic from typing import Iterator from typing import List +from typing import Mapping from typing import Optional from typing import overload from typing import Sequence @@ -54,6 +55,30 @@ _HP = TypeVar("_HP", bound="hybridproperty") _HM = TypeVar("_HM", bound="hybridmethod") +if compat.py310: + + def get_annotations(obj: Any) -> Mapping[str, Any]: + return inspect.get_annotations(obj) + +else: + + def get_annotations(obj: Any) -> Mapping[str, Any]: + # it's been observed that cls.__annotations__ can be non present. + # it's not clear what causes this, running under tox py37/38 it + # happens, running straight pytest it doesnt + + # https://docs.python.org/3/howto/annotations.html#annotations-howto + if isinstance(obj, type): + ann = obj.__dict__.get("__annotations__", None) + else: + ann = getattr(obj, "__annotations__", None) + + if ann is None: + return _collections.EMPTY_DICT + else: + return cast("Mapping[str, Any]", ann) + + def md5_hex(x: Any) -> str: x = x.encode("utf-8") m = hashlib.md5() diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py index 62a9f6c8a8..56ea4d0e06 100644 --- a/lib/sqlalchemy/util/typing.py +++ b/lib/sqlalchemy/util/typing.py @@ -1,6 +1,10 @@ +import sys import typing from typing import Any from typing import Callable # noqa +from typing import cast +from typing import Dict +from typing import ForwardRef from typing import Generic from typing import overload from typing import Type @@ -13,21 +17,36 @@ from . import compat _T = TypeVar("_T", bound=Any) -if typing.TYPE_CHECKING or not compat.py38: - from typing_extensions import Literal # noqa F401 - from typing_extensions import Protocol # noqa F401 - from typing_extensions import TypedDict # noqa F401 +if compat.py310: + # why they took until py310 to put this in stdlib is beyond me, + # I've been wanting it since py27 + from types import NoneType else: - from typing import Literal # noqa F401 - from typing import Protocol # noqa F401 - from typing import TypedDict # noqa F401 + NoneType = type(None) # type: ignore + +if typing.TYPE_CHECKING or compat.py310: + from typing import Annotated as Annotated +else: + from typing_extensions import Annotated as Annotated # noqa F401 + +if typing.TYPE_CHECKING or compat.py38: + from typing import Literal as Literal + from typing import Protocol as Protocol + from typing import TypedDict as TypedDict +else: + from typing_extensions import Literal as Literal # noqa F401 + from typing_extensions import Protocol as Protocol # noqa F401 + from typing_extensions import TypedDict as TypedDict # noqa F401 + +# work around https://github.com/microsoft/pyright/issues/3025 +_LiteralStar = Literal["*"] if typing.TYPE_CHECKING or not compat.py310: - from typing_extensions import Concatenate # noqa F401 - from typing_extensions import ParamSpec # noqa F401 + from typing_extensions import Concatenate as Concatenate + from typing_extensions import ParamSpec as ParamSpec else: - from typing import Concatenate # noqa F401 - from typing import ParamSpec # noqa F401 + from typing import Concatenate as Concatenate # noqa F401 + from typing import ParamSpec as ParamSpec # noqa F401 class _TypeToInstance(Generic[_T]): @@ -76,3 +95,121 @@ class ReadOnlyInstanceDescriptor(Protocol[_T]): self, instance: object, owner: Any ) -> Union["ReadOnlyInstanceDescriptor[_T]", _T]: ... + + +def de_stringify_annotation( + cls: Type[Any], annotation: Union[str, Type[Any]] +) -> Union[str, Type[Any]]: + """Resolve annotations that may be string based into real objects. + + This is particularly important if a module defines "from __future__ import + annotations", as everything inside of __annotations__ is a string. We want + to at least have generic containers like ``Mapped``, ``Union``, ``List``, + etc. + + """ + + # looked at typing.get_type_hints(), looked at pydantic. We need much + # less here, and we here try to not use any private typing internals + # or construct ForwardRef objects which is documented as something + # that should be avoided. + + if ( + is_fwd_ref(annotation) + and not cast(ForwardRef, annotation).__forward_evaluated__ + ): + annotation = cast(ForwardRef, annotation).__forward_arg__ + + if isinstance(annotation, str): + base_globals: "Dict[str, Any]" = getattr( + sys.modules.get(cls.__module__, None), "__dict__", {} + ) + try: + annotation = eval(annotation, base_globals, None) + except NameError: + pass + return annotation + + +def is_fwd_ref(type_): + return isinstance(type_, ForwardRef) + + +def de_optionalize_union_types(type_): + """Given a type, filter out ``Union`` types that include ``NoneType`` + to not include the ``NoneType``. + + """ + if is_optional(type_): + typ = set(type_.__args__) + + typ.discard(NoneType) + + return make_union_type(*typ) + + else: + return type_ + + +def make_union_type(*types): + """Make a Union type. + + This is needed by :func:`.de_optionalize_union_types` which removes + ``NoneType`` from a ``Union``. + + """ + return cast(Any, Union).__getitem__(types) + + +def expand_unions(type_, include_union=False, discard_none=False): + """Return a type as as a tuple of individual types, expanding for + ``Union`` types.""" + + if is_union(type_): + typ = set(type_.__args__) + + if discard_none: + typ.discard(NoneType) + + if include_union: + return (type_,) + tuple(typ) + else: + return tuple(typ) + else: + return (type_,) + + +def is_optional(type_): + return is_origin_of( + type_, + "Optional", + "Union", + ) + + +def is_union(type_): + return is_origin_of(type_, "Union") + + +def is_origin_of(type_, *names, module=None): + """return True if the given type has an __origin__ with the given name + and optional module.""" + + origin = getattr(type_, "__origin__", None) + if origin is None: + return False + + return _get_type_name(origin) in names and ( + module is None or origin.__module__.startswith(module) + ) + + +def _get_type_name(type_): + if compat.py310: + return type_.__name__ + else: + typ_name = getattr(type_, "__name__", None) + if typ_name is None: + typ_name = getattr(type_, "_name", None) + + return typ_name diff --git a/mypy_plugin.ini b/mypy_plugin.ini new file mode 100644 index 0000000000..34ddc371ce --- /dev/null +++ b/mypy_plugin.ini @@ -0,0 +1,9 @@ +[mypy] +plugins = sqlalchemy.ext.mypy.plugin +show_error_codes = True +mypy_path=./lib/ +strict = True +raise_exceptions=True + +[mypy-sqlalchemy.*] +ignore_errors = True diff --git a/pyproject.toml b/pyproject.toml index 3af6ea089c..be5dd15962 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,10 +65,9 @@ warn_unused_ignores = false strict = true -# https://github.com/python/mypy/issues/8754 -# we are a pep-561 package, so implicit-rexport should be -# enabled -implicit_reexport = true +# some debate at +# https://github.com/python/mypy/issues/8754. +# implicit_reexport = true # individual packages or even modules should be listed here # with strictness-specificity set up. there's no way we are going to get @@ -78,7 +77,6 @@ implicit_reexport = true # strict checking [[tool.mypy.overrides]] module = [ - "sqlalchemy.events", "sqlalchemy.events", "sqlalchemy.exc", "sqlalchemy.inspection", diff --git a/setup.cfg b/setup.cfg index 2eceb0b816..99abcea1c8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -109,6 +109,8 @@ import-order-style = google application-import-names = sqlalchemy,test per-file-ignores = **/__init__.py:F401 + test/ext/mypy/plain_files/*:F821,E501 + test/ext/mypy/plugin_files/*:F821,E501 lib/sqlalchemy/events.py:F401 lib/sqlalchemy/schema.py:F401 lib/sqlalchemy/types.py:F401 diff --git a/test/base/test_concurrency_py3k.py b/test/base/test_concurrency_py3k.py index 3c89108ee3..79601019e8 100644 --- a/test/base/test_concurrency_py3k.py +++ b/test/base/test_concurrency_py3k.py @@ -1,3 +1,4 @@ +import asyncio import threading from sqlalchemy import exc @@ -7,7 +8,6 @@ from sqlalchemy.testing import expect_raises from sqlalchemy.testing import expect_raises_message from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_true -from sqlalchemy.util import asyncio from sqlalchemy.util import await_fallback from sqlalchemy.util import await_only from sqlalchemy.util import greenlet_spawn diff --git a/test/base/test_utils.py b/test/base/test_utils.py index dc02c37cb0..67fcc88705 100644 --- a/test/base/test_utils.py +++ b/test/base/test_utils.py @@ -31,6 +31,7 @@ from sqlalchemy.util import compat from sqlalchemy.util import get_callable_argspec from sqlalchemy.util import langhelpers from sqlalchemy.util import WeakSequence +from sqlalchemy.util._collections import merge_lists_w_ordering class WeakSequenceTest(fixtures.TestBase): @@ -66,6 +67,49 @@ class WeakSequenceTest(fixtures.TestBase): eq_(len(w._storage), 2) +class MergeListsWOrderingTest(fixtures.TestBase): + @testing.combinations( + ( + ["__tablename__", "id", "x", "created_at"], + ["id", "name", "data", "y", "created_at"], + ["__tablename__", "id", "name", "data", "y", "x", "created_at"], + ), + (["a", "b", "c", "d", "e", "f"], [], ["a", "b", "c", "d", "e", "f"]), + ([], ["a", "b", "c", "d", "e", "f"], ["a", "b", "c", "d", "e", "f"]), + ([], [], []), + (["a", "b", "c"], ["a", "b", "c"], ["a", "b", "c"]), + ( + ["a", "b", "c"], + ["a", "b", "c", "d", "e"], + ["a", "b", "c", "d", "e"], + ), + (["a", "b", "c", "d"], ["c", "d", "e"], ["a", "b", "c", "d", "e"]), + ( + ["a", "c", "e", "g"], + ["b", "d", "f", "g"], + ["a", "c", "e", "b", "d", "f", "g"], # no overlaps until "g" + ), + ( + ["a", "b", "e", "f", "g"], + ["b", "c", "d", "e"], + ["a", "b", "c", "d", "e", "f", "g"], + ), + ( + ["a", "b", "c", "e", "f"], + ["c", "d", "f", "g"], + ["a", "b", "c", "d", "e", "f", "g"], + ), + ( + ["c", "d", "f", "g"], + ["a", "b", "c", "e", "f"], + ["a", "b", "c", "e", "d", "f", "g"], + ), + argnames="a,b,expected", + ) + def test_merge_lists(self, a, b, expected): + eq_(merge_lists_w_ordering(a, b), expected) + + class OrderedDictTest(fixtures.TestBase): def test_odict(self): o = util.OrderedDict() diff --git a/test/dialect/postgresql/test_dialect.py b/test/dialect/postgresql/test_dialect.py index 5f33fa46de..613fc80a5a 100644 --- a/test/dialect/postgresql/test_dialect.py +++ b/test/dialect/postgresql/test_dialect.py @@ -29,7 +29,6 @@ from sqlalchemy import Table from sqlalchemy import testing from sqlalchemy import text from sqlalchemy import TypeDecorator -from sqlalchemy import util from sqlalchemy.dialects.postgresql import base as postgresql from sqlalchemy.dialects.postgresql import HSTORE from sqlalchemy.dialects.postgresql import JSONB @@ -615,7 +614,7 @@ class ExecutemanyValuesInsertsTest(ExecuteManyMode, fixtures.TablesTest): "id", Integer, primary_key=True, - default=lambda: util.next(counter), + default=lambda: next(counter), ), Column("data", Integer), ) diff --git a/test/ext/declarative/test_inheritance.py b/test/ext/declarative/test_inheritance.py index 64c32c76bd..a695aadba4 100644 --- a/test/ext/declarative/test_inheritance.py +++ b/test/ext/declarative/test_inheritance.py @@ -547,9 +547,9 @@ class ConcreteInhTest( configure_mappers() self.assert_compile( select(Employee), - "SELECT pjoin.name, pjoin.employee_id, pjoin.type, pjoin._type " - "FROM (SELECT manager.name AS name, manager.employee_id AS " - "employee_id, manager.type AS type, 'manager' AS _type " + "SELECT pjoin.employee_id, pjoin.type, pjoin.name, pjoin._type " + "FROM (SELECT manager.employee_id AS employee_id, " + "manager.type AS type, manager.name AS name, 'manager' AS _type " "FROM manager) AS pjoin", ) @@ -859,13 +859,13 @@ class ConcreteExtensionConfigTest( session = Session() self.assert_compile( session.query(Document), - "SELECT pjoin.doctype AS pjoin_doctype, " - "pjoin.send_method AS pjoin_send_method, " - "pjoin.id AS pjoin_id, pjoin.type AS pjoin_type " - "FROM (SELECT actual_documents.doctype AS doctype, " + "SELECT pjoin.id AS pjoin_id, pjoin.send_method AS " + "pjoin_send_method, pjoin.doctype AS pjoin_doctype, " + "pjoin.type AS pjoin_type FROM " + "(SELECT actual_documents.id AS id, " "actual_documents.send_method AS send_method, " - "actual_documents.id AS id, 'actual' AS type " - "FROM actual_documents) AS pjoin", + "actual_documents.doctype AS doctype, " + "'actual' AS type FROM actual_documents) AS pjoin", ) def test_column_attr_names(self): @@ -886,14 +886,14 @@ class ConcreteExtensionConfigTest( session.query(Document), "SELECT pjoin.documenttype AS pjoin_documenttype, " "pjoin.id AS pjoin_id, pjoin.type AS pjoin_type FROM " - "(SELECT offers.documenttype AS documenttype, offers.id AS id, " + "(SELECT offers.id AS id, offers.documenttype AS documenttype, " "'offer' AS type FROM offers) AS pjoin", ) self.assert_compile( session.query(Document.documentType), "SELECT pjoin.documenttype AS pjoin_documenttype FROM " - "(SELECT offers.documenttype AS documenttype, offers.id AS id, " + "(SELECT offers.id AS id, offers.documenttype AS documenttype, " "'offer' AS type FROM offers) AS pjoin", ) diff --git a/test/ext/mypy/files/inspect.py b/test/ext/mypy/inspection_inspect.py similarity index 100% rename from test/ext/mypy/files/inspect.py rename to test/ext/mypy/inspection_inspect.py diff --git a/test/ext/mypy/plain_files/engine_inspection.py b/test/ext/mypy/plain_files/engine_inspection.py new file mode 100644 index 0000000000..1a1649e4ec --- /dev/null +++ b/test/ext/mypy/plain_files/engine_inspection.py @@ -0,0 +1,24 @@ +import typing + +from sqlalchemy import create_engine +from sqlalchemy import inspect + + +e = create_engine("sqlite://") + +insp = inspect(e) + +cols = insp.get_columns("some_table") + +c1 = cols[0] + +if typing.TYPE_CHECKING: + + # EXPECTED_TYPE: sqlalchemy.engine.base.Engine + reveal_type(e) + + # EXPECTED_TYPE: sqlalchemy.engine.reflection.Inspector.* + reveal_type(insp) + + # EXPECTED_TYPE: .*list.*TypedDict.*ReflectedColumn.* + reveal_type(cols) diff --git a/test/ext/mypy/plain_files/experimental_relationship.py b/test/ext/mypy/plain_files/experimental_relationship.py new file mode 100644 index 0000000000..e97a9598b0 --- /dev/null +++ b/test/ext/mypy/plain_files/experimental_relationship.py @@ -0,0 +1,69 @@ +"""this suite experiments with other kinds of relationship syntaxes. + +""" +import typing +from typing import List +from typing import Optional +from typing import Set + +from sqlalchemy import ForeignKey +from sqlalchemy import Integer +from sqlalchemy import String +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column +from sqlalchemy.orm import relationship + + +class Base(DeclarativeBase): + pass + + +class User(Base): + __tablename__ = "user" + + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column() + + # this currently doesnt generate an error. not sure how to get the + # overloads to hit this one, nor am i sure i really want to do that + # anyway + name_this_works_atm: Mapped[str] = mapped_column(nullable=True) + + extra: Mapped[Optional[str]] = mapped_column() + extra_name: Mapped[Optional[str]] = mapped_column("extra_name") + + addresses_style_one: Mapped[List["Address"]] = relationship() + addresses_style_two: Mapped[Set["Address"]] = relationship() + + +class Address(Base): + __tablename__ = "address" + + id = mapped_column(Integer, primary_key=True) + user_id = mapped_column(ForeignKey("user.id")) + email = mapped_column(String, nullable=False) + email_name = mapped_column("email_name", String, nullable=False) + + user_style_one: Mapped[User] = relationship() + user_style_two: Mapped["User"] = relationship() + + +if typing.TYPE_CHECKING: + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[Union\[builtins.str, None\]\] + reveal_type(User.extra) + + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[Union\[builtins.str, None\]\] + reveal_type(User.extra_name) + + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.str\*\] + reveal_type(Address.email) + + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.str\*\] + reveal_type(Address.email_name) + + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*\[experimental_relationship.Address\]\] + reveal_type(User.addresses_style_one) + + # EXPECTED_TYPE: sqlalchemy.orm.attributes.InstrumentedAttribute\[builtins.set\*\[experimental_relationship.Address\]\] + reveal_type(User.addresses_style_two) diff --git a/test/ext/mypy/plain_files/mapped_column.py b/test/ext/mypy/plain_files/mapped_column.py new file mode 100644 index 0000000000..b20beeb3a3 --- /dev/null +++ b/test/ext/mypy/plain_files/mapped_column.py @@ -0,0 +1,92 @@ +from typing import Optional + +from sqlalchemy import Integer +from sqlalchemy import String +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column + + +class Base(DeclarativeBase): + pass + + +class X(Base): + __tablename__ = "x" + + id: Mapped[int] = mapped_column(primary_key=True) + int_id: Mapped[int] = mapped_column(Integer, primary_key=True) + + # EXPECTED_MYPY: No overload variant of "mapped_column" matches argument types + err_int_id: Mapped[Optional[int]] = mapped_column( + Integer, primary_key=True + ) + + id_name: Mapped[int] = mapped_column("id_name", primary_key=True) + int_id_name: Mapped[int] = mapped_column( + "int_id_name", Integer, primary_key=True + ) + + # EXPECTED_MYPY: No overload variant of "mapped_column" matches argument types + err_int_id_name: Mapped[Optional[int]] = mapped_column( + "err_int_id_name", Integer, primary_key=True + ) + + # note we arent getting into primary_key=True / nullable=True here. + # leaving that as undefined for now + + a: Mapped[str] = mapped_column() + b: Mapped[Optional[str]] = mapped_column() + + # can't detect error because no SQL type is present + c: Mapped[str] = mapped_column(nullable=True) + d: Mapped[str] = mapped_column(nullable=False) + + e: Mapped[Optional[str]] = mapped_column(nullable=True) + + # can't detect error because no SQL type is present + f: Mapped[Optional[str]] = mapped_column(nullable=False) + + g: Mapped[str] = mapped_column(String) + h: Mapped[Optional[str]] = mapped_column(String) + + # EXPECTED_MYPY: No overload variant of "mapped_column" matches argument types + i: Mapped[str] = mapped_column(String, nullable=True) + + j: Mapped[str] = mapped_column(String, nullable=False) + + k: Mapped[Optional[str]] = mapped_column(String, nullable=True) + + # EXPECTED_MYPY_RE: Argument \d to "mapped_column" has incompatible type + l: Mapped[Optional[str]] = mapped_column(String, nullable=False) + + a_name: Mapped[str] = mapped_column("a_name") + b_name: Mapped[Optional[str]] = mapped_column("b_name") + + # can't detect error because no SQL type is present + c_name: Mapped[str] = mapped_column("c_name", nullable=True) + d_name: Mapped[str] = mapped_column("d_name", nullable=False) + + e_name: Mapped[Optional[str]] = mapped_column("e_name", nullable=True) + + # can't detect error because no SQL type is present + f_name: Mapped[Optional[str]] = mapped_column("f_name", nullable=False) + + g_name: Mapped[str] = mapped_column("g_name", String) + h_name: Mapped[Optional[str]] = mapped_column("h_name", String) + + # EXPECTED_MYPY: No overload variant of "mapped_column" matches argument types + i_name: Mapped[str] = mapped_column("i_name", String, nullable=True) + + j_name: Mapped[str] = mapped_column("j_name", String, nullable=False) + + k_name: Mapped[Optional[str]] = mapped_column( + "k_name", String, nullable=True + ) + + l_name: Mapped[Optional[str]] = mapped_column( + "l_name", + # EXPECTED_MYPY_RE: Argument \d to "mapped_column" has incompatible type + String, + nullable=False, + ) diff --git a/test/ext/mypy/plain_files/trad_relationship_uselist.py b/test/ext/mypy/plain_files/trad_relationship_uselist.py new file mode 100644 index 0000000000..a372fe2d1c --- /dev/null +++ b/test/ext/mypy/plain_files/trad_relationship_uselist.py @@ -0,0 +1,133 @@ +"""traditional relationship patterns with explicit uselist. + + +""" +import typing +from typing import cast +from typing import Dict +from typing import List +from typing import Set +from typing import Type + +from sqlalchemy import ForeignKey +from sqlalchemy import Integer +from sqlalchemy import String +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column +from sqlalchemy.orm import relationship +from sqlalchemy.orm.collections import attribute_mapped_collection + + +class Base(DeclarativeBase): + pass + + +class User(Base): + __tablename__ = "user" + + id = mapped_column(Integer, primary_key=True) + name = mapped_column(String, nullable=False) + + addresses_style_one: Mapped[List["Address"]] = relationship( + "Address", uselist=True + ) + + addresses_style_two: Mapped[Set["Address"]] = relationship( + "Address", collection_class=set + ) + + addresses_style_three = relationship("Address", collection_class=set) + + addresses_style_three_cast = relationship( + cast(Type["Address"], "Address"), collection_class=set + ) + + addresses_style_four = relationship("Address", collection_class=list) + + +class Address(Base): + __tablename__ = "address" + + id = mapped_column(Integer, primary_key=True) + user_id = mapped_column(ForeignKey("user.id")) + email = mapped_column(String, nullable=False) + + user_style_one = relationship(User, uselist=False) + + user_style_one_typed: Mapped[User] = relationship(User, uselist=False) + + user_style_two = relationship("User", uselist=False) + + user_style_two_typed: Mapped["User"] = relationship("User", uselist=False) + + # these is obviously not correct relationally but want to see the typing + # work out with a real class passed as the argument + user_style_three: Mapped[List[User]] = relationship(User, uselist=True) + + user_style_four: Mapped[List[User]] = relationship("User", uselist=True) + + user_style_five: Mapped[List[User]] = relationship(User, uselist=True) + + user_style_six: Mapped[Set[User]] = relationship( + User, uselist=True, collection_class=set + ) + + user_style_seven = relationship(User, uselist=True, collection_class=set) + + user_style_eight = relationship(User, uselist=True, collection_class=list) + + user_style_nine = relationship(User, uselist=True) + + user_style_ten = relationship( + User, collection_class=attribute_mapped_collection("name") + ) + + user_style_ten_typed: Mapped[Dict[str, User]] = relationship( + User, collection_class=attribute_mapped_collection("name") + ) + + +if typing.TYPE_CHECKING: + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*\[trad_relationship_uselist.Address\]\] + reveal_type(User.addresses_style_one) + + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.set\*\[trad_relationship_uselist.Address\]\] + reveal_type(User.addresses_style_two) + + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.set\*\[Any\]\] + reveal_type(User.addresses_style_three) + + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.set\*\[trad_relationship_uselist.Address\]\] + reveal_type(User.addresses_style_three_cast) + + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*\[Any\]\] + reveal_type(User.addresses_style_four) + + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[trad_relationship_uselist.User\*\] + reveal_type(Address.user_style_one) + + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[trad_relationship_uselist.User\*\] + reveal_type(Address.user_style_one_typed) + + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\] + reveal_type(Address.user_style_two) + + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[trad_relationship_uselist.User\*\] + reveal_type(Address.user_style_two_typed) + + # reveal_type(Address.user_style_six) + + # reveal_type(Address.user_style_seven) + + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*\[trad_relationship_uselist.User\]\] + reveal_type(Address.user_style_eight) + + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*\[trad_relationship_uselist.User\]\] + reveal_type(Address.user_style_nine) + + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\] + reveal_type(Address.user_style_ten) + + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.dict\*\[builtins.str, trad_relationship_uselist.User\]\] + reveal_type(Address.user_style_ten_typed) diff --git a/test/ext/mypy/plain_files/traditional_relationship.py b/test/ext/mypy/plain_files/traditional_relationship.py new file mode 100644 index 0000000000..473ccb2824 --- /dev/null +++ b/test/ext/mypy/plain_files/traditional_relationship.py @@ -0,0 +1,88 @@ +"""Here we illustrate 'traditional' relationship that looks as much like +1.x SQLAlchemy as possible. We want to illustrate that users can apply +Mapped[...] on the left hand side and that this will work in all cases. +This requires that the return type of relationship is based on Any, +if no uselists are present. + +""" +import typing +from typing import List +from typing import Set + +from sqlalchemy import ForeignKey +from sqlalchemy import Integer +from sqlalchemy import String +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column +from sqlalchemy.orm import relationship + + +class Base(DeclarativeBase): + pass + + +class User(Base): + __tablename__ = "user" + + id = mapped_column(Integer, primary_key=True) + name = mapped_column(String, nullable=False) + + addresses_style_one: Mapped[List["Address"]] = relationship("Address") + + addresses_style_two: Mapped[Set["Address"]] = relationship( + "Address", collection_class=set + ) + + +class Address(Base): + __tablename__ = "address" + + id = mapped_column(Integer, primary_key=True) + user_id = mapped_column(ForeignKey("user.id")) + email = mapped_column(String, nullable=False) + + user_style_one = relationship(User) + + user_style_one_typed: Mapped[User] = relationship(User) + + user_style_two = relationship("User") + + user_style_two_typed: Mapped["User"] = relationship("User") + + # this is obviously not correct relationally but want to see the typing + # work out + user_style_three: Mapped[List[User]] = relationship(User) + + user_style_four: Mapped[List[User]] = relationship("User") + + user_style_five = relationship(User, collection_class=set) + + +if typing.TYPE_CHECKING: + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*\[traditional_relationship.Address\]\] + reveal_type(User.addresses_style_one) + + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.set\*\[traditional_relationship.Address\]\] + reveal_type(User.addresses_style_two) + + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[traditional_relationship.User\*\] + reveal_type(Address.user_style_one) + + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[traditional_relationship.User\*\] + reveal_type(Address.user_style_one_typed) + + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[Any\] + reveal_type(Address.user_style_two) + + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[traditional_relationship.User\*\] + reveal_type(Address.user_style_two_typed) + + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*\[traditional_relationship.User\]\] + reveal_type(Address.user_style_three) + + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*\[traditional_relationship.User\]\] + reveal_type(Address.user_style_four) + + # EXPECTED_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.set\*\[traditional_relationship.User\]\] + reveal_type(Address.user_style_five) diff --git a/test/ext/mypy/files/abstract_one.py b/test/ext/mypy/plugin_files/abstract_one.py similarity index 100% rename from test/ext/mypy/files/abstract_one.py rename to test/ext/mypy/plugin_files/abstract_one.py diff --git a/test/ext/mypy/files/as_declarative.py b/test/ext/mypy/plugin_files/as_declarative.py similarity index 100% rename from test/ext/mypy/files/as_declarative.py rename to test/ext/mypy/plugin_files/as_declarative.py diff --git a/test/ext/mypy/files/as_declarative_base.py b/test/ext/mypy/plugin_files/as_declarative_base.py similarity index 100% rename from test/ext/mypy/files/as_declarative_base.py rename to test/ext/mypy/plugin_files/as_declarative_base.py diff --git a/test/ext/mypy/files/boolean_col.py b/test/ext/mypy/plugin_files/boolean_col.py similarity index 100% rename from test/ext/mypy/files/boolean_col.py rename to test/ext/mypy/plugin_files/boolean_col.py diff --git a/test/ext/mypy/files/cols_noninferred_plain_nonopt.py b/test/ext/mypy/plugin_files/cols_noninferred_plain_nonopt.py similarity index 100% rename from test/ext/mypy/files/cols_noninferred_plain_nonopt.py rename to test/ext/mypy/plugin_files/cols_noninferred_plain_nonopt.py diff --git a/test/ext/mypy/files/cols_notype_on_fk_col.py b/test/ext/mypy/plugin_files/cols_notype_on_fk_col.py similarity index 100% rename from test/ext/mypy/files/cols_notype_on_fk_col.py rename to test/ext/mypy/plugin_files/cols_notype_on_fk_col.py diff --git a/test/ext/mypy/files/complete_orm_no_plugin.py b/test/ext/mypy/plugin_files/complete_orm_no_plugin.py similarity index 100% rename from test/ext/mypy/files/complete_orm_no_plugin.py rename to test/ext/mypy/plugin_files/complete_orm_no_plugin.py diff --git a/test/ext/mypy/files/composite_props.py b/test/ext/mypy/plugin_files/composite_props.py similarity index 95% rename from test/ext/mypy/files/composite_props.py rename to test/ext/mypy/plugin_files/composite_props.py index f92b93c57d..d717ca0489 100644 --- a/test/ext/mypy/files/composite_props.py +++ b/test/ext/mypy/plugin_files/composite_props.py @@ -52,7 +52,7 @@ v1 = Vertex(start=Point(3, 4), end=Point(5, 6)) # I'm not even sure composites support this but it should work from a # typing perspective -stmt = select(v1).where(Vertex.start.in_([Point(3, 4)])) +stmt = select(Vertex).where(Vertex.start.in_([Point(3, 4)])) p1: Point = v1.start p2: Point = v1.end diff --git a/test/ext/mypy/files/constr_cols_only.py b/test/ext/mypy/plugin_files/constr_cols_only.py similarity index 100% rename from test/ext/mypy/files/constr_cols_only.py rename to test/ext/mypy/plugin_files/constr_cols_only.py diff --git a/test/ext/mypy/files/dataclasses_workaround.py b/test/ext/mypy/plugin_files/dataclasses_workaround.py similarity index 95% rename from test/ext/mypy/files/dataclasses_workaround.py rename to test/ext/mypy/plugin_files/dataclasses_workaround.py index 56c61b3333..9928b5a335 100644 --- a/test/ext/mypy/files/dataclasses_workaround.py +++ b/test/ext/mypy/plugin_files/dataclasses_workaround.py @@ -4,6 +4,8 @@ from __future__ import annotations from dataclasses import dataclass from dataclasses import field +from typing import Any +from typing import Dict from typing import List from typing import Optional from typing import TYPE_CHECKING @@ -40,7 +42,7 @@ class User: if TYPE_CHECKING: _mypy_mapped_attrs = [id, name, fullname, nickname, addresses] - __mapper_args__ = { # type: ignore + __mapper_args__: Dict[str, Any] = { "properties": {"addresses": relationship("Address")} } diff --git a/test/ext/mypy/files/decl_attrs_one.py b/test/ext/mypy/plugin_files/decl_attrs_one.py similarity index 100% rename from test/ext/mypy/files/decl_attrs_one.py rename to test/ext/mypy/plugin_files/decl_attrs_one.py diff --git a/test/ext/mypy/files/decl_attrs_two.py b/test/ext/mypy/plugin_files/decl_attrs_two.py similarity index 100% rename from test/ext/mypy/files/decl_attrs_two.py rename to test/ext/mypy/plugin_files/decl_attrs_two.py diff --git a/test/ext/mypy/files/decl_base_subclass_one.py b/test/ext/mypy/plugin_files/decl_base_subclass_one.py similarity index 100% rename from test/ext/mypy/files/decl_base_subclass_one.py rename to test/ext/mypy/plugin_files/decl_base_subclass_one.py diff --git a/test/ext/mypy/files/decl_base_subclass_two.py b/test/ext/mypy/plugin_files/decl_base_subclass_two.py similarity index 100% rename from test/ext/mypy/files/decl_base_subclass_two.py rename to test/ext/mypy/plugin_files/decl_base_subclass_two.py diff --git a/test/ext/mypy/files/declarative_base_dynamic.py b/test/ext/mypy/plugin_files/declarative_base_dynamic.py similarity index 100% rename from test/ext/mypy/files/declarative_base_dynamic.py rename to test/ext/mypy/plugin_files/declarative_base_dynamic.py diff --git a/test/ext/mypy/files/declarative_base_explicit.py b/test/ext/mypy/plugin_files/declarative_base_explicit.py similarity index 100% rename from test/ext/mypy/files/declarative_base_explicit.py rename to test/ext/mypy/plugin_files/declarative_base_explicit.py diff --git a/test/ext/mypy/files/ensure_descriptor_type_fully_inferred.py b/test/ext/mypy/plugin_files/ensure_descriptor_type_fully_inferred.py similarity index 100% rename from test/ext/mypy/files/ensure_descriptor_type_fully_inferred.py rename to test/ext/mypy/plugin_files/ensure_descriptor_type_fully_inferred.py diff --git a/test/ext/mypy/files/ensure_descriptor_type_noninferred.py b/test/ext/mypy/plugin_files/ensure_descriptor_type_noninferred.py similarity index 100% rename from test/ext/mypy/files/ensure_descriptor_type_noninferred.py rename to test/ext/mypy/plugin_files/ensure_descriptor_type_noninferred.py diff --git a/test/ext/mypy/files/ensure_descriptor_type_semiinferred.py b/test/ext/mypy/plugin_files/ensure_descriptor_type_semiinferred.py similarity index 100% rename from test/ext/mypy/files/ensure_descriptor_type_semiinferred.py rename to test/ext/mypy/plugin_files/ensure_descriptor_type_semiinferred.py diff --git a/test/ext/mypy/files/enum_col.py b/test/ext/mypy/plugin_files/enum_col.py similarity index 100% rename from test/ext/mypy/files/enum_col.py rename to test/ext/mypy/plugin_files/enum_col.py diff --git a/test/ext/mypy/files/imperative_table.py b/test/ext/mypy/plugin_files/imperative_table.py similarity index 100% rename from test/ext/mypy/files/imperative_table.py rename to test/ext/mypy/plugin_files/imperative_table.py diff --git a/test/ext/mypy/files/invalid_noninferred_lh_type.py b/test/ext/mypy/plugin_files/invalid_noninferred_lh_type.py similarity index 100% rename from test/ext/mypy/files/invalid_noninferred_lh_type.py rename to test/ext/mypy/plugin_files/invalid_noninferred_lh_type.py diff --git a/test/ext/mypy/files/issue_7321.py b/test/ext/mypy/plugin_files/issue_7321.py similarity index 100% rename from test/ext/mypy/files/issue_7321.py rename to test/ext/mypy/plugin_files/issue_7321.py diff --git a/test/ext/mypy/files/issue_7321_part2.py b/test/ext/mypy/plugin_files/issue_7321_part2.py similarity index 100% rename from test/ext/mypy/files/issue_7321_part2.py rename to test/ext/mypy/plugin_files/issue_7321_part2.py diff --git a/test/ext/mypy/files/mapped_attr_assign.py b/test/ext/mypy/plugin_files/mapped_attr_assign.py similarity index 100% rename from test/ext/mypy/files/mapped_attr_assign.py rename to test/ext/mypy/plugin_files/mapped_attr_assign.py diff --git a/test/ext/mypy/files/mixin_not_mapped.py b/test/ext/mypy/plugin_files/mixin_not_mapped.py similarity index 100% rename from test/ext/mypy/files/mixin_not_mapped.py rename to test/ext/mypy/plugin_files/mixin_not_mapped.py diff --git a/test/ext/mypy/files/mixin_one.py b/test/ext/mypy/plugin_files/mixin_one.py similarity index 100% rename from test/ext/mypy/files/mixin_one.py rename to test/ext/mypy/plugin_files/mixin_one.py diff --git a/test/ext/mypy/files/mixin_three.py b/test/ext/mypy/plugin_files/mixin_three.py similarity index 100% rename from test/ext/mypy/files/mixin_three.py rename to test/ext/mypy/plugin_files/mixin_three.py diff --git a/test/ext/mypy/files/mixin_two.py b/test/ext/mypy/plugin_files/mixin_two.py similarity index 95% rename from test/ext/mypy/files/mixin_two.py rename to test/ext/mypy/plugin_files/mixin_two.py index c4dc610973..897ce82498 100644 --- a/test/ext/mypy/files/mixin_two.py +++ b/test/ext/mypy/plugin_files/mixin_two.py @@ -6,6 +6,7 @@ from sqlalchemy import String from sqlalchemy.orm import deferred from sqlalchemy.orm import Mapped from sqlalchemy.orm import registry +from sqlalchemy.orm import Relationship from sqlalchemy.orm import relationship from sqlalchemy.orm.decl_api import declared_attr from sqlalchemy.orm.interfaces import MapperProperty @@ -36,11 +37,11 @@ class HasAMixin: return relationship("A", back_populates="bs") @declared_attr - def a3(cls) -> relationship["A"]: + def a3(cls) -> Relationship["A"]: return relationship("A", back_populates="bs") @declared_attr - def c1(cls) -> relationship[C]: + def c1(cls) -> Relationship[C]: return relationship(C, back_populates="bs") @declared_attr diff --git a/test/ext/mypy/files/mixin_w_tablename.py b/test/ext/mypy/plugin_files/mixin_w_tablename.py similarity index 100% rename from test/ext/mypy/files/mixin_w_tablename.py rename to test/ext/mypy/plugin_files/mixin_w_tablename.py diff --git a/test/ext/mypy/files/orderinglist1.py b/test/ext/mypy/plugin_files/orderinglist1.py similarity index 100% rename from test/ext/mypy/files/orderinglist1.py rename to test/ext/mypy/plugin_files/orderinglist1.py diff --git a/test/ext/mypy/files/orderinglist2.py b/test/ext/mypy/plugin_files/orderinglist2.py similarity index 100% rename from test/ext/mypy/files/orderinglist2.py rename to test/ext/mypy/plugin_files/orderinglist2.py diff --git a/test/ext/mypy/files/other_mapper_props.py b/test/ext/mypy/plugin_files/other_mapper_props.py similarity index 100% rename from test/ext/mypy/files/other_mapper_props.py rename to test/ext/mypy/plugin_files/other_mapper_props.py diff --git a/test/ext/mypy/files/plugin_doesnt_break_one.py b/test/ext/mypy/plugin_files/plugin_doesnt_break_one.py similarity index 100% rename from test/ext/mypy/files/plugin_doesnt_break_one.py rename to test/ext/mypy/plugin_files/plugin_doesnt_break_one.py diff --git a/test/ext/mypy/files/relationship_6255_one.py b/test/ext/mypy/plugin_files/relationship_6255_one.py similarity index 74% rename from test/ext/mypy/files/relationship_6255_one.py rename to test/ext/mypy/plugin_files/relationship_6255_one.py index e5a180b479..0c8e3c4f64 100644 --- a/test/ext/mypy/files/relationship_6255_one.py +++ b/test/ext/mypy/plugin_files/relationship_6255_one.py @@ -1,13 +1,13 @@ from typing import List from typing import Optional -from sqlalchemy import Column from sqlalchemy import ForeignKey from sqlalchemy import Integer from sqlalchemy import select from sqlalchemy import String from sqlalchemy.orm import declarative_base from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column from sqlalchemy.orm import relationship Base = declarative_base() @@ -16,8 +16,8 @@ Base = declarative_base() class User(Base): __tablename__ = "user" - id = Column(Integer, primary_key=True) - name = Column(String) + id = mapped_column(Integer, primary_key=True) + name = mapped_column(String, nullable=True) addresses: Mapped[List["Address"]] = relationship( "Address", back_populates="user" @@ -31,10 +31,10 @@ class User(Base): class Address(Base): __tablename__ = "address" - id = Column(Integer, primary_key=True) - user_id: int = Column(ForeignKey("user.id")) + id = mapped_column(Integer, primary_key=True) + user_id: Mapped[int] = mapped_column(ForeignKey("user.id")) - user: "User" = relationship("User", back_populates="addresses") + user: Mapped["User"] = relationship("User", back_populates="addresses") @property def some_other_property(self) -> Optional[str]: diff --git a/test/ext/mypy/files/relationship_6255_three.py b/test/ext/mypy/plugin_files/relationship_6255_three.py similarity index 100% rename from test/ext/mypy/files/relationship_6255_three.py rename to test/ext/mypy/plugin_files/relationship_6255_three.py diff --git a/test/ext/mypy/files/relationship_6255_two.py b/test/ext/mypy/plugin_files/relationship_6255_two.py similarity index 100% rename from test/ext/mypy/files/relationship_6255_two.py rename to test/ext/mypy/plugin_files/relationship_6255_two.py diff --git a/test/ext/mypy/files/relationship_direct_cls.py b/test/ext/mypy/plugin_files/relationship_direct_cls.py similarity index 100% rename from test/ext/mypy/files/relationship_direct_cls.py rename to test/ext/mypy/plugin_files/relationship_direct_cls.py diff --git a/test/ext/mypy/files/relationship_err1.py b/test/ext/mypy/plugin_files/relationship_err1.py similarity index 91% rename from test/ext/mypy/files/relationship_err1.py rename to test/ext/mypy/plugin_files/relationship_err1.py index 46e7067d34..ba3783f05a 100644 --- a/test/ext/mypy/files/relationship_err1.py +++ b/test/ext/mypy/plugin_files/relationship_err1.py @@ -27,4 +27,5 @@ class A(Base): b_id: int = Column(ForeignKey("b.id")) # EXPECTED: Sending uselist=False and collection_class at the same time does not make sense # noqa + # EXPECTED_MYPY_RE: No overload variant of "relationship" matches argument types b: B = relationship(B, uselist=False, collection_class=set) diff --git a/test/ext/mypy/files/relationship_err2.py b/test/ext/mypy/plugin_files/relationship_err2.py similarity index 100% rename from test/ext/mypy/files/relationship_err2.py rename to test/ext/mypy/plugin_files/relationship_err2.py diff --git a/test/ext/mypy/files/relationship_err3.py b/test/ext/mypy/plugin_files/relationship_err3.py similarity index 100% rename from test/ext/mypy/files/relationship_err3.py rename to test/ext/mypy/plugin_files/relationship_err3.py diff --git a/test/ext/mypy/files/sa_module_prefix.py b/test/ext/mypy/plugin_files/sa_module_prefix.py similarity index 100% rename from test/ext/mypy/files/sa_module_prefix.py rename to test/ext/mypy/plugin_files/sa_module_prefix.py diff --git a/test/ext/mypy/files/t_6950.py b/test/ext/mypy/plugin_files/t_6950.py similarity index 100% rename from test/ext/mypy/files/t_6950.py rename to test/ext/mypy/plugin_files/t_6950.py diff --git a/test/ext/mypy/files/type_decorator.py b/test/ext/mypy/plugin_files/type_decorator.py similarity index 100% rename from test/ext/mypy/files/type_decorator.py rename to test/ext/mypy/plugin_files/type_decorator.py diff --git a/test/ext/mypy/files/typeless_fk_col_cant_infer.py b/test/ext/mypy/plugin_files/typeless_fk_col_cant_infer.py similarity index 100% rename from test/ext/mypy/files/typeless_fk_col_cant_infer.py rename to test/ext/mypy/plugin_files/typeless_fk_col_cant_infer.py diff --git a/test/ext/mypy/files/typing_err1.py b/test/ext/mypy/plugin_files/typing_err1.py similarity index 100% rename from test/ext/mypy/files/typing_err1.py rename to test/ext/mypy/plugin_files/typing_err1.py diff --git a/test/ext/mypy/files/typing_err2.py b/test/ext/mypy/plugin_files/typing_err2.py similarity index 88% rename from test/ext/mypy/files/typing_err2.py rename to test/ext/mypy/plugin_files/typing_err2.py index adc50f9890..ec56358755 100644 --- a/test/ext/mypy/files/typing_err2.py +++ b/test/ext/mypy/plugin_files/typing_err2.py @@ -3,6 +3,7 @@ from sqlalchemy import Integer from sqlalchemy import String from sqlalchemy.orm import declared_attr from sqlalchemy.orm import registry +from sqlalchemy.orm import Relationship from sqlalchemy.orm import relationship reg: registry = registry() @@ -29,8 +30,8 @@ class Foo: # EXPECTED: Can't infer type from @declared_attr on function 'some_relationship' # noqa @declared_attr - # EXPECTED_MYPY: Missing type parameters for generic type "relationship" - def some_relationship(cls) -> relationship: + # EXPECTED_MYPY: Missing type parameters for generic type "Relationship" + def some_relationship(cls) -> Relationship: return relationship("Bar") diff --git a/test/ext/mypy/files/typing_err3.py b/test/ext/mypy/plugin_files/typing_err3.py similarity index 93% rename from test/ext/mypy/files/typing_err3.py rename to test/ext/mypy/plugin_files/typing_err3.py index 5383f89560..466e636a78 100644 --- a/test/ext/mypy/files/typing_err3.py +++ b/test/ext/mypy/plugin_files/typing_err3.py @@ -49,6 +49,5 @@ class Address(Base): @declared_attr # EXPECTED_MYPY: Invalid type comment or annotation def thisisweird(cls) -> Column(String): - # with the bad annotation mypy seems to not go into the - # function body + # EXPECTED_MYPY: No overload variant of "Column" matches argument type "bool" # noqa return Column(False) diff --git a/test/ext/mypy/test_mypy_plugin_py3k.py b/test/ext/mypy/test_mypy_plugin_py3k.py index cc8d8955f6..6df21e46c0 100644 --- a/test/ext/mypy/test_mypy_plugin_py3k.py +++ b/test/ext/mypy/test_mypy_plugin_py3k.py @@ -3,16 +3,54 @@ import re import shutil import sys import tempfile +from typing import Any +from typing import cast +from typing import List +from typing import Tuple +import sqlalchemy from sqlalchemy import testing from sqlalchemy.testing import config from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures +def _file_combinations(dirname): + path = os.path.join(os.path.dirname(__file__), dirname) + files = [] + for f in os.listdir(path): + if f.endswith(".py"): + files.append(os.path.join(os.path.dirname(__file__), dirname, f)) + + for extra_dir in testing.config.options.mypy_extra_test_paths: + if extra_dir and os.path.isdir(extra_dir): + for f in os.listdir(os.path.join(extra_dir, dirname)): + if f.endswith(".py"): + files.append(os.path.join(extra_dir, dirname, f)) + return files + + +def _incremental_dirs(): + path = os.path.join(os.path.dirname(__file__), "incremental") + files = [] + for d in os.listdir(path): + if os.path.isdir(os.path.join(path, d)): + files.append( + os.path.join(os.path.dirname(__file__), "incremental", d) + ) + + for extra_dir in testing.config.options.mypy_extra_test_paths: + if extra_dir and os.path.isdir(extra_dir): + for d in os.listdir(os.path.join(extra_dir, "incremental")): + if os.path.isdir(os.path.join(path, d)): + files.append(os.path.join(extra_dir, "incremental", d)) + return files + + @testing.add_to_marker.mypy class MypyPluginTest(fixtures.TestBase): - __requires__ = ("sqlalchemy2_stubs",) + __tags__ = ("mypy",) + __requires__ = ("no_sqlalchemy2_stubs",) @testing.fixture(scope="function") def per_func_cachedir(self): @@ -25,22 +63,50 @@ class MypyPluginTest(fixtures.TestBase): yield item def _cachedir(self): + sqlalchemy_path = os.path.dirname(os.path.dirname(sqlalchemy.__file__)) + + # for a pytest from my local ./lib/ , i need mypy_path. + # for a tox run where sqlalchemy is in site_packages, mypy complains + # "../python3.10/site-packages is in the MYPYPATH. Please remove it." + # previously when we used sqlalchemy2-stubs, it would just be + # installed as a dependency, which is why mypy_path wasn't needed + # then, but I like to be able to run the test suite from the local + # ./lib/ as well. + + if "site-packages" not in sqlalchemy_path: + mypy_path = f"mypy_path={sqlalchemy_path}" + else: + mypy_path = "" + with tempfile.TemporaryDirectory() as cachedir: with open( os.path.join(cachedir, "sqla_mypy_config.cfg"), "w" ) as config_file: config_file.write( - """ + f""" [mypy]\n plugins = sqlalchemy.ext.mypy.plugin\n + show_error_codes = True\n + {mypy_path} + disable_error_code = no-untyped-call + + [mypy-sqlalchemy.*] + ignore_errors = True + """ ) with open( os.path.join(cachedir, "plain_mypy_config.cfg"), "w" ) as config_file: config_file.write( - """ + f""" [mypy]\n + show_error_codes = True\n + {mypy_path} + disable_error_code = var-annotated,no-untyped-call + [mypy-sqlalchemy.*] + ignore_errors = True + """ ) yield cachedir @@ -70,24 +136,12 @@ class MypyPluginTest(fixtures.TestBase): return run - def _incremental_dirs(): - path = os.path.join(os.path.dirname(__file__), "incremental") - files = [] - for d in os.listdir(path): - if os.path.isdir(os.path.join(path, d)): - files.append( - os.path.join(os.path.dirname(__file__), "incremental", d) - ) - - for extra_dir in testing.config.options.mypy_extra_test_paths: - if extra_dir and os.path.isdir(extra_dir): - for d in os.listdir(os.path.join(extra_dir, "incremental")): - if os.path.isdir(os.path.join(path, d)): - files.append(os.path.join(extra_dir, "incremental", d)) - return files - @testing.combinations( - *[(pathname,) for pathname in _incremental_dirs()], argnames="pathname" + *[ + (pathname, testing.exclusions.closed()) + for pathname in _incremental_dirs() + ], + argnames="pathname", ) @testing.requires.patch_library def test_incremental(self, mypy_runner, per_func_cachedir, pathname): @@ -131,33 +185,33 @@ class MypyPluginTest(fixtures.TestBase): % (patchfile, result[0]), ) - def _file_combinations(): - path = os.path.join(os.path.dirname(__file__), "files") - files = [] - for f in os.listdir(path): - if f.endswith(".py"): - files.append( - os.path.join(os.path.dirname(__file__), "files", f) - ) - - for extra_dir in testing.config.options.mypy_extra_test_paths: - if extra_dir and os.path.isdir(extra_dir): - for f in os.listdir(os.path.join(extra_dir, "files")): - if f.endswith(".py"): - files.append(os.path.join(extra_dir, "files", f)) - return files - @testing.combinations( - *[(filename,) for filename in _file_combinations()], argnames="path" + *( + cast( + List[Tuple[Any, ...]], + [ + ("w_plugin", os.path.basename(path), path, True) + for path in _file_combinations("plugin_files") + ], + ) + + cast( + List[Tuple[Any, ...]], + [ + ("plain", os.path.basename(path), path, False) + for path in _file_combinations("plain_files") + ], + ) + ), + argnames="filename,path,use_plugin", + id_="isaa", ) - def test_mypy(self, mypy_runner, path): - filename = os.path.basename(path) - use_plugin = True + def test_files(self, mypy_runner, filename, path, use_plugin): - expected_errors = [] - expected_re = re.compile(r"\s*# EXPECTED(_MYPY)?: (.+)") + expected_messages = [] + expected_re = re.compile(r"\s*# EXPECTED(_MYPY)?(_RE)?(_TYPE)?: (.+)") py_ver_re = re.compile(r"^#\s*PYTHON_VERSION\s?>=\s?(\d+\.\d+)") with open(path) as file_: + current_assert_messages = [] for num, line in enumerate(file_, 1): m = py_ver_re.match(line) if m: @@ -174,38 +228,79 @@ class MypyPluginTest(fixtures.TestBase): m = expected_re.match(line) if m: is_mypy = bool(m.group(1)) - expected_msg = m.group(2) - expected_msg = re.sub(r"# noqa ?.*", "", m.group(2)) - expected_errors.append( - (num, is_mypy, expected_msg.strip()) + is_re = bool(m.group(2)) + is_type = bool(m.group(3)) + + expected_msg = re.sub(r"# noqa ?.*", "", m.group(4)) + if is_type: + is_mypy = is_re = True + expected_msg = f'Revealed type is "{expected_msg}"' + current_assert_messages.append( + (is_mypy, is_re, expected_msg.strip()) + ) + elif current_assert_messages: + expected_messages.extend( + (num, is_mypy, is_re, expected_msg) + for ( + is_mypy, + is_re, + expected_msg, + ) in current_assert_messages ) + current_assert_messages[:] = [] result = mypy_runner(path, use_plugin=use_plugin) - if expected_errors: + if expected_messages: eq_(result[2], 1, msg=result) - print(result[0]) + output = [] - errors = [] - for e in result[0].split("\n"): + raw_lines = result[0].split("\n") + while raw_lines: + e = raw_lines.pop(0) if re.match(r".+\.py:\d+: error: .*", e): - errors.append(e) - - for num, is_mypy, msg in expected_errors: + output.append(("error", e)) + elif re.match( + r".+\.py:\d+: note: +(?:Possible overload|def ).*", e + ): + while raw_lines: + ol = raw_lines.pop(0) + if not re.match(r".+\.py:\d+: note: +def \[.*", ol): + break + elif re.match( + r".+\.py:\d+: note: .*(?:perhaps|suggestion)", e, re.I + ): + pass + elif re.match(r".+\.py:\d+: note: .*", e): + output.append(("note", e)) + + for num, is_mypy, is_re, msg in expected_messages: msg = msg.replace("'", '"') prefix = "[SQLAlchemy Mypy plugin] " if not is_mypy else "" - for idx, errmsg in enumerate(errors): - if ( - f"{filename}:{num + 1}: error: {prefix}{msg}" + for idx, (typ, errmsg) in enumerate(output): + if is_re: + if re.match( + fr".*{filename}\:{num}\: {typ}\: {prefix}{msg}", # noqa E501 + errmsg, + ): + break + elif ( + f"{filename}:{num}: {typ}: {prefix}{msg}" in errmsg.replace("'", '"') ): break else: continue - del errors[idx] + del output[idx] - assert not errors, "errors remain: %s" % "\n".join(errors) + if output: + print("messages from mypy that were not consumed:") + print("\n".join(msg for _, msg in output)) + assert False, "errors and/or notes remain, see stdout" else: + if result[2] != 0: + print(result[0]) + eq_(result[2], 0, msg=result) diff --git a/test/ext/test_serializer.py b/test/ext/test_serializer.py index 9a05e1fae3..76fd90fa88 100644 --- a/test/ext/test_serializer.py +++ b/test/ext/test_serializer.py @@ -231,7 +231,7 @@ class SerializeTest(AssertsCompiledSQL, fixtures.MappedTest): serializer.loads(pickled_failing, users.metadata, None) def test_orm_join(self): - from sqlalchemy.orm.util import join + from sqlalchemy.orm import join j = join(User, Address, User.addresses) diff --git a/test/orm/declarative/__init__.py b/test/orm/declarative/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/test/orm/declarative/test_basic.py b/test/orm/declarative/test_basic.py index 9651f6dbfd..9f9f8e601e 100644 --- a/test/orm/declarative/test_basic.py +++ b/test/orm/declarative/test_basic.py @@ -563,26 +563,6 @@ class DeclarativeMultiBaseTest( eq_(a1, Address(email="two")) eq_(a1.user, User(name="u1")) - def test_mapped_column_construct(self): - class User(Base, fixtures.ComparableEntity): - __tablename__ = "users" - - id = mapped_column("id", Integer, primary_key=True) - name = mapped_column(String(50)) - - Base.metadata.create_all(testing.db) - - u1 = User(id=1, name="u1") - sess = fixture_session() - sess.add(u1) - sess.flush() - sess.expunge_all() - - eq_( - sess.query(User).all(), - [User(name="u1", id=1)], - ) - def test_back_populates_setup(self): class User(Base): __tablename__ = "users" @@ -1534,28 +1514,25 @@ class DeclarativeMultiBaseTest( yield go + @testing.combinations(Column, mapped_column, argnames="_column") def test_add_prop_auto( - self, require_metaclass, assert_user_address_mapping + self, require_metaclass, assert_user_address_mapping, _column ): class User(Base, fixtures.ComparableEntity): __tablename__ = "users" - id = Column( - "id", Integer, primary_key=True, test_needs_autoincrement=True - ) + id = Column("id", Integer, primary_key=True) - User.name = Column("name", String(50)) + User.name = _column("name", String(50)) User.addresses = relationship("Address", backref="user") class Address(Base, fixtures.ComparableEntity): __tablename__ = "addresses" - id = Column( - Integer, primary_key=True, test_needs_autoincrement=True - ) + id = _column(Integer, primary_key=True) - Address.email = Column(String(50), key="_email") - Address.user_id = Column( + Address.email = _column(String(50), key="_email") + Address.user_id = _column( "user_id", Integer, ForeignKey("users.id"), key="_user_id" ) @@ -1565,15 +1542,14 @@ class DeclarativeMultiBaseTest( assert_user_address_mapping(User, Address) - def test_add_prop_manual(self, assert_user_address_mapping): + @testing.combinations(Column, mapped_column, argnames="_column") + def test_add_prop_manual(self, assert_user_address_mapping, _column): class User(Base, fixtures.ComparableEntity): __tablename__ = "users" - id = Column( - "id", Integer, primary_key=True, test_needs_autoincrement=True - ) + id = _column("id", Integer, primary_key=True) - add_mapped_attribute(User, "name", Column("name", String(50))) + add_mapped_attribute(User, "name", _column("name", String(50))) add_mapped_attribute( User, "addresses", relationship("Address", backref="user") ) @@ -1581,17 +1557,17 @@ class DeclarativeMultiBaseTest( class Address(Base, fixtures.ComparableEntity): __tablename__ = "addresses" - id = Column( - Integer, primary_key=True, test_needs_autoincrement=True - ) + id = _column(Integer, primary_key=True) add_mapped_attribute( - Address, "email", Column(String(50), key="_email") + Address, "email", _column(String(50), key="_email") ) add_mapped_attribute( Address, "user_id", - Column("user_id", Integer, ForeignKey("users.id"), key="_user_id"), + _column( + "user_id", Integer, ForeignKey("users.id"), key="_user_id" + ), ) eq_(Address.__table__.c["id"].name, "id") @@ -1612,7 +1588,7 @@ class DeclarativeMultiBaseTest( assert ASub.brap.property is A.data.property assert isinstance( - ASub.brap.original_property, descriptor_props.SynonymProperty + ASub.brap.original_property, descriptor_props.Synonym ) def test_alt_name_attr_subclass_relationship_inline(self): @@ -1634,7 +1610,7 @@ class DeclarativeMultiBaseTest( assert ASub.brap.property is A.b.property assert isinstance( - ASub.brap.original_property, descriptor_props.SynonymProperty + ASub.brap.original_property, descriptor_props.Synonym ) ASub(brap=B()) @@ -1647,9 +1623,7 @@ class DeclarativeMultiBaseTest( A.brap = A.data assert A.brap.property is A.data.property - assert isinstance( - A.brap.original_property, descriptor_props.SynonymProperty - ) + assert isinstance(A.brap.original_property, descriptor_props.Synonym) def test_alt_name_attr_subclass_relationship_attrset( self, require_metaclass @@ -1668,9 +1642,7 @@ class DeclarativeMultiBaseTest( id = Column("id", Integer, primary_key=True) assert A.brap.property is A.b.property - assert isinstance( - A.brap.original_property, descriptor_props.SynonymProperty - ) + assert isinstance(A.brap.original_property, descriptor_props.Synonym) A(brap=B()) def test_eager_order_by(self): diff --git a/test/orm/declarative/test_mixin.py b/test/orm/declarative/test_mixin.py index 97f0d560e4..5be8237e26 100644 --- a/test/orm/declarative/test_mixin.py +++ b/test/orm/declarative/test_mixin.py @@ -1,3 +1,5 @@ +from operator import is_not + import sqlalchemy as sa from sqlalchemy import ForeignKey from sqlalchemy import func @@ -18,15 +20,18 @@ from sqlalchemy.orm import declared_attr from sqlalchemy.orm import deferred from sqlalchemy.orm import events as orm_events from sqlalchemy.orm import has_inherited_table +from sqlalchemy.orm import mapped_column from sqlalchemy.orm import registry from sqlalchemy.orm import relationship from sqlalchemy.orm import synonym from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import eq_ +from sqlalchemy.testing import expect_raises_message from sqlalchemy.testing import expect_warnings from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ +from sqlalchemy.testing import is_true from sqlalchemy.testing import mock from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column @@ -375,16 +380,88 @@ class DeclarativeMixinTest(DeclarativeTestBase): eq_(m1.tables["user"].c.keys(), ["id", "name", "surname"]) eq_(m2.tables["user"].c.keys(), ["id", "username"]) - def test_not_allowed(self): + @testing.combinations(Column, mapped_column, argnames="_column") + @testing.combinations("strname", "colref", "objref", argnames="fk_type") + def test_fk_mixin(self, decl_base, fk_type, _column): + class Bar(decl_base): + __tablename__ = "bar" + + id = _column(Integer, primary_key=True) + + if fk_type == "strname": + fk = ForeignKey("bar.id") + elif fk_type == "colref": + fk = ForeignKey(Bar.__table__.c.id) + elif fk_type == "objref": + fk = ForeignKey(Bar.id) + else: + assert False + class MyMixin: - foo = Column(Integer, ForeignKey("bar.id")) + foo = _column(Integer, fk) - def go(): - class MyModel(Base, MyMixin): - __tablename__ = "foo" + class A(MyMixin, decl_base): + __tablename__ = "a" - assert_raises(sa.exc.InvalidRequestError, go) + id = _column(Integer, primary_key=True) + + class B(MyMixin, decl_base): + __tablename__ = "b" + + id = _column(Integer, primary_key=True) + + is_true(A.__table__.c.foo.references(Bar.__table__.c.id)) + is_true(B.__table__.c.foo.references(Bar.__table__.c.id)) + + fka = list(A.__table__.c.foo.foreign_keys)[0] + fkb = list(A.__table__.c.foo.foreign_keys)[0] + is_not(fka, fkb) + + @testing.combinations(Column, mapped_column, argnames="_column") + def test_fk_mixin_self_referential_error(self, decl_base, _column): + class MyMixin: + id = _column(Integer, primary_key=True) + foo = _column(Integer, ForeignKey(id)) + with expect_raises_message( + sa.exc.InvalidRequestError, + "Columns with foreign keys to non-table-bound columns " + "must be declared as @declared_attr", + ): + + class A(MyMixin, decl_base): + __tablename__ = "a" + + @testing.combinations(Column, mapped_column, argnames="_column") + def test_fk_mixin_self_referential_declared_attr(self, decl_base, _column): + class MyMixin: + id = _column(Integer, primary_key=True) + + @declared_attr + def foo(cls): + return _column(Integer, ForeignKey(cls.id)) + + class A(MyMixin, decl_base): + __tablename__ = "a" + + class B(MyMixin, decl_base): + __tablename__ = "b" + + is_true(A.__table__.c.foo.references(A.__table__.c.id)) + is_true(B.__table__.c.foo.references(B.__table__.c.id)) + + fka = list(A.__table__.c.foo.foreign_keys)[0] + fkb = list(A.__table__.c.foo.foreign_keys)[0] + is_not(fka, fkb) + + is_true(A.__table__.c.foo.references(A.__table__.c.id)) + is_true(B.__table__.c.foo.references(B.__table__.c.id)) + + fka = list(A.__table__.c.foo.foreign_keys)[0] + fkb = list(A.__table__.c.foo.foreign_keys)[0] + is_not(fka, fkb) + + def test_not_allowed(self): class MyRelMixin: foo = relationship("Bar") @@ -1013,7 +1090,7 @@ class DeclarativeMixinTest(DeclarativeTestBase): __mapper_args__ = dict(polymorphic_identity="specific") assert Specific.__table__ is Generic.__table__ - eq_(list(Generic.__table__.c.keys()), ["id", "type", "value"]) + eq_(list(Generic.__table__.c.keys()), ["type", "value", "id"]) assert ( class_mapper(Specific).polymorphic_on is Generic.__table__.c.type ) @@ -1043,7 +1120,7 @@ class DeclarativeMixinTest(DeclarativeTestBase): eq_(Specific.__table__.name, "specific") eq_( list(Generic.__table__.c.keys()), - ["timestamp", "id", "python_type"], + ["python_type", "timestamp", "id"], ) eq_(list(Specific.__table__.c.keys()), ["id"]) eq_(Generic.__table__.kwargs, {"mysql_engine": "InnoDB"}) @@ -1078,7 +1155,7 @@ class DeclarativeMixinTest(DeclarativeTestBase): eq_(BaseType.__table__.name, "basetype") eq_( list(BaseType.__table__.c.keys()), - ["timestamp", "type", "id", "value"], + ["type", "id", "value", "timestamp"], ) eq_(BaseType.__table__.kwargs, {"mysql_engine": "InnoDB"}) assert Single.__table__ is BaseType.__table__ @@ -1326,7 +1403,7 @@ class DeclarativeMixinTest(DeclarativeTestBase): eq_( list(Model.__table__.c.keys()), - ["col1", "col3", "col2", "col4", "id"], + ["id", "col1", "col3", "col2", "col4"], ) def test_honor_class_mro_one(self): @@ -1813,11 +1890,11 @@ class DeclaredAttrTest(DeclarativeTestBase, testing.AssertsCompiledSQL): s = fixture_session() self.assert_compile( s.query(A), - "SELECT a.x AS a_x, a.x + :x_1 AS anon_1, a.id AS a_id FROM a", + "SELECT a.x + :x_1 AS anon_1, a.x AS a_x, a.id AS a_id FROM a", ) self.assert_compile( s.query(B), - "SELECT b.x AS b_x, b.x + :x_1 AS anon_1, b.id AS b_id FROM b", + "SELECT b.x + :x_1 AS anon_1, b.x AS b_x, b.id AS b_id FROM b", ) @testing.requires.predictable_gc @@ -2161,7 +2238,7 @@ class AbstractTest(DeclarativeTestBase): class C(B): c_value = Column(String) - eq_(sa.inspect(C).attrs.keys(), ["id", "name", "data", "c_value"]) + eq_(sa.inspect(C).attrs.keys(), ["id", "name", "c_value", "data"]) def test_implicit_abstract_viadecorator(self): @mapper_registry.mapped @@ -2178,7 +2255,7 @@ class AbstractTest(DeclarativeTestBase): class C(B): c_value = Column(String) - eq_(sa.inspect(C).attrs.keys(), ["id", "name", "data", "c_value"]) + eq_(sa.inspect(C).attrs.keys(), ["id", "name", "c_value", "data"]) def test_middle_abstract_inherits(self): # test for [ticket:3240] diff --git a/test/orm/declarative/test_tm_future_annotations.py b/test/orm/declarative/test_tm_future_annotations.py new file mode 100644 index 0000000000..c7022dc31c --- /dev/null +++ b/test/orm/declarative/test_tm_future_annotations.py @@ -0,0 +1,9 @@ +from __future__ import annotations + +from .test_typed_mapping import MappedColumnTest # noqa +from .test_typed_mapping import RelationshipLHSTest # noqa + +"""runs the annotation-sensitive tests from test_typed_mappings while +having ``from __future__ import annotations`` in effect. + +""" diff --git a/test/orm/declarative/test_typed_mapping.py b/test/orm/declarative/test_typed_mapping.py new file mode 100644 index 0000000000..71eb7ce42b --- /dev/null +++ b/test/orm/declarative/test_typed_mapping.py @@ -0,0 +1,1048 @@ +import dataclasses +import datetime +from decimal import Decimal +from typing import Dict +from typing import Generic +from typing import List +from typing import Optional +from typing import Set +from typing import Type +from typing import TypeVar +from typing import Union + +from sqlalchemy import BIGINT +from sqlalchemy import Column +from sqlalchemy import DateTime +from sqlalchemy import exc as sa_exc +from sqlalchemy import ForeignKey +from sqlalchemy import inspect +from sqlalchemy import Integer +from sqlalchemy import Numeric +from sqlalchemy import select +from sqlalchemy import String +from sqlalchemy import Table +from sqlalchemy import testing +from sqlalchemy import VARCHAR +from sqlalchemy.exc import ArgumentError +from sqlalchemy.orm import as_declarative +from sqlalchemy.orm import composite +from sqlalchemy.orm import declarative_base +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import declared_attr +from sqlalchemy.orm import deferred +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column +from sqlalchemy.orm import relationship +from sqlalchemy.orm import undefer +from sqlalchemy.orm.collections import attribute_mapped_collection +from sqlalchemy.orm.collections import MappedCollection +from sqlalchemy.testing import eq_ +from sqlalchemy.testing import expect_raises +from sqlalchemy.testing import expect_raises_message +from sqlalchemy.testing import fixtures +from sqlalchemy.testing import is_ +from sqlalchemy.testing import is_false +from sqlalchemy.testing import is_not +from sqlalchemy.testing import is_true +from sqlalchemy.testing.fixtures import fixture_session +from sqlalchemy.util.typing import Annotated + + +class DeclarativeBaseTest(fixtures.TestBase): + def test_class_getitem_as_declarative(self): + T = TypeVar("T", bound="CommonBase") # noqa + + class CommonBase(Generic[T]): + @classmethod + def boring(cls: Type[T]) -> Type[T]: + return cls + + @classmethod + def more_boring(cls: Type[T]) -> int: + return 27 + + @as_declarative() + class Base(CommonBase[T]): + foo = 1 + + class Tab(Base["Tab"]): + __tablename__ = "foo" + a = Column(Integer, primary_key=True) + + eq_(Tab.foo, 1) + is_(Tab.__table__, inspect(Tab).local_table) + eq_(Tab.boring(), Tab) + eq_(Tab.more_boring(), 27) + + with expect_raises(AttributeError): + Tab.non_existent + + +class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): + __dialect__ = "default" + + def test_legacy_declarative_base(self): + typ = VARCHAR(50) + Base = declarative_base(type_annotation_map={str: typ}) + + class MyClass(Base): + __tablename__ = "mytable" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[str] + x: Mapped[int] + + is_(MyClass.__table__.c.data.type, typ) + is_true(MyClass.__table__.c.id.primary_key) + + def test_required_no_arg(self, decl_base): + with expect_raises_message( + sa_exc.ArgumentError, + r"Python typing annotation is required for attribute " + r'"A.data" when primary ' + r'argument\(s\) for "MappedColumn" construct are None or ' + r"not present", + ): + + class A(decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True) + data = mapped_column() + + def test_construct_rhs(self, decl_base): + class User(decl_base): + __tablename__ = "users" + + id = mapped_column("id", Integer, primary_key=True) + name = mapped_column(String(50)) + + self.assert_compile( + select(User), "SELECT users.id, users.name FROM users" + ) + eq_(User.__mapper__.primary_key, (User.__table__.c.id,)) + + def test_construct_lhs(self, decl_base): + class User(decl_base): + __tablename__ = "users" + + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column() + data: Mapped[Optional[str]] = mapped_column() + + self.assert_compile( + select(User), "SELECT users.id, users.name, users.data FROM users" + ) + eq_(User.__mapper__.primary_key, (User.__table__.c.id,)) + is_false(User.__table__.c.id.nullable) + is_false(User.__table__.c.name.nullable) + is_true(User.__table__.c.data.nullable) + + def test_construct_lhs_omit_mapped_column(self, decl_base): + class User(decl_base): + __tablename__ = "users" + + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] + data: Mapped[Optional[str]] + x: Mapped[int] + y: Mapped[int] + created_at: Mapped[datetime.datetime] + + self.assert_compile( + select(User), + "SELECT users.id, users.name, users.data, users.x, " + "users.y, users.created_at FROM users", + ) + eq_(User.__mapper__.primary_key, (User.__table__.c.id,)) + is_false(User.__table__.c.id.nullable) + is_false(User.__table__.c.name.nullable) + is_true(User.__table__.c.data.nullable) + assert isinstance(User.__table__.c.created_at.type, DateTime) + + def test_construct_lhs_type_missing(self, decl_base): + class MyClass: + pass + + with expect_raises_message( + sa_exc.ArgumentError, + "Could not locate SQLAlchemy Core type for Python type: .*MyClass", + ): + + class User(decl_base): + __tablename__ = "users" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[MyClass] = mapped_column() + + def test_construct_rhs_type_override_lhs(self, decl_base): + class Element(decl_base): + __tablename__ = "element" + + id: Mapped[int] = mapped_column(BIGINT, primary_key=True) + + class User(decl_base): + __tablename__ = "users" + + id: Mapped[int] = mapped_column(BIGINT, primary_key=True) + other_id: Mapped[int] = mapped_column(ForeignKey("element.id")) + data: Mapped[int] = mapped_column() + + # exact class test + is_(User.__table__.c.id.type.__class__, BIGINT) + is_(User.__table__.c.other_id.type.__class__, BIGINT) + is_(User.__table__.c.data.type.__class__, Integer) + + @testing.combinations(True, False, argnames="include_rhs_type") + def test_construct_nullability_overrides( + self, decl_base, include_rhs_type + ): + + if include_rhs_type: + args = (String,) + else: + args = () + + class User(decl_base): + __tablename__ = "users" + + id: Mapped[int] = mapped_column(primary_key=True) + + lnnl_rndf: Mapped[str] = mapped_column(*args) + lnnl_rnnl: Mapped[str] = mapped_column(*args, nullable=False) + lnnl_rnl: Mapped[str] = mapped_column(*args, nullable=True) + lnl_rndf: Mapped[Optional[str]] = mapped_column(*args) + lnl_rnnl: Mapped[Optional[str]] = mapped_column( + *args, nullable=False + ) + lnl_rnl: Mapped[Optional[str]] = mapped_column( + *args, nullable=True + ) + + is_false(User.__table__.c.lnnl_rndf.nullable) + is_false(User.__table__.c.lnnl_rnnl.nullable) + is_true(User.__table__.c.lnnl_rnl.nullable) + + is_true(User.__table__.c.lnl_rndf.nullable) + is_false(User.__table__.c.lnl_rnnl.nullable) + is_true(User.__table__.c.lnl_rnl.nullable) + + def test_fwd_refs(self, decl_base: Type[DeclarativeBase]): + class MyClass(decl_base): + __tablename__ = "my_table" + + id: Mapped["int"] = mapped_column(primary_key=True) + data_one: Mapped["str"] + + def test_annotated_types_as_keys(self, decl_base: Type[DeclarativeBase]): + """neat!!!""" + + str50 = Annotated[str, 50] + str30 = Annotated[str, 30] + opt_str50 = Optional[str50] + opt_str30 = Optional[str30] + + decl_base.registry.update_type_annotation_map( + {str50: String(50), str30: String(30)} + ) + + class MyClass(decl_base): + __tablename__ = "my_table" + + id: Mapped[str50] = mapped_column(primary_key=True) + data_one: Mapped[str30] + data_two: Mapped[opt_str30] + data_three: Mapped[str50] + data_four: Mapped[opt_str50] + data_five: Mapped[str] + data_six: Mapped[Optional[str]] + + eq_(MyClass.__table__.c.data_one.type.length, 30) + is_false(MyClass.__table__.c.data_one.nullable) + eq_(MyClass.__table__.c.data_two.type.length, 30) + is_true(MyClass.__table__.c.data_two.nullable) + eq_(MyClass.__table__.c.data_three.type.length, 50) + + def test_unions(self): + our_type = Numeric(10, 2) + + class Base(DeclarativeBase): + type_annotation_map = {Union[float, Decimal]: our_type} + + class User(Base): + __tablename__ = "users" + __table__: Table + + id: Mapped[int] = mapped_column(primary_key=True) + + data: Mapped[Union[float, Decimal]] = mapped_column() + reverse_data: Mapped[Union[Decimal, float]] = mapped_column() + optional_data: Mapped[ + Optional[Union[float, Decimal]] + ] = mapped_column() + + # use Optional directly + reverse_optional_data: Mapped[ + Optional[Union[Decimal, float]] + ] = mapped_column() + + # use Union with None, same as Optional but presents differently + # (Optional object with __origin__ Union vs. Union) + reverse_u_optional_data: Mapped[ + Union[Decimal, float, None] + ] = mapped_column() + float_data: Mapped[float] = mapped_column() + decimal_data: Mapped[Decimal] = mapped_column() + + is_(User.__table__.c.data.type, our_type) + is_false(User.__table__.c.data.nullable) + is_(User.__table__.c.reverse_data.type, our_type) + is_(User.__table__.c.optional_data.type, our_type) + is_true(User.__table__.c.optional_data.nullable) + + is_(User.__table__.c.reverse_optional_data.type, our_type) + is_(User.__table__.c.reverse_u_optional_data.type, our_type) + is_true(User.__table__.c.reverse_optional_data.nullable) + is_true(User.__table__.c.reverse_u_optional_data.nullable) + + is_(User.__table__.c.float_data.type, our_type) + is_(User.__table__.c.decimal_data.type, our_type) + + def test_missing_mapped_lhs(self, decl_base): + with expect_raises_message( + ArgumentError, + r'Type annotation for "User.name" should use the ' + r'syntax "Mapped\[str\]" or "MappedColumn\[str\]"', + ): + + class User(decl_base): + __tablename__ = "users" + + id: Mapped[int] = mapped_column(primary_key=True) + name: str = mapped_column() # type: ignore + + def test_construct_lhs_separate_name(self, decl_base): + class User(decl_base): + __tablename__ = "users" + + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column() + data: Mapped[Optional[str]] = mapped_column("the_data") + + self.assert_compile( + select(User.data), "SELECT users.the_data FROM users" + ) + is_true(User.__table__.c.the_data.nullable) + + def test_construct_works_in_expr(self, decl_base): + class User(decl_base): + __tablename__ = "users" + + id: Mapped[int] = mapped_column(primary_key=True) + + class Address(decl_base): + __tablename__ = "addresses" + + id: Mapped[int] = mapped_column(primary_key=True) + user_id: Mapped[int] = mapped_column(ForeignKey("users.id")) + + user = relationship(User, primaryjoin=user_id == User.id) + + self.assert_compile( + select(Address.user_id, User.id).join(Address.user), + "SELECT addresses.user_id, users.id FROM addresses " + "JOIN users ON addresses.user_id = users.id", + ) + + def test_construct_works_as_polymorphic_on(self, decl_base): + class User(decl_base): + __tablename__ = "users" + + id: Mapped[int] = mapped_column(primary_key=True) + type: Mapped[str] = mapped_column() + + __mapper_args__ = {"polymorphic_on": type} + + decl_base.registry.configure() + is_(User.__table__.c.type, User.__mapper__.polymorphic_on) + + def test_construct_works_as_version_id_col(self, decl_base): + class User(decl_base): + __tablename__ = "users" + + id: Mapped[int] = mapped_column(primary_key=True) + version_id: Mapped[int] = mapped_column() + + __mapper_args__ = {"version_id_col": version_id} + + decl_base.registry.configure() + is_(User.__table__.c.version_id, User.__mapper__.version_id_col) + + def test_construct_works_in_deferred(self, decl_base): + class User(decl_base): + __tablename__ = "users" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[str] = deferred(mapped_column()) + + self.assert_compile(select(User), "SELECT users.id FROM users") + self.assert_compile( + select(User).options(undefer(User.data)), + "SELECT users.data, users.id FROM users", + ) + + def test_deferred_kw(self, decl_base): + class User(decl_base): + __tablename__ = "users" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[str] = mapped_column(deferred=True) + + self.assert_compile(select(User), "SELECT users.id FROM users") + self.assert_compile( + select(User).options(undefer(User.data)), + "SELECT users.data, users.id FROM users", + ) + + +class MixinTest(fixtures.TestBase, testing.AssertsCompiledSQL): + __dialect__ = "default" + + def test_mapped_column_omit_fn(self, decl_base): + class MixinOne: + name: Mapped[str] + x: Mapped[int] + y: Mapped[int] = mapped_column() + + class A(MixinOne, decl_base): + __tablename__ = "a" + id: Mapped[int] = mapped_column(primary_key=True) + + # ordering of cols is TODO + eq_(A.__table__.c.keys(), ["id", "y", "name", "x"]) + + def test_mc_duplication_plain(self, decl_base): + class MixinOne: + name: Mapped[str] = mapped_column() + + class A(MixinOne, decl_base): + __tablename__ = "a" + id: Mapped[int] = mapped_column(primary_key=True) + + class B(MixinOne, decl_base): + __tablename__ = "b" + id: Mapped[int] = mapped_column(primary_key=True) + + is_not(A.__table__.c.name, B.__table__.c.name) + + def test_mc_duplication_declared_attr(self, decl_base): + class MixinOne: + @declared_attr + def name(cls) -> Mapped[str]: + return mapped_column() + + class A(MixinOne, decl_base): + __tablename__ = "a" + id: Mapped[int] = mapped_column(primary_key=True) + + class B(MixinOne, decl_base): + __tablename__ = "b" + id: Mapped[int] = mapped_column(primary_key=True) + + is_not(A.__table__.c.name, B.__table__.c.name) + + def test_relationship_requires_declared_attr(self, decl_base): + class Related(decl_base): + __tablename__ = "related" + + id: Mapped[int] = mapped_column(primary_key=True) + + class HasRelated: + related_id: Mapped[int] = mapped_column(ForeignKey(Related.id)) + + related: Mapped[Related] = relationship() + + with expect_raises_message( + sa_exc.InvalidRequestError, + r"Mapper properties \(i.e. deferred,column_property\(\), " + r"relationship\(\), etc.\) must be declared", + ): + + class A(HasRelated, decl_base): + __tablename__ = "a" + id: Mapped[int] = mapped_column(primary_key=True) + + def test_relationship_duplication_declared_attr(self, decl_base): + class Related(decl_base): + __tablename__ = "related" + + id: Mapped[int] = mapped_column(primary_key=True) + + class HasRelated: + related_id: Mapped[int] = mapped_column(ForeignKey(Related.id)) + + @declared_attr + def related(cls) -> Mapped[Related]: + return relationship() + + class A(HasRelated, decl_base): + __tablename__ = "a" + id: Mapped[int] = mapped_column(primary_key=True) + + class B(HasRelated, decl_base): + __tablename__ = "b" + id: Mapped[int] = mapped_column(primary_key=True) + + self.assert_compile( + select(A).join(A.related), + "SELECT a.id, a.related_id FROM a " + "JOIN related ON related.id = a.related_id", + ) + self.assert_compile( + select(B).join(B.related), + "SELECT b.id, b.related_id FROM b " + "JOIN related ON related.id = b.related_id", + ) + + +class RelationshipLHSTest(fixtures.TestBase, testing.AssertsCompiledSQL): + __dialect__ = "default" + + @testing.fixture + def decl_base(self): + class Base(DeclarativeBase): + pass + + yield Base + Base.registry.dispose() + + def test_no_typing_in_rhs(self, decl_base): + class A(decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True) + bs = relationship("List['B']") + + class B(decl_base): + __tablename__ = "b" + id: Mapped[int] = mapped_column(Integer, primary_key=True) + a_id: Mapped[int] = mapped_column(ForeignKey("a.id")) + + with expect_raises_message( + sa_exc.InvalidRequestError, + r"When initializing mapper Mapper\[A\(a\)\], expression " + r'"relationship\(\"List\[\'B\'\]\"\)\" seems to be using a ' + r"generic class as the argument to relationship\(\); please " + r"state the generic argument using an annotation, e.g. " + r'"bs: Mapped\[List\[\'B\'\]\] = relationship\(\)"', + ): + + decl_base.registry.configure() + + def test_required_no_arg(self, decl_base): + with expect_raises_message( + sa_exc.ArgumentError, + r"Python typing annotation is required for attribute " + r'"A.bs" when primary ' + r'argument\(s\) for "Relationship" construct are None or ' + r"not present", + ): + + class A(decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True) + bs = relationship() + + def test_rudimentary_dataclasses_support(self, registry): + @registry.mapped + @dataclasses.dataclass + class A: + __tablename__ = "a" + __sa_dataclass_metadata_key__ = "sa" + + id: Mapped[int] = mapped_column(primary_key=True) + bs: List["B"] = dataclasses.field( # noqa: F821 + default_factory=list, metadata={"sa": relationship()} + ) + + @registry.mapped + @dataclasses.dataclass + class B: + __tablename__ = "b" + id: Mapped[int] = mapped_column(primary_key=True) + a_id = mapped_column(ForeignKey("a.id")) + + self.assert_compile( + select(A).join(A.bs), "SELECT a.id FROM a JOIN b ON a.id = b.a_id" + ) + + def test_basic_bidirectional(self, decl_base): + class A(decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[str] = mapped_column() + bs: Mapped[List["B"]] = relationship( # noqa F821 + back_populates="a" + ) + + class B(decl_base): + __tablename__ = "b" + id: Mapped[int] = mapped_column(Integer, primary_key=True) + a_id: Mapped[int] = mapped_column(ForeignKey("a.id")) + + a: Mapped["A"] = relationship( + back_populates="bs", primaryjoin=a_id == A.id + ) + + a1 = A(data="data") + b1 = B() + a1.bs.append(b1) + is_(a1, b1.a) + + def test_wrong_annotation_type_one(self, decl_base): + + with expect_raises_message( + sa_exc.ArgumentError, + r"Type annotation for \"A.data\" should use the " + r"syntax \"Mapped\['B'\]\" or \"Relationship\['B'\]\"", + ): + + class A(decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True) + data: "B" = relationship() # type: ignore # noqa + + def test_wrong_annotation_type_two(self, decl_base): + + with expect_raises_message( + sa_exc.ArgumentError, + r"Type annotation for \"A.data\" should use the " + r"syntax \"Mapped\[B\]\" or \"Relationship\[B\]\"", + ): + + class B(decl_base): + __tablename__ = "b" + + id: Mapped[int] = mapped_column(primary_key=True) + + class A(decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True) + data: B = relationship() # type: ignore # noqa + + def test_wrong_annotation_type_three(self, decl_base): + + with expect_raises_message( + sa_exc.ArgumentError, + r"Type annotation for \"A.data\" should use the " + r"syntax \"Mapped\['List\[B\]'\]\" or " + r"\"Relationship\['List\[B\]'\]\"", + ): + + class B(decl_base): + __tablename__ = "b" + + id: Mapped[int] = mapped_column(primary_key=True) + + class A(decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True) + data: "List[B]" = relationship() # type: ignore # noqa + + def test_collection_class_uselist(self, decl_base): + class A(decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[str] = mapped_column() + bs_list: Mapped[List["B"]] = relationship( # noqa F821 + viewonly=True + ) + bs_set: Mapped[Set["B"]] = relationship(viewonly=True) # noqa F821 + bs_list_warg: Mapped[List["B"]] = relationship( # noqa F821 + "B", viewonly=True + ) + bs_set_warg: Mapped[Set["B"]] = relationship( # noqa F821 + "B", viewonly=True + ) + + class B(decl_base): + __tablename__ = "b" + id: Mapped[int] = mapped_column(Integer, primary_key=True) + a_id: Mapped[int] = mapped_column(ForeignKey("a.id")) + + a: Mapped["A"] = relationship(viewonly=True) + a_warg: Mapped["A"] = relationship("A", viewonly=True) + + is_(A.__mapper__.attrs["bs_list"].collection_class, list) + is_(A.__mapper__.attrs["bs_set"].collection_class, set) + is_(A.__mapper__.attrs["bs_list_warg"].collection_class, list) + is_(A.__mapper__.attrs["bs_set_warg"].collection_class, set) + is_true(A.__mapper__.attrs["bs_list"].uselist) + is_true(A.__mapper__.attrs["bs_set"].uselist) + is_true(A.__mapper__.attrs["bs_list_warg"].uselist) + is_true(A.__mapper__.attrs["bs_set_warg"].uselist) + + is_false(B.__mapper__.attrs["a"].uselist) + is_false(B.__mapper__.attrs["a_warg"].uselist) + + def test_collection_class_dict_no_collection(self, decl_base): + class A(decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[str] = mapped_column() + bs: Mapped[Dict[str, "B"]] = relationship() # noqa F821 + + class B(decl_base): + __tablename__ = "b" + id: Mapped[int] = mapped_column(Integer, primary_key=True) + a_id: Mapped[int] = mapped_column(ForeignKey("a.id")) + name: Mapped[str] = mapped_column() + + # this is the old collections message. it's not great, but at the + # moment I like that this is what's raised + with expect_raises_message( + sa_exc.ArgumentError, + "Type InstrumentedDict must elect an appender", + ): + decl_base.registry.configure() + + def test_collection_class_dict_attr_mapped_collection(self, decl_base): + class A(decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[str] = mapped_column() + + bs: Mapped[MappedCollection[str, "B"]] = relationship( # noqa F821 + collection_class=attribute_mapped_collection("name") + ) + + class B(decl_base): + __tablename__ = "b" + id: Mapped[int] = mapped_column(Integer, primary_key=True) + a_id: Mapped[int] = mapped_column(ForeignKey("a.id")) + name: Mapped[str] = mapped_column() + + decl_base.registry.configure() + + a1 = A() + b1 = B(name="foo") + + # collection appender on MappedCollection + a1.bs.set(b1) + + is_(a1.bs["foo"], b1) + + +class CompositeTest(fixtures.TestBase, testing.AssertsCompiledSQL): + __dialect__ = "default" + + @testing.fixture + def dataclass_point_fixture(self, decl_base): + @dataclasses.dataclass + class Point: + x: int + y: int + + class Edge(decl_base): + __tablename__ = "edge" + id: Mapped[int] = mapped_column(primary_key=True) + graph_id: Mapped[int] = mapped_column(ForeignKey("graph.id")) + + start: Mapped[Point] = composite( + Point, mapped_column("x1"), mapped_column("y1") + ) + + end: Mapped[Point] = composite( + Point, mapped_column("x2"), mapped_column("y2") + ) + + class Graph(decl_base): + __tablename__ = "graph" + id: Mapped[int] = mapped_column(primary_key=True) + + edges: Mapped[List[Edge]] = relationship() + + decl_base.metadata.create_all(testing.db) + return Point, Graph, Edge + + def test_composite_setup(self, dataclass_point_fixture): + Point, Graph, Edge = dataclass_point_fixture + + with fixture_session() as sess: + sess.add( + Graph( + edges=[ + Edge(start=Point(1, 2), end=Point(3, 4)), + Edge(start=Point(7, 8), end=Point(5, 6)), + ] + ) + ) + sess.commit() + + self.assert_compile( + select(Edge), + "SELECT edge.id, edge.graph_id, edge.x1, edge.y1, " + "edge.x2, edge.y2 FROM edge", + ) + + with fixture_session() as sess: + g1 = sess.scalar(select(Graph)) + + # round trip! + eq_(g1.edges[0].end, Point(3, 4)) + + def test_named_setup(self, decl_base: Type[DeclarativeBase]): + @dataclasses.dataclass + class Address: + street: str + state: str + zip_: str + + class User(decl_base): + __tablename__ = "user" + + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column() + + address: Mapped[Address] = composite( + Address, mapped_column(), mapped_column(), mapped_column("zip") + ) + + decl_base.metadata.create_all(testing.db) + + with fixture_session() as sess: + sess.add( + User( + name="user 1", + address=Address("123 anywhere street", "NY", "12345"), + ) + ) + sess.commit() + + with fixture_session() as sess: + u1 = sess.scalar(select(User)) + + # round trip! + eq_(u1.address, Address("123 anywhere street", "NY", "12345")) + + def test_no_fwd_ref_annotated_setup(self, decl_base): + @dataclasses.dataclass + class Address: + street: str + state: str + zip_: str + + with expect_raises_message( + ArgumentError, + r"Can't use forward ref ForwardRef\('Address'\) " + r"for composite class argument", + ): + + class User(decl_base): + __tablename__ = "user" + + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column() + + address: Mapped["Address"] = composite( + mapped_column(), mapped_column(), mapped_column("zip") + ) + + def test_fwd_ref_plus_no_mapped(self, decl_base): + @dataclasses.dataclass + class Address: + street: str + state: str + zip_: str + + with expect_raises_message( + ArgumentError, + r"Type annotation for \"User.address\" should use the syntax " + r"\"Mapped\['Address'\]\" or \"MappedColumn\['Address'\]\"", + ): + + class User(decl_base): + __tablename__ = "user" + + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column() + + address: "Address" = composite( # type: ignore + mapped_column(), mapped_column(), mapped_column("zip") + ) + + def test_fwd_ref_ok_explicit_cls(self, decl_base): + @dataclasses.dataclass + class Address: + street: str + state: str + zip_: str + + class User(decl_base): + __tablename__ = "user" + + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column() + + address: Mapped["Address"] = composite( + Address, mapped_column(), mapped_column(), mapped_column("zip") + ) + + self.assert_compile( + select(User), + 'SELECT "user".id, "user".name, "user".street, ' + '"user".state, "user".zip FROM "user"', + ) + + def test_cls_annotated_setup(self, decl_base): + @dataclasses.dataclass + class Address: + street: str + state: str + zip_: str + + class User(decl_base): + __tablename__ = "user" + + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column() + + address: Mapped[Address] = composite( + mapped_column(), mapped_column(), mapped_column("zip") + ) + + decl_base.metadata.create_all(testing.db) + + with fixture_session() as sess: + sess.add( + User( + name="user 1", + address=Address("123 anywhere street", "NY", "12345"), + ) + ) + sess.commit() + + with fixture_session() as sess: + u1 = sess.scalar(select(User)) + + # round trip! + eq_(u1.address, Address("123 anywhere street", "NY", "12345")) + + def test_one_col_setup(self, decl_base): + @dataclasses.dataclass + class Address: + street: str + + class User(decl_base): + __tablename__ = "user" + + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column() + + address: Mapped[Address] = composite(Address, mapped_column()) + + decl_base.metadata.create_all(testing.db) + + with fixture_session() as sess: + sess.add( + User( + name="user 1", + address=Address("123 anywhere street"), + ) + ) + sess.commit() + + with fixture_session() as sess: + u1 = sess.scalar(select(User)) + + # round trip! + eq_(u1.address, Address("123 anywhere street")) + + +class AllYourFavoriteHitsTest(fixtures.TestBase, testing.AssertsCompiledSQL): + """try a bunch of common mappings using the new style""" + + __dialect__ = "default" + + def test_employee_joined_inh(self, decl_base: Type[DeclarativeBase]): + + str50 = Annotated[str, 50] + str30 = Annotated[str, 30] + opt_str50 = Optional[str50] + + decl_base.registry.update_type_annotation_map( + {str50: String(50), str30: String(30)} + ) + + class Company(decl_base): + __tablename__ = "company" + + company_id: Mapped[int] = mapped_column(Integer, primary_key=True) + + name: Mapped[str50] + + employees: Mapped[Set["Person"]] = relationship() # noqa F821 + + class Person(decl_base): + __tablename__ = "person" + person_id: Mapped[int] = mapped_column(primary_key=True) + company_id: Mapped[int] = mapped_column( + ForeignKey("company.company_id") + ) + name: Mapped[str50] + type: Mapped[str30] = mapped_column() + + __mapper_args__ = {"polymorphic_on": type} + + class Engineer(Person): + __tablename__ = "engineer" + + person_id: Mapped[int] = mapped_column( + ForeignKey("person.person_id"), primary_key=True + ) + + status: Mapped[str] = mapped_column(String(30)) + engineer_name: Mapped[opt_str50] + primary_language: Mapped[opt_str50] + + class Manager(Person): + __tablename__ = "manager" + + person_id: Mapped[int] = mapped_column( + ForeignKey("person.person_id"), primary_key=True + ) + status: Mapped[str] = mapped_column(String(30)) + manager_name: Mapped[str50] + + is_(Person.__mapper__.polymorphic_on, Person.__table__.c.type) + + # the SELECT statements here confirm the columns present and their + # ordering + self.assert_compile( + select(Person), + "SELECT person.person_id, person.company_id, person.name, " + "person.type FROM person", + ) + + self.assert_compile( + select(Manager), + "SELECT manager.person_id, person.person_id AS person_id_1, " + "person.company_id, person.name, person.type, manager.status, " + "manager.manager_name FROM person " + "JOIN manager ON person.person_id = manager.person_id", + ) + + self.assert_compile( + select(Company).join(Company.employees.of_type(Engineer)), + "SELECT company.company_id, company.name FROM company JOIN " + "(person JOIN engineer ON person.person_id = engineer.person_id) " + "ON company.company_id = person.company_id", + ) diff --git a/test/orm/declarative/test_typing_py3k.py b/test/orm/declarative/test_typing_py3k.py deleted file mode 100644 index 0be91a509f..0000000000 --- a/test/orm/declarative/test_typing_py3k.py +++ /dev/null @@ -1,42 +0,0 @@ -from typing import Generic -from typing import Type -from typing import TypeVar - -from sqlalchemy import Column -from sqlalchemy import inspect -from sqlalchemy import Integer -from sqlalchemy.orm import as_declarative -from sqlalchemy.testing import eq_ -from sqlalchemy.testing import fixtures -from sqlalchemy.testing import is_ -from sqlalchemy.testing.assertions import expect_raises - - -class DeclarativeBaseTest(fixtures.TestBase): - def test_class_getitem(self): - T = TypeVar("T", bound="CommonBase") # noqa - - class CommonBase(Generic[T]): - @classmethod - def boring(cls: Type[T]) -> Type[T]: - return cls - - @classmethod - def more_boring(cls: Type[T]) -> int: - return 27 - - @as_declarative() - class Base(CommonBase[T]): - foo = 1 - - class Tab(Base["Tab"]): - __tablename__ = "foo" - a = Column(Integer, primary_key=True) - - eq_(Tab.foo, 1) - is_(Tab.__table__, inspect(Tab).local_table) - eq_(Tab.boring(), Tab) - eq_(Tab.more_boring(), 27) - - with expect_raises(AttributeError): - Tab.non_existent diff --git a/test/orm/inheritance/test_concrete.py b/test/orm/inheritance/test_concrete.py index fae146755a..c5031ed596 100644 --- a/test/orm/inheritance/test_concrete.py +++ b/test/orm/inheritance/test_concrete.py @@ -14,7 +14,7 @@ from sqlalchemy.orm import configure_mappers from sqlalchemy.orm import joinedload from sqlalchemy.orm import polymorphic_union from sqlalchemy.orm import relationship -from sqlalchemy.orm.util import with_polymorphic +from sqlalchemy.orm import with_polymorphic from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import eq_ diff --git a/test/orm/inheritance/test_poly_loading.py b/test/orm/inheritance/test_poly_loading.py index 5f8ff56395..d9d4a9a221 100644 --- a/test/orm/inheritance/test_poly_loading.py +++ b/test/orm/inheritance/test_poly_loading.py @@ -339,13 +339,13 @@ class TestGeometries(GeometryFixtureBase): testing.db, q.all, CompiledSQL( - "SELECT a.type AS a_type, a.id AS a_id, " + "SELECT a.id AS a_id, a.type AS a_type, " "a.a_data AS a_a_data FROM a", {}, ), Or( CompiledSQL( - "SELECT a.type AS a_type, c.id AS c_id, a.id AS a_id, " + "SELECT c.id AS c_id, a.id AS a_id, a.type AS a_type, " "c.c_data AS c_c_data, c.e_data AS c_e_data, " "c.d_data AS c_d_data " "FROM a JOIN c ON a.id = c.id " @@ -354,7 +354,7 @@ class TestGeometries(GeometryFixtureBase): [{"primary_keys": [1, 2]}], ), CompiledSQL( - "SELECT a.type AS a_type, c.id AS c_id, a.id AS a_id, " + "SELECT c.id AS c_id, a.id AS a_id, a.type AS a_type, " "c.c_data AS c_c_data, " "c.d_data AS c_d_data, c.e_data AS c_e_data " "FROM a JOIN c ON a.id = c.id " @@ -396,13 +396,13 @@ class TestGeometries(GeometryFixtureBase): testing.db, q.all, CompiledSQL( - "SELECT a.type AS a_type, a.id AS a_id, " + "SELECT a.id AS a_id, a.type AS a_type, " "a.a_data AS a_a_data FROM a", {}, ), Or( CompiledSQL( - "SELECT a.type AS a_type, c.id AS c_id, a.id AS a_id, " + "SELECT a.id AS a_id, a.type AS a_type, c.id AS c_id, " "c.c_data AS c_c_data, c.e_data AS c_e_data, " "c.d_data AS c_d_data " "FROM a JOIN c ON a.id = c.id " @@ -411,7 +411,7 @@ class TestGeometries(GeometryFixtureBase): [{"primary_keys": [1, 2]}], ), CompiledSQL( - "SELECT a.type AS a_type, c.id AS c_id, a.id AS a_id, " + "SELECT c.id AS c_id, a.id AS a_id, a.type AS a_type, " "c.c_data AS c_c_data, c.d_data AS c_d_data, " "c.e_data AS c_e_data " "FROM a JOIN c ON a.id = c.id " @@ -465,15 +465,15 @@ class TestGeometries(GeometryFixtureBase): testing.db, q.all, CompiledSQL( - "SELECT a.type AS a_type, a.id AS a_id, " + "SELECT a.id AS a_id, a.type AS a_type, " "a.a_data AS a_a_data FROM a ORDER BY a.id", {}, ), Or( # here, the test is that the adaptation of "a" takes place CompiledSQL( - "SELECT poly.a_type AS poly_a_type, " - "poly.c_id AS poly_c_id, " + "SELECT poly.c_id AS poly_c_id, " + "poly.a_type AS poly_a_type, " "poly.a_id AS poly_a_id, poly.c_c_data AS poly_c_c_data, " "poly.e_id AS poly_e_id, poly.e_e_data AS poly_e_e_data, " "poly.d_id AS poly_d_id, poly.d_d_data AS poly_d_d_data " @@ -489,9 +489,9 @@ class TestGeometries(GeometryFixtureBase): [{"primary_keys": [1, 2]}], ), CompiledSQL( - "SELECT poly.a_type AS poly_a_type, " - "poly.c_id AS poly_c_id, " - "poly.a_id AS poly_a_id, poly.c_c_data AS poly_c_c_data, " + "SELECT poly.c_id AS poly_c_id, " + "poly.a_id AS poly_a_id, poly.a_type AS poly_a_type, " + "poly.c_c_data AS poly_c_c_data, " "poly.d_id AS poly_d_id, poly.d_d_data AS poly_d_d_data, " "poly.e_id AS poly_e_id, poly.e_e_data AS poly_e_e_data " "FROM (SELECT a.id AS a_id, a.type AS a_type, " diff --git a/test/orm/test_composites.py b/test/orm/test_composites.py index 19e090e0ed..f41947b6c8 100644 --- a/test/orm/test_composites.py +++ b/test/orm/test_composites.py @@ -6,8 +6,8 @@ from sqlalchemy import String from sqlalchemy import testing from sqlalchemy import update from sqlalchemy.orm import aliased +from sqlalchemy.orm import Composite from sqlalchemy.orm import composite -from sqlalchemy.orm import CompositeProperty from sqlalchemy.orm import configure_mappers from sqlalchemy.orm import relationship from sqlalchemy.orm import Session @@ -1105,7 +1105,7 @@ class ComparatorTest(fixtures.MappedTest, testing.AssertsCompiledSQL): if custom: - class CustomComparator(sa.orm.CompositeProperty.Comparator): + class CustomComparator(sa.orm.Composite.Comparator): def near(self, other, d): clauses = self.__clause_element__().clauses diff_x = clauses[0] - other.x @@ -1163,7 +1163,7 @@ class ComparatorTest(fixtures.MappedTest, testing.AssertsCompiledSQL): Edge = self.classes.Edge start_prop = Edge.start.property - assert start_prop.comparator_factory is CompositeProperty.Comparator + assert start_prop.comparator_factory is Composite.Comparator def test_custom_comparator_factory(self): self._fixture(True) diff --git a/test/orm/test_dataclasses_py3k.py b/test/orm/test_dataclasses_py3k.py index 706024eb5e..13cd6dec5e 100644 --- a/test/orm/test_dataclasses_py3k.py +++ b/test/orm/test_dataclasses_py3k.py @@ -271,7 +271,9 @@ class PlainDeclarativeDataclassesTest(DataclassesTest): widgets: List[Widget] = dataclasses.field(default_factory=list) widget_count: int = dataclasses.field(init=False) - widgets = relationship("Widget") + __mapper_args__ = dict( + properties=dict(widgets=relationship("Widget")) + ) def __post_init__(self): self.widget_count = len(self.widgets) @@ -912,7 +914,7 @@ class PropagationFromMixinTest(fixtures.TestBase): eq_(BaseType.__table__.name, "basetype") eq_( list(BaseType.__table__.c.keys()), - ["timestamp", "type", "id", "value"], + ["type", "id", "value", "timestamp"], ) eq_(BaseType.__table__.kwargs, {"mysql_engine": "InnoDB"}) assert Single.__table__ is BaseType.__table__ diff --git a/test/orm/test_mapper.py b/test/orm/test_mapper.py index de211cf63b..1fad974b92 100644 --- a/test/orm/test_mapper.py +++ b/test/orm/test_mapper.py @@ -25,6 +25,7 @@ from sqlalchemy.orm import Load from sqlalchemy.orm import load_only from sqlalchemy.orm import reconstructor from sqlalchemy.orm import registry +from sqlalchemy.orm import Relationship from sqlalchemy.orm import relationship from sqlalchemy.orm import Session from sqlalchemy.orm import synonym @@ -2896,12 +2897,10 @@ class ComparatorFactoryTest(_fixtures.FixtureTest, AssertsCompiledSQL): self.classes.User, ) - from sqlalchemy.orm.relationships import RelationshipProperty - # NOTE: this API changed in 0.8, previously __clause_element__() # gave the parent selecatable, now it gives the # primaryjoin/secondaryjoin - class MyFactory(RelationshipProperty.Comparator): + class MyFactory(Relationship.Comparator): __hash__ = None def __eq__(self, other): @@ -2909,7 +2908,7 @@ class ComparatorFactoryTest(_fixtures.FixtureTest, AssertsCompiledSQL): self._source_selectable().c.user_id ) == func.foobar(other.id) - class MyFactory2(RelationshipProperty.Comparator): + class MyFactory2(Relationship.Comparator): __hash__ = None def __eq__(self, other): diff --git a/test/orm/test_options.py b/test/orm/test_options.py index e74ffeced4..96759e3889 100644 --- a/test/orm/test_options.py +++ b/test/orm/test_options.py @@ -930,7 +930,7 @@ class OptionsNoPropTest(_fixtures.FixtureTest): [Item], lambda: (load_only(Item.keywords),), 'Can\'t apply "column loader" strategy to property ' - '"Item.keywords", which is a "relationship property"; this ' + '"Item.keywords", which is a "relationship"; this ' 'loader strategy is intended to be used with a "column property".', ) @@ -942,7 +942,7 @@ class OptionsNoPropTest(_fixtures.FixtureTest): lambda: (joinedload(Keyword.id).joinedload(Item.keywords),), 'Can\'t apply "joined loader" strategy to property "Keyword.id", ' 'which is a "column property"; this loader strategy is intended ' - 'to be used with a "relationship property".', + 'to be used with a "relationship".', ) def test_option_against_wrong_multi_entity_type_attr_two(self): @@ -953,7 +953,7 @@ class OptionsNoPropTest(_fixtures.FixtureTest): lambda: (joinedload(Keyword.keywords).joinedload(Item.keywords),), 'Can\'t apply "joined loader" strategy to property ' '"Keyword.keywords", which is a "column property"; this loader ' - 'strategy is intended to be used with a "relationship property".', + 'strategy is intended to be used with a "relationship".', ) def test_option_against_wrong_multi_entity_type_attr_three(self): diff --git a/test/orm/test_query.py b/test/orm/test_query.py index e7fdf661a6..d0c8f41084 100644 --- a/test/orm/test_query.py +++ b/test/orm/test_query.py @@ -50,6 +50,7 @@ from sqlalchemy.orm import column_property from sqlalchemy.orm import contains_eager from sqlalchemy.orm import defer from sqlalchemy.orm import deferred +from sqlalchemy.orm import join from sqlalchemy.orm import joinedload from sqlalchemy.orm import lazyload from sqlalchemy.orm import Query @@ -59,9 +60,8 @@ from sqlalchemy.orm import Session from sqlalchemy.orm import subqueryload from sqlalchemy.orm import synonym from sqlalchemy.orm import undefer +from sqlalchemy.orm import with_parent from sqlalchemy.orm.context import QueryContext -from sqlalchemy.orm.util import join -from sqlalchemy.orm.util import with_parent from sqlalchemy.sql import expression from sqlalchemy.sql import operators from sqlalchemy.testing import AssertsCompiledSQL diff --git a/test/sql/test_metadata.py b/test/sql/test_metadata.py index e02f0e2ed5..0924941658 100644 --- a/test/sql/test_metadata.py +++ b/test/sql/test_metadata.py @@ -2528,7 +2528,6 @@ class SchemaTest(fixtures.TestBase, AssertsCompiledSQL): assert t2.c.x.references(t1.c.x) def test_create_drop_schema(self): - self.assert_compile( schema.CreateSchema("sa_schema"), "CREATE SCHEMA sa_schema" ) diff --git a/tox.ini b/tox.ini index b2a7a154d9..71fef2a834 100644 --- a/tox.ini +++ b/tox.ini @@ -146,7 +146,6 @@ deps= importlib_metadata; python_version < '3.8' mypy patch==1.* - git+https://github.com/sqlalchemy/sqlalchemy2-stubs commands = pytest -m mypy {posargs}