]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Implement Mypy plugin
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 16 Feb 2021 23:36:50 +0000 (18:36 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 14 Mar 2021 00:01:41 +0000 (19:01 -0500)
Rudimentary and experimental support for Mypy has been added in the form of
a new plugin, which itself depends on new typing stubs for SQLAlchemy. The
plugin allows declarative mappings in their standard form to both be
compatible with Mypy as well as to provide typing support for mapped
classes and instances.

Fixes: #4609
Change-Id: Ia035978c02ad3a5c0e5b3c6c30044dd5a3155170

53 files changed:
doc/build/changelog/unreleased_14/4609.rst [new file with mode: 0644]
doc/build/conf.py
doc/build/glossary.rst
doc/build/index.rst
doc/build/orm/extensions/index.rst
doc/build/orm/extensions/mypy.rst [new file with mode: 0644]
doc/build/orm/internals.rst
doc/build/orm/mapping_styles.rst
lib/sqlalchemy/engine/interfaces.py
lib/sqlalchemy/ext/mypy/__init__.py [new file with mode: 0644]
lib/sqlalchemy/ext/mypy/decl_class.py [new file with mode: 0644]
lib/sqlalchemy/ext/mypy/names.py [new file with mode: 0644]
lib/sqlalchemy/ext/mypy/plugin.py [new file with mode: 0644]
lib/sqlalchemy/ext/mypy/util.py [new file with mode: 0644]
lib/sqlalchemy/orm/__init__.py
lib/sqlalchemy/orm/attributes.py
lib/sqlalchemy/orm/interfaces.py
lib/sqlalchemy/testing/requirements.py
setup.cfg
test/ext/mypy/files/abstract_one.py [new file with mode: 0644]
test/ext/mypy/files/cols_noninferred_plain_nonopt.py [new file with mode: 0644]
test/ext/mypy/files/cols_notype_on_fk_col.py [new file with mode: 0644]
test/ext/mypy/files/complete_orm_no_plugin.py [new file with mode: 0644]
test/ext/mypy/files/composite_props.py [new file with mode: 0644]
test/ext/mypy/files/constr_cols_only.py [new file with mode: 0644]
test/ext/mypy/files/dataclasses_workaround.py [new file with mode: 0644]
test/ext/mypy/files/decl_attrs_one.py [new file with mode: 0644]
test/ext/mypy/files/decl_attrs_two.py [new file with mode: 0644]
test/ext/mypy/files/decl_base_subclass_one.py [new file with mode: 0644]
test/ext/mypy/files/decl_base_subclass_two.py [new file with mode: 0644]
test/ext/mypy/files/declarative_base_dynamic.py [new file with mode: 0644]
test/ext/mypy/files/declarative_base_explicit.py [new file with mode: 0644]
test/ext/mypy/files/ensure_descriptor_type_fully_inferred.py [new file with mode: 0644]
test/ext/mypy/files/ensure_descriptor_type_noninferred.py [new file with mode: 0644]
test/ext/mypy/files/ensure_descriptor_type_semiinferred.py [new file with mode: 0644]
test/ext/mypy/files/imperative_table.py [new file with mode: 0644]
test/ext/mypy/files/inspect.py [new file with mode: 0644]
test/ext/mypy/files/invalid_noninferred_lh_type.py [new file with mode: 0644]
test/ext/mypy/files/mapped_attr_assign.py [new file with mode: 0644]
test/ext/mypy/files/mixin_one.py [new file with mode: 0644]
test/ext/mypy/files/mixin_two.py [new file with mode: 0644]
test/ext/mypy/files/other_mapper_props.py [new file with mode: 0644]
test/ext/mypy/files/plugin_doesnt_break_one.py [new file with mode: 0644]
test/ext/mypy/files/relationship_direct_cls.py [new file with mode: 0644]
test/ext/mypy/files/relationship_err1.py [new file with mode: 0644]
test/ext/mypy/files/relationship_err2.py [new file with mode: 0644]
test/ext/mypy/files/relationship_err3.py [new file with mode: 0644]
test/ext/mypy/files/typeless_fk_col_cant_infer.py [new file with mode: 0644]
test/ext/mypy/files/typing_err1.py [new file with mode: 0644]
test/ext/mypy/files/typing_err2.py [new file with mode: 0644]
test/ext/mypy/files/typing_err3.py [new file with mode: 0644]
test/ext/mypy/test_mypy_plugin_py3k.py [new file with mode: 0644]
tox.ini

diff --git a/doc/build/changelog/unreleased_14/4609.rst b/doc/build/changelog/unreleased_14/4609.rst
new file mode 100644 (file)
index 0000000..2c7ed15
--- /dev/null
@@ -0,0 +1,13 @@
+.. change::
+    :tags: feature, mypy
+    :tickets: 4609
+
+    Rudimentary and experimental support for Mypy has been added in the form of
+    a new plugin, which itself depends on new typing stubs for SQLAlchemy. The
+    plugin allows declarative mappings in their standard form to both be
+    compatible with Mypy as well as to provide typing support for mapped
+    classes and instances.
+
+    .. seealso::
+
+        :ref:`mypy_toplevel`
index 505647708ad5a75385e1ac848f8e2e17c1e969d0..f1a1e2faf11ad1af7196b7eb3974560ef24a2f84 100644 (file)
@@ -62,6 +62,7 @@ changelog_sections = [
     "sql",
     "schema",
     "extensions",
+    "mypy",
     "asyncio",
     "postgresql",
     "mysql",
@@ -120,6 +121,7 @@ autodocmods_convert_modname = {
     "sqlalchemy.ext.asyncio.engine": "sqlalchemy.ext.asyncio",
     "sqlalchemy.ext.asyncio.session": "sqlalchemy.ext.asyncio",
     "sqlalchemy.util._collections": "sqlalchemy.util",
+    "sqlalchemy.orm.attributes": "sqlalchemy.orm",
     "sqlalchemy.orm.relationships": "sqlalchemy.orm",
     "sqlalchemy.orm.interfaces": "sqlalchemy.orm",
     "sqlalchemy.orm.query": "sqlalchemy.orm",
index 2939fdf34c6acccf7694a212f32f4b6aab443532..2a4fd5f585380b685128869061b727b49fec377d 100644 (file)
@@ -224,9 +224,13 @@ Glossary
 
     descriptor
     descriptors
-        In Python, a descriptor is an object attribute with “binding behavior”, one whose attribute access has been overridden by methods in the `descriptor protocol <http://docs.python.org/howto/descriptor.html>`_.
-        Those methods are __get__(), __set__(), and __delete__(). If any of those methods are defined
-        for an object, it is said to be a descriptor.
+
+        In Python, a descriptor is an object attribute with “binding behavior”,
+        one whose attribute access has been overridden by methods in the
+        `descriptor protocol <http://docs.python.org/howto/descriptor.html>`_.
+        Those methods are ``__get__()``, ``__set__()``, and ``__delete__()``.
+        If any of those methods are defined for an object, it is said to be a
+        descriptor.
 
         In SQLAlchemy, descriptors are used heavily in order to provide attribute behavior
         on mapped classes.   When a class is mapped as such::
index e7f19c1f1953cffa148ea1130cbf5c8d9c34c20c..39a4cfb2c5c2e25f2336c3d1ff15b518c6d6324f 100644 (file)
@@ -83,6 +83,7 @@ SQLAlchemy Documentation
       :doc:`AsyncIO Support <orm/extensions/asyncio>`
 
     * **Configuration Extensions:**
+      :doc:`Mypy integration <orm/extensions/mypy>` |
       :doc:`Association Proxy <orm/extensions/associationproxy>` |
       :doc:`Hybrid Attributes <orm/extensions/hybrid>` |
       :doc:`Automap <orm/extensions/automap>` |
index ba040b9f65f84d608b03080ff9e7379477a83814..0dda58affa6ad7452e69536d957e56c605b07e3e 100644 (file)
@@ -20,6 +20,7 @@ behavior.   In particular the "Horizontal Sharding", "Hybrid Attributes", and
     automap
     baked
     declarative/index
+    mypy
     mutable
     orderinglist
     horizontal_shard
diff --git a/doc/build/orm/extensions/mypy.rst b/doc/build/orm/extensions/mypy.rst
new file mode 100644 (file)
index 0000000..e8d85c1
--- /dev/null
@@ -0,0 +1,542 @@
+.. _mypy_toplevel:
+
+Mypy  / Pep-484 Support for ORM Mappings
+========================================
+
+Support for :pep:`484` typing annotations as well as the
+`Mypy <https://mypy.readthedocs.io/>`_ type checking tool.
+
+
+.. note:: The Mypy plugin and typing annotations should be regarded as
+   **alpha level** for the
+   early 1.4 releases of SQLAlchemy.  The plugin has not been tested in real world
+   scenarios and may have many unhandled cases and error conditions.
+   Specifics of the new typing stubs are also **subject to change** during
+   the 1.4 series.
+
+Installation
+------------
+
+The Mypy plugin depends upon new stubs for SQLAlchemy packaged at
+`sqlalchemy2-stubs <https://pypi.org/project/sqlalchemy2-stubs/>`_.  These
+stubs necessarily fully replace the previous ``sqlalchemy-stubs`` typing
+annotations published by Dropbox, as they occupy the same ``sqlalchemy-stubs``
+namespace as specified by :pep:`561`.  The `Mypy <https://pypi.org/project/mypy/>`_
+package itself is also a dependency.
+
+Both packages may be installed using the "mypy" extras hook using pip::
+
+    pip install sqlalchemy[mypy]
+
+The plugin itself is configured as described in
+`Configuring mypy to use Plugins <https://mypy.readthedocs.io/en/latest/extending_mypy.html#configuring-mypy-to-use-plugins>`_,
+using the ``sqlalchemy.ext.mypy.plugin`` module name, such as within
+``setup.cfg``::
+
+    [mypy]
+    plugins = sqlalchemy.ext.mypy.plugin
+
+What the Plugin Does
+--------------------
+
+The primary purpose of the Mypy plugin is to intercept and alter the static
+definition of SQLAlchemy
+:ref:`declarative mappings <orm_declarative_mapper_config_toplevel>` so that
+they match up to how they are structured after they have been
+:term:`instrumented` by their :class:`_orm.Mapper` objects. This allows both
+the class structure itself as well as code that uses the class to make sense to
+the Mypy tool, which otherwise would not be the case based on how declarative
+mappings currently function.    The plugin is not unlike similar plugins
+that are required for libraries like
+`dataclasses <https://docs.python.org/3/library/dataclasses.html>`_ which
+alter classes dynamically at runtime.
+
+To cover the major areas where this occurs, consider the following ORM
+mapping, using the typical example of the ``User`` class::
+
+    from sqlalchemy import Column
+    from sqlalchemy import Integer
+    from sqlalchemy import String
+    from sqlalchemy import select
+    from sqlalchemy.orm import declarative_base
+
+    # "Base" is a class that is created dynamically from the
+    # declarative_base() function
+    Base = declarative_base()
+
+    class User(Base):
+        __tablename__ = 'user'
+
+        id = Column(Integer, primary_key=True)
+        name = Column(String)
+
+    # "some_user" is an instance of the User class, which
+    # accepts "id" and "name" kwargs based on the mapping
+    some_user = User(id=5, name='user')
+
+    # it has an attribute called .name that's a string
+    print(f"Username: {some_user.name}")
+
+    # a select() construct makes use of SQL expressions derived from the
+    # User class itself
+    select_stmt = select(User).where(User.id.in_([3, 4, 5])).where(User.name.contains('s'))
+
+Above, the steps that the Mypy extension can take include:
+
+* Interpretation of the ``Base`` dynamic class generated by
+  :func:`_orm.declarative_base`, so that classes which inherit from it
+  are known to be mapped.  It also can accommodate the class decorator
+  approach described at :ref:`orm_declarative_decorator`.
+
+* Type inference for ORM mapped attributes that are defined in declarative
+  "inline" style, in the above example the ``id`` and ``name`` attributes of
+  the ``User`` class. This includes that an instance of ``User`` will use
+  ``int`` for ``id`` and ``str`` for ``name``. It also includes that when the
+  ``User.id`` and ``User.name`` class-level attributes are accessed, as they
+  are above in the ``select()`` statement, they are compatible with SQL
+  expression behavior, which is derived from the
+  :class:`_orm.InstrumentedAttribute` attribute descriptor class.
+
+* Application of an ``__init__()`` method to mapped classes that do not
+  already include an explicit constructor, which accepts keyword arguments
+  of specific types for all mapped attributes detected.
+
+When the Mypy plugin processes the above file, the resulting static class
+definition and Python code passed to the Mypy tool is equivalent to the
+following::
+
+    from sqlalchemy import Column
+    from sqlalchemy import Integer
+    from sqlalchemy import String
+    from sqlalchemy import select
+    from sqlalchemy.orm import declarative_base
+    from sqlalchemy.orm.decl_api import DeclarativeMeta
+    from sqlalchemy.orm import Mapped
+
+    class Base(metaclass=DeclarativeMeta):
+        __abstract__ = True
+
+    class User(Base):
+        __tablename__ = 'user'
+
+        id: Mapped[Optional[int]] = Mapped._special_method(
+            Column(Integer, primary_key=True)
+        )
+        name: Mapped[Optional[str]] = Mapped._special_method(
+            Column(String)
+        )
+
+        def __init__(self, id: Optional[int] = ..., name: Optional[str] = ...) -> None:
+            ...
+
+    some_user = User(id=5, name='user')
+
+    print(f"Username: {some_user.name}")
+
+    select_stmt = select(User).where(User.id.in_([3, 4, 5])).where(User.name.contains('s'))
+
+The key steps which have been taken above include:
+
+* The ``Base`` class is now defined in terms of the :class:`_orm.DeclarativeMeta`
+  class explicitly, rather than being a dynamic class.
+
+* The ``id`` and ``name`` attributes are defined in terms of the
+  :class:`_orm.Mapped` class, which represents a Python descriptor that
+  exhibits different behaviors at the class vs. instance levels.  The
+  :class:`_orm.Mapped` class is now the base class for the :class:`_orm.InstrumentedAttribute`
+  class that is used for all ORM mapped attributes.
+
+  In ``sqlalchemy2-stubs``,
+  :class:`_orm.Mapped` is defined as a generic class against arbitrary Python
+  types, meaning specific occurrences of :class:`_orm.Mapped` are associated
+  with a specific Python type, such as ``Mapped[Optional[int]]`` and
+  ``Mapped[Optional[str]]`` above.
+
+* The right-hand side of the declarative mapped attribute assignments are
+  **removed**, as this resembles the operation that the :class:`_orm.Mapper`
+  class would normally be doing, which is that it would be replacing these
+  attributes with specific instances of :class:`_orm.InstrumentedAttribute`.
+  The original expression is moved into a function call that will allow it to
+  still be type-checked without conflicting with the left-hand side of the
+  expression. For Mypy purposes, the left-hand typing annotation is sufficient
+  for the attribute's behavior to be understood.
+
+* A type stub for the ``User.__init__()`` method is added which includes the
+  correct keywords and datatypes.
+
+Usage
+------
+
+The following subsections will address individual uses cases that have
+so far been considered for pep-484 compliance.
+
+
+Introspection of Columns based on TypeEngine
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+For mapped columns that include an explicit datatype, when they are mapped
+as inline attributes, the mapped type will be introspected automatically::
+
+    class MyClass(Base):
+        # ...
+
+        id = Column(Integer, primary_key=True)
+        name = Column("employee_name", String(50), nullable=False)
+        other_name = Column(String(50))
+
+Above, the ultimate class-level datatypes of ``id``, ``name`` and
+``other_name`` will be introspected as ``Mapped[Optional[int]]``,
+``Mapped[Optional[str]]`` and ``Mapped[Optional[str]]``. The types are by
+default **always** considered to be ``Optional``, even for the primary key and
+non-nullable column. The reason is because while the database columns "id" and
+"name" can't be NULL, the Python attributes ``id`` and ``name`` most certainly
+can be ``None`` without an explicit constructor::
+
+    >>> m1 = MyClass()
+    >>> m1.id
+    None
+
+The types of the above columns can be stated **explicitly**, providing the
+two advantages of clearer self-documentation as well as being able to
+control which types are optional::
+
+    class MyClass(Base):
+        # ...
+
+        id: int = Column(Integer, primary_key=True)
+        name: str = Column("employee_name", String(50), nullable=False)
+        other_name: Optional[str] = Column(String(50))
+
+The Mypy plugin will accept the above ``int``, ``str`` and ``Optional[str]``
+and convert them to include the ``Mapped[]`` type surrounding them.  The
+``Mapped[]`` construct may also be used explicitly::
+
+    from sqlalchemy.orm import Mapped
+
+    class MyClass(Base):
+        # ...
+
+        id: Mapped[int] = Column(Integer, primary_key=True)
+        name: Mapped[str] = Column("employee_name", String(50), nullable=False)
+        other_name: Mapped[Optional[str]] = Column(String(50))
+
+When the type is non-optional, it simply means that the attribute as accessed
+from an instance of ``MyClass`` will be considered to be non-None::
+
+    mc = MyClass(...)
+
+    # will pass mypy --strict
+    name: str = mc.name
+
+For optional attributes, Mypy considers that the type must include None
+or otherwise be ``Optional``::
+
+    mc = MyClass(...)
+
+    # will pass mypy --strict
+    other_name: Optional[str] = mc.name
+
+Whether or not the mapped attribute is typed as ``Optional``, the
+generation of the ``__init__()`` method will **still consider all keywords
+to be optional**.  This is again matching what the SQLAlchemy ORM actually
+does when it creates the constructor, and should not be confused with the
+behavior of a validating system such as Python ``dataclasses`` which will
+generate a constructor that matches the annotations in terms of optional
+vs. required attributes.
+
+.. tip::
+
+    In the above examples the :class:`_types.Integer` and
+    :class:`_types.String` datatypes are both :class:`_types.TypeEngine`
+    subclasses. In ``sqlalchemy2-stubs``, the :class:`_schema.Column` object is
+    a `generic <https://www.python.org/dev/peps/pep-0484/#generics>`_ which
+    subscribes to the type, e.g. above the column types are
+    ``Column[Integer]``, ``Column[String]``, and ``Column[String]``. The
+    :class:`_types.Integer` and :class:`_types.String` classes are in turn
+    generically subscribed to the Python types they correspond towards, i.e.
+    ``Integer(TypeEngine[int])``, ``String(TypeEngine[str])``.
+
+Columns that Don't have an Explicit Type
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Columns that include a :class:`_schema.ForeignKey` modifier do not need
+to specify a datatype in a SQLAlchemy declarative mapping.  For
+this type of attribute, the Mypy plugin will inform the user that it
+needs an explicit type to be sent::
+
+    # .. other imports
+    from sqlalchemy.sql.schema import ForeignKey
+
+    Base = declarative_base()
+
+    class User(Base):
+        __tablename__ = 'user'
+
+        id = Column(Integer, primary_key=True)
+        name = Column(String)
+
+    class Address(Base):
+        __tablename__ = 'address'
+
+        id = Column(Integer, primary_key=True)
+        user_id = Column(ForeignKey("user.id"))
+
+The plugin will deliver the message as follows::
+
+    $ mypy test3.py --strict
+    test3.py:20: error: [SQLAlchemy Mypy plugin] Can't infer type from
+    ORM mapped expression assigned to attribute 'user_id'; please specify a
+    Python type or Mapped[<python type>] on the left hand side.
+    Found 1 error in 1 file (checked 1 source file)
+
+To resolve, apply an explicit type annotation to the ``Address.user_id``
+column::
+
+    class Address(Base):
+        __tablename__ = 'address'
+
+        id = Column(Integer, primary_key=True)
+        user_id: int = Column(ForeignKey("user.id"))
+
+Mapping Columns with Imperative Table
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+In :ref:`imperative table style <orm_imperative_table_configuration>`, the
+:class:`_schema.Column` definitions are given inside of a :class:`_schema.Table`
+construct which is separate from the mapped attributes themselves.  The Mypy
+plugin does not consider this :class:`_schema.Table`, but instead supports that
+the attributes can be explicitly stated with a complete annotation that
+**must** use the :class:`_orm.Mapped` class to identify them as mapped attributes::
+
+    class MyClass(Base):
+        __table__ = Table(
+            "mytable",
+            Base.metadata,
+            Column(Integer, primary_key=True),
+            Column("employee_name", String(50), nullable=False),
+            Column(String(50))
+        )
+
+        id: Mapped[int]
+        name: Mapped[str]
+        other_name: Mapped[Optional[str]]
+
+The above :class:`_orm.Mapped` annotations are considered as mapped columns and
+will be included in the default constructor, as well as provide the correct
+typing profile for ``MyClass`` both at the class level and the instance level.
+
+Mapping Relationships
+^^^^^^^^^^^^^^^^^^^^^^
+
+The plugin has limited support for using type inference to detect the types
+for relationships.    For all those cases where it can't detect the type,
+it will emit an informative error message, and in all cases the appropriate
+type may be provided explicitly, either with the :class:`_orm.Mapped`
+class or optionally omitting it for an inline declaration.     The plugin
+also needs to determine whether or not the relationship refers to a collection
+or a scalar, and for that it relies upon the explicit value of
+the :paramref:`_orm.relationship.uselist` and/or :paramref:`_orm.relationship.collection_class`
+parameters.  An explicit type is needed if neither of these parameters are
+present, as well as if the target type of the :func:`_orm.relationship`
+is a string or callable, and not a class::
+
+    class User(Base):
+        __tablename__ = 'user'
+
+        id = Column(Integer, primary_key=True)
+        name = Column(String)
+
+    class Address(Base):
+        __tablename__ = 'address'
+
+        id = Column(Integer, primary_key=True)
+        user_id: int = Column(ForeignKey("user.id"))
+
+        user = relationship(User)
+
+The above mapping will produce the following error::
+
+    test3.py:22: error: [SQLAlchemy Mypy plugin] Can't infer scalar or
+    collection for ORM mapped expression assigned to attribute 'user'
+    if both 'uselist' and 'collection_class' arguments are absent from the
+    relationship(); please specify a type annotation on the left hand side.
+    Found 1 error in 1 file (checked 1 source file)
+
+The error can be resolved either by using ``relationship(User, uselist=False)``
+or by providing the type, in this case the scalar ``User`` object::
+
+    class Address(Base):
+        __tablename__ = 'address'
+
+        id = Column(Integer, primary_key=True)
+        user_id: int = Column(ForeignKey("user.id"))
+
+        user: User = relationship(User)
+
+For collections, a similar pattern applies, where in the absence of
+``uselist=True`` or a :paramref:`_orm.relationship.collection_class`,
+a collection annotation such as ``List`` may be used.   It is also fully
+appropriate to use the string name of the class in the annotation as supported
+by pep-484, ensuring the class is imported with in
+the `TYPE_CHECKING block <https://www.python.org/dev/peps/pep-0484/#runtime-or-type-checking>`_
+as approriate::
+
+    from typing import List, TYPE_CHECKING
+    from .mymodel import Base
+
+    if TYPE_CHECKING:
+        # if the target of the relationship is in another module
+        # that cannot normally be imported at runtime
+        from .myaddressmodel import Address
+
+    class User(Base):
+        __tablename__ = 'user'
+
+        id = Column(Integer, primary_key=True)
+        name = Column(String)
+        addresses: List["Address"] = relationship("Address")
+
+As is the case with columns, the :class:`_orm.Mapped` class may also be
+applied explicitly::
+
+    class User(Base):
+        __tablename__ = 'user'
+
+        id = Column(Integer, primary_key=True)
+        name = Column(String)
+
+        addresses: Mapped[List["Address"]] = relationship("Address", back_populates="user")
+
+    class Address(Base):
+        __tablename__ = 'address'
+
+        id = Column(Integer, primary_key=True)
+        user_id: int = Column(ForeignKey("user.id"))
+
+        user: Mapped[User] = relationship(User, back_populates="addresses")
+
+Using @declared_attr
+^^^^^^^^^^^^^^^^^^^^
+
+The :class:`_orm.declared_attr` class allows Declarative mapped attributes
+to be declared in class level functions, and is particularly useful when
+using `declarative mixins <orm_mixins_toplevel>`_.  For these functions,
+the return type of the function should be annotated using either the
+``Mapped[]`` construct or by indicating the exact kind of object returned
+by the function::
+
+    from sqlalchemy.orm.decl_api import declared_attr
+
+    class HasUpdatedAt:
+        @declared_attr
+        def updated_at(cls) -> Column[DateTime]:  # uses Column
+            return Column(DateTime)
+
+    class HasCompany:
+
+        @declared_attr
+        def company_id(cls) -> Mapped[int]:  # uses Mapped
+            return Column(ForeignKey("company.id"))
+
+        @declared_attr
+        def company(cls) -> Mapped["Company"]:
+            return relationship("Company")
+
+    class Employee(HasUpdatedAt, HasCompany, Base):
+        __tablename__ = 'employee'
+
+        id = Column(Integer, primary_key=True)
+        name = Column(String)
+
+Note the mismatch between the actual return type of a method like
+``HasCompany.company`` vs. what is annotated.  The Mypy plugin converts
+all ``@declared_attr`` functions into simple annotated attributes to avoid
+this complexity::
+
+    # what Mypy sees
+    class HasCompany:
+        company_id: Mapped[int]
+        company: Mapped["Company"]
+
+
+Combining with Dataclasses or Other Type-Sensitive Attribute Systems
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+The examples of Python dataclasses integration at :ref:`orm_declarative_dataclasses`
+presents a problem; Python dataclasses expect an explicit type that it will
+use to build the class, and the value given in each assignment statement
+is significant.    That is, a class as follows has to be stated exactly
+as it is in order to be accepted by dataclasses::
+
+    mapper_registry : registry = registry()
+
+
+    @mapper_registry.mapped
+    @dataclass
+    class User:
+        __table__ = Table(
+            "user",
+            mapper_registry.metadata,
+            Column("id", Integer, primary_key=True),
+            Column("name", String(50)),
+            Column("fullname", String(50)),
+            Column("nickname", String(12)),
+        )
+        id: int = field(init=False)
+        name: Optional[str] = None
+        fullname: Optional[str] = None
+        nickname: Optional[str] = None
+        addresses: List[Address] = field(default_factory=list)
+
+        __mapper_args__ = {  # type: ignore
+            "properties" : {
+                "addresses": relationship("Address")
+            }
+        }
+
+We can't apply our ``Mapped[]`` types to the attributes ``id``, ``name``,
+etc. because they will be rejected by the ``@dataclass`` decorator.   Additionally,
+Mypy has another plugin for dataclasses explicitly which can also get in the
+way of what we're doing.
+
+The above class will actually pass Mypy's type checking without issue; the
+only thing we are missing is the ability for attributes on ``User`` to be
+used in SQL expressions, such as::
+
+    stmt = select(User.name).where(User.id.in_([1, 2, 3]))
+
+To provide a workaround for this, the Mypy plugin has an additional feature
+whereby we can specify an extra attribute ``_mypy_mapped_attrs``, that is
+a list that encloses the class-level objects or their string names.
+This attribute can be conditional within the ``TYPE_CHECKING`` variable::
+
+    @mapper_registry.mapped
+    @dataclass
+    class User:
+        __table__ = Table(
+            "user",
+            mapper_registry.metadata,
+            Column("id", Integer, primary_key=True),
+            Column("name", String(50)),
+            Column("fullname", String(50)),
+            Column("nickname", String(12)),
+        )
+        id: int = field(init=False)
+        name: Optional[str] = None
+        fullname: Optional[str]
+        nickname: Optional[str]
+        addresses: List[Address] = field(default_factory=list)
+
+        if TYPE_CHECKING:
+            _mypy_mapped_attrs = [id, name, "fullname", "nickname", addresses]
+
+        __mapper_args__ = {  # type: ignore
+            "properties" : {
+                "addresses": relationship("Address")
+            }
+        }
+
+With the above recipe, the attributes listed in ``_mypy_mapped_attrs``
+will be applied with the :class:`_orm.Mapped` typing information so that the
+``User`` class will behave as a SQLAlchemy mapped class when used in a
+class-bound context.
\ No newline at end of file
index 8f26f7c3c0303bb11765719fa1e1ce1c3e7a65b1..8520fd07c14cdc5354886464e42222fa46023012 100644 (file)
@@ -60,6 +60,8 @@ sections, are listed here.
 
 .. autodata:: MANYTOMANY
 
+.. autoclass:: Mapped
+
 .. autoclass:: MapperProperty
     :members:
 
index 622159bd73e96358ef479bc4ee258e67d0052296..63288c74bca1b88033a81d03b90c7b40da56b720 100644 (file)
@@ -107,10 +107,16 @@ Documentation for Declarative mapping continues at :ref:`declarative_config_topl
 Creating an Explicit Base Non-Dynamically (for use with mypy, similar)
 ----------------------------------------------------------------------
 
-Tools like mypy are not necessarily compatible with the dynamically
-generated ``Base`` delivered by SQLAlchemy functions like :func:`_orm.declarative_base`.
-To build a declarative base in a non-dynamic fashion, the
-:class:`_orm.DeclarativeMeta` class may be used directly as follows::
+SQLAlchemy includes a :ref:`Mypy plugin <mypy_toplevel>` that automatically
+accommodates for the dynamically generated ``Base`` class
+delivered by SQLAlchemy functions like :func:`_orm.declarative_base`.
+This plugin works along with a new set of typing stubs published at
+`sqlalchemy2-stubs <https://pypi.org/project/sqlalhcemy-2-stubs>`_.
+
+When this plugin is not in use, or when using other :pep:`484` tools which
+may not know how to interpret this class, the declarative base class may
+be produced in a fully explicit fashion using the
+:class:`_orm.DeclarativeMeta` directly as follows::
 
     from sqlalchemy.orm import registry
     from sqlalchemy.orm.decl_api import DeclarativeMeta
@@ -119,6 +125,9 @@ To build a declarative base in a non-dynamic fashion, the
 
     class Base(metaclass=DeclarativeMeta):
         __abstract__ = True
+
+        # these are supplied by the sqlalchemy2-stubs, so may be omitted
+        # when they are installed
         registry = mapper_registry
         metadata = mapper_registry.metadata
 
@@ -126,6 +135,11 @@ The above ``Base`` is equivalent to one created using the
 :meth:`_orm.registry.generate_base` method and will be fully understood by
 type analysis tools without the use of plugins.
 
+.. seealso::
+
+    :ref:`mypy_toplevel` - background on the Mypy plugin which applies the
+    above structure automatically when running Mypy.
+
 
 .. _orm_declarative_decorator:
 
@@ -295,7 +309,7 @@ An example of a mapping using ``@dataclass`` using
         nickname: Optional[str] = None
         addresses: List[Address] = field(default_factory=list)
 
-        __mapper_args__ = {
+        __mapper_args__ = {   # type: ignore
             "properties" : {
                 "addresses": relationship("Address")
             }
@@ -428,15 +442,6 @@ A mapping using ``@attr.s``, in conjunction with imperative table::
 
     # other classes...
 
-.. sidebar:: Using MyPy with SQLAlchemy models
-
-    If you are using PEP 484 static type checkers for Python, a `MyPy
-    <http://mypy-lang.org/>`_ plugin is included with `type stubs for
-    SQLAlchemy <https://github.com/dropbox/sqlalchemy-stubs>`_.  The plugin is
-    tailored towards SQLAlchemy declarative models.   SQLAlchemy hopes to include
-    more comprehensive PEP 484 support in future releases.
-
-
 ``@dataclass`` and attrs_ mappings may also be used with classical mappings, i.e.
 with the :meth:`_orm.registry.map_imperatively` function.  See the section
 :ref:`orm_imperative_dataclasses` for a similar example.
index 4eb8a1d8d6698ac8f5a7b2f297f4e134f998f944..010abcc24f5ecc9efeed3a7c3dc8b7e044963ff4 100644 (file)
@@ -1138,7 +1138,7 @@ class CreateEnginePlugin(object):
     """  # noqa: E501
 
     def __init__(self, url, kwargs):
-        # type: (URL, dict[str: Any])
+        # type: (URL, dict[str, Any]) -> None
         """Construct a new :class:`.CreateEnginePlugin`.
 
         The plugin object is instantiated individually for each call
diff --git a/lib/sqlalchemy/ext/mypy/__init__.py b/lib/sqlalchemy/ext/mypy/__init__.py
new file mode 100644 (file)
index 0000000..e69de29
diff --git a/lib/sqlalchemy/ext/mypy/decl_class.py b/lib/sqlalchemy/ext/mypy/decl_class.py
new file mode 100644 (file)
index 0000000..f5215ca
--- /dev/null
@@ -0,0 +1,989 @@
+# ext/mypy/decl_class.py
+# Copyright (C) 2021 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
+
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Type
+from typing import Union
+
+from mypy import nodes
+from mypy import types
+from mypy.messages import format_type
+from mypy.nodes import ARG_NAMED_OPT
+from mypy.nodes import Argument
+from mypy.nodes import AssignmentStmt
+from mypy.nodes import CallExpr
+from mypy.nodes import ClassDef
+from mypy.nodes import Decorator
+from mypy.nodes import JsonDict
+from mypy.nodes import ListExpr
+from mypy.nodes import MDEF
+from mypy.nodes import NameExpr
+from mypy.nodes import PlaceholderNode
+from mypy.nodes import RefExpr
+from mypy.nodes import StrExpr
+from mypy.nodes import SymbolTableNode
+from mypy.nodes import TempNode
+from mypy.nodes import TypeInfo
+from mypy.nodes import Var
+from mypy.plugin import SemanticAnalyzerPluginInterface
+from mypy.plugins.common import add_method_to_class
+from mypy.plugins.common import deserialize_and_fixup_type
+from mypy.subtypes import is_subtype
+from mypy.types import AnyType
+from mypy.types import Instance
+from mypy.types import NoneTyp
+from mypy.types import NoneType
+from mypy.types import TypeOfAny
+from mypy.types import UnboundType
+from mypy.types import UnionType
+
+from . import names
+from . import util
+
+
+class DeclClassApplied:
+    def __init__(
+        self,
+        is_mapped: bool,
+        has_table: bool,
+        mapped_attr_names: Sequence[Tuple[str, Type]],
+        mapped_mro: Sequence[Type],
+    ):
+        self.is_mapped = is_mapped
+        self.has_table = has_table
+        self.mapped_attr_names = mapped_attr_names
+        self.mapped_mro = mapped_mro
+
+    def serialize(self) -> JsonDict:
+        return {
+            "is_mapped": self.is_mapped,
+            "has_table": self.has_table,
+            "mapped_attr_names": [
+                (name, type_.serialize())
+                for name, type_ in self.mapped_attr_names
+            ],
+            "mapped_mro": [type_.serialize() for type_ in self.mapped_mro],
+        }
+
+    @classmethod
+    def deserialize(
+        cls, data: JsonDict, api: SemanticAnalyzerPluginInterface
+    ) -> "DeclClassApplied":
+
+        return DeclClassApplied(
+            is_mapped=data["is_mapped"],
+            has_table=data["has_table"],
+            mapped_attr_names=[
+                (name, deserialize_and_fixup_type(type_, api))
+                for name, type_ in data["mapped_attr_names"]
+            ],
+            mapped_mro=[
+                deserialize_and_fixup_type(type_, api)
+                for type_ in data["mapped_mro"]
+            ],
+        )
+
+
+def _scan_declarative_assignments_and_apply_types(
+    cls: ClassDef, api: SemanticAnalyzerPluginInterface, is_mixin_scan=False
+) -> Optional[DeclClassApplied]:
+
+    if cls.fullname.startswith("builtins"):
+        return None
+    elif "_sa_decl_class_applied" in cls.info.metadata:
+        cls_metadata = DeclClassApplied.deserialize(
+            cls.info.metadata["_sa_decl_class_applied"], api
+        )
+
+        # ensure that a class that's mapped is always picked up by
+        # its mapped() decorator or declarative metaclass before
+        # it would be detected as an unmapped mixin class
+        if not is_mixin_scan:
+            assert cls_metadata.is_mapped
+
+            # mypy can call us more than once.  it then will have reset the
+            # left hand side of everything, but not the right that we removed,
+            # removing our ability to re-scan.   but we have the types
+            # here, so lets re-apply them.
+
+            _re_apply_declarative_assignments(cls, api, cls_metadata)
+
+        return cls_metadata
+
+    cls_metadata = DeclClassApplied(not is_mixin_scan, False, [], [])
+
+    for stmt in util._flatten_typechecking(cls.defs.body):
+        if isinstance(stmt, AssignmentStmt):
+            _scan_declarative_assignment_stmt(cls, api, stmt, cls_metadata)
+        elif isinstance(stmt, Decorator):
+            _scan_declarative_decorator_stmt(cls, api, stmt, cls_metadata)
+    _scan_for_mapped_bases(cls, api, cls_metadata)
+    _add_additional_orm_attributes(cls, api, cls_metadata)
+
+    cls.info.metadata["_sa_decl_class_applied"] = cls_metadata.serialize()
+
+    return cls_metadata
+
+
+def _scan_declarative_decorator_stmt(
+    cls: ClassDef,
+    api: SemanticAnalyzerPluginInterface,
+    stmt: Decorator,
+    cls_metadata: DeclClassApplied,
+):
+    """Extract mapping information from a @declared_attr in a declarative
+    class.
+
+    E.g.::
+
+        @reg.mapped
+        class MyClass:
+            # ...
+
+            @declared_attr
+            def updated_at(cls) -> Column[DateTime]:
+                return Column(DateTime)
+
+    Will resolve in mypy as::
+
+        @reg.mapped
+        class MyClass:
+            # ...
+
+            updated_at: Mapped[Optional[datetime.datetime]]
+
+    """
+    for dec in stmt.decorators:
+        if names._type_id_for_named_node(dec) is names.DECLARED_ATTR:
+            break
+    else:
+        return
+
+    dec_index = cls.defs.body.index(stmt)
+
+    left_hand_explicit_type = None
+
+    if stmt.func.type is not None:
+        func_type = stmt.func.type.ret_type
+        if isinstance(func_type, UnboundType):
+            type_id = names._type_id_for_unbound_type(func_type, cls, api)
+        else:
+            # this does not seem to occur unless the type argument is
+            # incorrect
+            return
+
+        if (
+            type_id
+            in {
+                names.MAPPED,
+                names.RELATIONSHIP,
+                names.COMPOSITE_PROPERTY,
+                names.MAPPER_PROPERTY,
+                names.SYNONYM_PROPERTY,
+                names.COLUMN_PROPERTY,
+            }
+            and func_type.args
+        ):
+            left_hand_explicit_type = func_type.args[0]
+        elif type_id is names.COLUMN and func_type.args:
+            typeengine_arg = func_type.args[0]
+            if isinstance(typeengine_arg, UnboundType):
+                sym = api.lookup(typeengine_arg.name, typeengine_arg)
+                if sym is not None and names._mro_has_id(
+                    sym.node.mro, names.TYPEENGINE
+                ):
+
+                    left_hand_explicit_type = UnionType(
+                        [
+                            _extract_python_type_from_typeengine(sym.node),
+                            NoneType(),
+                        ]
+                    )
+                else:
+                    util.fail(
+                        api,
+                        "Column type should be a TypeEngine "
+                        "subclass not '{}'".format(sym.node.fullname),
+                        func_type,
+                    )
+
+    if left_hand_explicit_type is None:
+        # no type on the decorated function.  our option here is to
+        # dig into the function body and get the return type, but they
+        # should just have an annotation.
+        msg = (
+            "Can't infer type from @declared_attr on function '{}';  "
+            "please specify a return type from this function that is "
+            "one of: Mapped[<python type>], relationship[<target class>], "
+            "Column[<TypeEngine>], MapperProperty[<python type>]"
+        )
+        util.fail(api, msg.format(stmt.var.name), stmt)
+
+        left_hand_explicit_type = AnyType(TypeOfAny.special_form)
+
+    descriptor = api.modules["sqlalchemy.orm.attributes"].names["Mapped"]
+
+    left_node = NameExpr(stmt.var.name)
+    left_node.node = stmt.var
+
+    # totally feeling around in the dark here as I don't totally understand
+    # the significance of UnboundType.  It seems to be something that is
+    # not going to do what's expected when it is applied as the type of
+    # an AssignmentStatement.  So do a feeling-around-in-the-dark version
+    # of converting it to the regular Instance/TypeInfo/UnionType structures
+    # we see everywhere else.
+    if isinstance(left_hand_explicit_type, UnboundType):
+        left_hand_explicit_type = util._unbound_to_instance(
+            api, left_hand_explicit_type
+        )
+
+    left_node.node.type = Instance(descriptor.node, [left_hand_explicit_type])
+
+    # this will ignore the rvalue entirely
+    # rvalue = TempNode(AnyType(TypeOfAny.special_form))
+
+    # rewrite the node as:
+    # <attr> : Mapped[<typ>] =
+    # _sa_Mapped._empty_constructor(lambda: <function body>)
+    # the function body is maintained so it gets type checked internally
+    api.add_symbol_table_node("_sa_Mapped", descriptor)
+    column_descriptor = nodes.NameExpr("_sa_Mapped")
+    column_descriptor.fullname = "sqlalchemy.orm.Mapped"
+    mm = nodes.MemberExpr(column_descriptor, "_empty_constructor")
+
+    arg = nodes.LambdaExpr(stmt.func.arguments, stmt.func.body)
+    rvalue = CallExpr(
+        mm,
+        [arg],
+        [nodes.ARG_POS],
+        ["arg1"],
+    )
+
+    new_stmt = AssignmentStmt([left_node], rvalue)
+    new_stmt.type = left_node.node.type
+
+    cls_metadata.mapped_attr_names.append(
+        (left_node.name, left_hand_explicit_type)
+    )
+    cls.defs.body[dec_index] = new_stmt
+
+
+def _scan_declarative_assignment_stmt(
+    cls: ClassDef,
+    api: SemanticAnalyzerPluginInterface,
+    stmt: AssignmentStmt,
+    cls_metadata: DeclClassApplied,
+):
+    """Extract mapping information from an assignment statement in a
+    declarative class.
+
+    """
+    lvalue = stmt.lvalues[0]
+    if not isinstance(lvalue, NameExpr):
+        return
+
+    sym = cls.info.names.get(lvalue.name)
+
+    # this establishes that semantic analysis has taken place, which
+    # means the nodes are populated and we are called from an appropriate
+    # hook.
+    assert sym is not None
+    node = sym.node
+
+    if isinstance(node, PlaceholderNode):
+        return
+
+    assert node is lvalue.node
+    assert isinstance(node, Var)
+
+    if node.name == "__abstract__":
+        if stmt.rvalue.fullname == "builtins.True":
+            cls_metadata.is_mapped = False
+        return
+    elif node.name == "__tablename__":
+        cls_metadata.has_table = True
+    elif node.name.startswith("__"):
+        return
+    elif node.name == "_mypy_mapped_attrs":
+        if not isinstance(stmt.rvalue, ListExpr):
+            util.fail(api, "_mypy_mapped_attrs is expected to be a list", stmt)
+        else:
+            for item in stmt.rvalue.items:
+                if isinstance(item, (NameExpr, StrExpr)):
+                    _apply_mypy_mapped_attr(cls, api, item, cls_metadata)
+
+    left_hand_mapped_type: Type = None
+
+    if node.is_inferred or node.type is None:
+        if isinstance(stmt.type, UnboundType):
+            # look for an explicit Mapped[] type annotation on the left
+            # side with nothing on the right
+
+            # print(stmt.type)
+            # Mapped?[Optional?[A?]]
+
+            left_hand_explicit_type = stmt.type
+
+            if stmt.type.name == "Mapped":
+                mapped_sym = api.lookup("Mapped", cls)
+                if (
+                    mapped_sym is not None
+                    and names._type_id_for_named_node(mapped_sym.node)
+                    is names.MAPPED
+                ):
+                    left_hand_explicit_type = stmt.type.args[0]
+                    left_hand_mapped_type = stmt.type
+
+            # TODO: do we need to convert from unbound for this case?
+            # left_hand_explicit_type = util._unbound_to_instance(
+            #     api, left_hand_explicit_type
+            # )
+
+        else:
+            left_hand_explicit_type = None
+    else:
+        if (
+            isinstance(node.type, Instance)
+            and names._type_id_for_named_node(node.type.type) is names.MAPPED
+        ):
+            # print(node.type)
+            # sqlalchemy.orm.attributes.Mapped[<python type>]
+            left_hand_explicit_type = node.type.args[0]
+            left_hand_mapped_type = node.type
+        else:
+            # print(node.type)
+            # <python type>
+            left_hand_explicit_type = node.type
+            left_hand_mapped_type = None
+
+    if isinstance(stmt.rvalue, TempNode) and left_hand_mapped_type is not None:
+        # annotation without assignment and Mapped is present
+        # as type annotation
+        # equivalent to using _infer_type_from_left_hand_type_only.
+
+        python_type_for_type = left_hand_explicit_type
+    elif isinstance(stmt.rvalue, CallExpr) and isinstance(
+        stmt.rvalue.callee, RefExpr
+    ):
+
+        type_id = names._type_id_for_callee(stmt.rvalue.callee)
+
+        if type_id is None:
+            return
+        elif type_id is names.COLUMN:
+            python_type_for_type = _infer_type_from_decl_column(
+                api, stmt, node, left_hand_explicit_type, stmt.rvalue
+            )
+        elif type_id is names.RELATIONSHIP:
+            python_type_for_type = _infer_type_from_relationship(
+                api, stmt, node, left_hand_explicit_type
+            )
+        elif type_id is names.COLUMN_PROPERTY:
+            python_type_for_type = _infer_type_from_decl_column_property(
+                api, stmt, node, left_hand_explicit_type
+            )
+        elif type_id is names.SYNONYM_PROPERTY:
+            python_type_for_type = _infer_type_from_left_hand_type_only(
+                api, node, left_hand_explicit_type
+            )
+        elif type_id is names.COMPOSITE_PROPERTY:
+            python_type_for_type = _infer_type_from_decl_composite_property(
+                api, stmt, node, left_hand_explicit_type
+            )
+        else:
+            return
+
+    else:
+        return
+
+    cls_metadata.mapped_attr_names.append((node.name, python_type_for_type))
+
+    assert python_type_for_type is not None
+
+    _apply_type_to_mapped_statement(
+        api,
+        stmt,
+        lvalue,
+        left_hand_explicit_type,
+        python_type_for_type,
+    )
+
+
+def _apply_mypy_mapped_attr(
+    cls: ClassDef,
+    api: SemanticAnalyzerPluginInterface,
+    item: Union[NameExpr, StrExpr],
+    cls_metadata: DeclClassApplied,
+):
+    if isinstance(item, NameExpr):
+        name = item.name
+    elif isinstance(item, StrExpr):
+        name = item.value
+    else:
+        return
+
+    for stmt in cls.defs.body:
+        if isinstance(stmt, AssignmentStmt) and stmt.lvalues[0].name == name:
+            break
+    else:
+        util.fail(api, "Can't find mapped attribute {}".format(name), cls)
+        return
+
+    if stmt.type is None:
+        util.fail(
+            api,
+            "Statement linked from _mypy_mapped_attrs has no "
+            "typing information",
+            stmt,
+        )
+        return
+
+    left_hand_explicit_type = stmt.type
+
+    cls_metadata.mapped_attr_names.append((name, left_hand_explicit_type))
+
+    _apply_type_to_mapped_statement(
+        api, stmt, stmt.lvalues[0], left_hand_explicit_type, None
+    )
+
+
+def _infer_type_from_relationship(
+    api: SemanticAnalyzerPluginInterface,
+    stmt: AssignmentStmt,
+    node: Var,
+    left_hand_explicit_type: Optional[types.Type],
+) -> Union[Instance, UnionType, None]:
+    """Infer the type of mapping from a relationship.
+
+    E.g.::
+
+        @reg.mapped
+        class MyClass:
+            # ...
+
+            addresses = relationship(Address, uselist=True)
+
+            order: Mapped["Order"] = relationship("Order")
+
+    Will resolve in mypy as::
+
+        @reg.mapped
+        class MyClass:
+            # ...
+
+            addresses: Mapped[List[Address]]
+
+            order: Mapped["Order"]
+
+    """
+
+    assert isinstance(stmt.rvalue, CallExpr)
+    target_cls_arg = stmt.rvalue.args[0]
+    python_type_for_type = None
+
+    if isinstance(target_cls_arg, NameExpr) and isinstance(
+        target_cls_arg.node, TypeInfo
+    ):
+        # type
+        related_object_type = target_cls_arg.node
+        python_type_for_type = Instance(related_object_type, [])
+
+    # other cases not covered - an error message directs the user
+    # to set an explicit type annotation
+    #
+    # node.type == str, it's a string
+    # if isinstance(target_cls_arg, NameExpr) and isinstance(
+    #     target_cls_arg.node, Var
+    # )
+    # points to a type
+    # isinstance(target_cls_arg, NameExpr) and isinstance(
+    #     target_cls_arg.node, TypeAlias
+    # )
+    # string expression
+    # isinstance(target_cls_arg, StrExpr)
+
+    uselist_arg = util._get_callexpr_kwarg(stmt.rvalue, "uselist")
+    collection_cls_arg = util._get_callexpr_kwarg(
+        stmt.rvalue, "collection_class"
+    )
+
+    # this can be used to determine Optional for a many-to-one
+    # in the same way nullable=False could be used, if we start supporting
+    # that.
+    # innerjoin_arg = _get_callexpr_kwarg(stmt.rvalue, "innerjoin")
+
+    if (
+        uselist_arg is not None
+        and uselist_arg.fullname == "builtins.True"
+        and collection_cls_arg is None
+    ):
+        if python_type_for_type is not None:
+            python_type_for_type = Instance(
+                api.lookup_fully_qualified("builtins.list").node,
+                [python_type_for_type],
+            )
+    elif (
+        uselist_arg is None or uselist_arg.fullname == "builtins.True"
+    ) and collection_cls_arg is not None:
+        if isinstance(collection_cls_arg.node, TypeInfo):
+            if python_type_for_type is not None:
+                python_type_for_type = Instance(
+                    collection_cls_arg.node, [python_type_for_type]
+                )
+        else:
+            util.fail(
+                api,
+                "Expected Python collection type for "
+                "collection_class parameter",
+                stmt.rvalue,
+            )
+            python_type_for_type = None
+    elif uselist_arg is not None and uselist_arg.fullname == "builtins.False":
+        if collection_cls_arg is not None:
+            util.fail(
+                api,
+                "Sending uselist=False and collection_class at the same time "
+                "does not make sense",
+                stmt.rvalue,
+            )
+        if python_type_for_type is not None:
+            python_type_for_type = UnionType(
+                [python_type_for_type, NoneType()]
+            )
+
+    else:
+        if left_hand_explicit_type is None:
+            msg = (
+                "Can't infer scalar or collection for ORM mapped expression "
+                "assigned to attribute '{}' if both 'uselist' and "
+                "'collection_class' arguments are absent from the "
+                "relationship(); please specify a "
+                "type annotation on the left hand side."
+            )
+            util.fail(api, msg.format(node.name), node)
+
+    if python_type_for_type is None:
+        return _infer_type_from_left_hand_type_only(
+            api, node, left_hand_explicit_type
+        )
+    elif left_hand_explicit_type is not None:
+        return _infer_type_from_left_and_inferred_right(
+            api, node, left_hand_explicit_type, python_type_for_type
+        )
+    else:
+        return python_type_for_type
+
+
+def _infer_type_from_decl_composite_property(
+    api: SemanticAnalyzerPluginInterface,
+    stmt: AssignmentStmt,
+    node: Var,
+    left_hand_explicit_type: Optional[types.Type],
+) -> Union[Instance, UnionType, None]:
+    """Infer the type of mapping from a CompositeProperty."""
+
+    assert isinstance(stmt.rvalue, CallExpr)
+    target_cls_arg = stmt.rvalue.args[0]
+    python_type_for_type = None
+
+    if isinstance(target_cls_arg, NameExpr) and isinstance(
+        target_cls_arg.node, TypeInfo
+    ):
+        related_object_type = target_cls_arg.node
+        python_type_for_type = Instance(related_object_type, [])
+    else:
+        python_type_for_type = None
+
+    if python_type_for_type is None:
+        return _infer_type_from_left_hand_type_only(
+            api, node, left_hand_explicit_type
+        )
+    elif left_hand_explicit_type is not None:
+        return _infer_type_from_left_and_inferred_right(
+            api, node, left_hand_explicit_type, python_type_for_type
+        )
+    else:
+        return python_type_for_type
+
+
+def _infer_type_from_decl_column_property(
+    api: SemanticAnalyzerPluginInterface,
+    stmt: AssignmentStmt,
+    node: Var,
+    left_hand_explicit_type: Optional[types.Type],
+) -> Union[Instance, UnionType, None]:
+    """Infer the type of mapping from a ColumnProperty.
+
+    This includes mappings against ``column_property()`` as well as the
+    ``deferred()`` function.
+
+    """
+    assert isinstance(stmt.rvalue, CallExpr)
+    first_prop_arg = stmt.rvalue.args[0]
+
+    if isinstance(first_prop_arg, CallExpr):
+        type_id = names._type_id_for_callee(first_prop_arg.callee)
+    else:
+        type_id = None
+
+    print(stmt.lvalues[0].name)
+
+    # look for column_property() / deferred() etc with Column as first
+    # argument
+    if type_id is names.COLUMN:
+        return _infer_type_from_decl_column(
+            api, stmt, node, left_hand_explicit_type, first_prop_arg
+        )
+    else:
+        return _infer_type_from_left_hand_type_only(
+            api, node, left_hand_explicit_type
+        )
+
+
+def _infer_type_from_decl_column(
+    api: SemanticAnalyzerPluginInterface,
+    stmt: AssignmentStmt,
+    node: Var,
+    left_hand_explicit_type: Optional[types.Type],
+    right_hand_expression: CallExpr,
+) -> Union[Instance, UnionType, None]:
+    """Infer the type of mapping from a Column.
+
+    E.g.::
+
+        @reg.mapped
+        class MyClass:
+            # ...
+
+            a = Column(Integer)
+
+            b = Column("b", String)
+
+            c: Mapped[int] = Column(Integer)
+
+            d: bool = Column(Boolean)
+
+    Will resolve in MyPy as::
+
+        @reg.mapped
+        class MyClass:
+            # ...
+
+            a : Mapped[int]
+
+            b : Mapped[str]
+
+            c: Mapped[int]
+
+            d: Mapped[bool]
+
+    """
+    assert isinstance(node, Var)
+
+    callee = None
+
+    for column_arg in right_hand_expression.args[0:2]:
+        if isinstance(column_arg, nodes.CallExpr):
+            # x = Column(String(50))
+            callee = column_arg.callee
+            break
+        elif isinstance(column_arg, nodes.NameExpr):
+            if isinstance(column_arg.node, TypeInfo):
+                # x = Column(String)
+                callee = column_arg
+                break
+            else:
+                # x = Column(some_name, String), go to next argument
+                continue
+        elif isinstance(column_arg, (StrExpr,)):
+            # x = Column("name", String), go to next argument
+            continue
+        else:
+            assert False
+
+    if callee is None:
+        return None
+
+    if names._mro_has_id(callee.node.mro, names.TYPEENGINE):
+        python_type_for_type = _extract_python_type_from_typeengine(
+            callee.node
+        )
+
+        if left_hand_explicit_type is not None:
+
+            return _infer_type_from_left_and_inferred_right(
+                api, node, left_hand_explicit_type, python_type_for_type
+            )
+
+        else:
+            python_type_for_type = UnionType(
+                [python_type_for_type, NoneType()]
+            )
+        return python_type_for_type
+    else:
+        # it's not TypeEngine, it's typically implicitly typed
+        # like ForeignKey.  we can't infer from the right side.
+        return _infer_type_from_left_hand_type_only(
+            api, node, left_hand_explicit_type
+        )
+
+
+def _infer_type_from_left_and_inferred_right(
+    api: SemanticAnalyzerPluginInterface,
+    node: Var,
+    left_hand_explicit_type: Optional[types.Type],
+    python_type_for_type: Union[Instance, UnionType],
+) -> Optional[Union[Instance, UnionType]]:
+    """Validate type when a left hand annotation is present and we also
+    could infer the right hand side::
+
+        attrname: SomeType = Column(SomeDBType)
+
+    """
+    if not is_subtype(left_hand_explicit_type, python_type_for_type):
+        descriptor = api.modules["sqlalchemy.orm.attributes"].names["Mapped"]
+
+        effective_type = Instance(descriptor.node, [python_type_for_type])
+
+        msg = (
+            "Left hand assignment '{}: {}' not compatible "
+            "with ORM mapped expression of type {}"
+        )
+        util.fail(
+            api,
+            msg.format(
+                node.name,
+                format_type(left_hand_explicit_type),
+                format_type(effective_type),
+            ),
+            node,
+        )
+
+    return left_hand_explicit_type
+
+
+def _infer_type_from_left_hand_type_only(
+    api: SemanticAnalyzerPluginInterface,
+    node: Var,
+    left_hand_explicit_type: Optional[types.Type],
+) -> Optional[Union[Instance, UnionType]]:
+    """Determine the type based on explicit annotation only.
+
+    if no annotation were present, note that we need one there to know
+    the type.
+
+    """
+    if left_hand_explicit_type is None:
+        msg = (
+            "Can't infer type from ORM mapped expression "
+            "assigned to attribute '{}'; please specify a "
+            "Python type or "
+            "Mapped[<python type>] on the left hand side."
+        )
+        util.fail(api, msg.format(node.name), node)
+
+        descriptor = api.modules["sqlalchemy.orm.attributes"].names["Mapped"]
+
+        return Instance(descriptor.node, [AnyType(TypeOfAny.special_form)])
+
+    else:
+        # use type from the left hand side
+        return left_hand_explicit_type
+
+
+def _re_apply_declarative_assignments(
+    cls: ClassDef,
+    api: SemanticAnalyzerPluginInterface,
+    cls_metadata: DeclClassApplied,
+):
+    """For multiple class passes, re-apply our left-hand side types as mypy
+    seems to reset them in place.
+
+    """
+    mapped_attr_lookup = {
+        name: typ for name, typ in cls_metadata.mapped_attr_names
+    }
+
+    descriptor = api.modules["sqlalchemy.orm.attributes"].names["Mapped"]
+
+    for stmt in cls.defs.body:
+        # for a re-apply, all of our statements are AssignmentStmt;
+        # @declared_attr calls will have been converted and this
+        # currently seems to be preserved by mypy (but who knows if this
+        # will change).
+        if (
+            isinstance(stmt, AssignmentStmt)
+            and stmt.lvalues[0].name in mapped_attr_lookup
+        ):
+            typ = mapped_attr_lookup[stmt.lvalues[0].name]
+            left_node = stmt.lvalues[0].node
+
+            inst = Instance(descriptor.node, [typ])
+            left_node.type = inst
+
+
+def _apply_type_to_mapped_statement(
+    api: SemanticAnalyzerPluginInterface,
+    stmt: AssignmentStmt,
+    lvalue: NameExpr,
+    left_hand_explicit_type: Optional[Union[Instance, UnionType]],
+    python_type_for_type: Union[Instance, UnionType],
+) -> None:
+    """Apply the Mapped[<type>] annotation and right hand object to a
+    declarative assignment statement.
+
+    This converts a Python declarative class statement such as::
+
+        class User(Base):
+            # ...
+
+            attrname = Column(Integer)
+
+    To one that describes the final Python behavior to Mypy::
+
+        class User(Base):
+            # ...
+
+            attrname : Mapped[Optional[int]] = <meaningless temp node>
+
+    """
+    descriptor = api.modules["sqlalchemy.orm.attributes"].names["Mapped"]
+
+    left_node = lvalue.node
+
+    inst = Instance(descriptor.node, [python_type_for_type])
+
+    if left_hand_explicit_type is not None:
+        left_node.type = Instance(descriptor.node, [left_hand_explicit_type])
+    else:
+        lvalue.is_inferred_def = False
+        left_node.type = inst
+
+    # so to have it skip the right side totally, we can do this:
+    # stmt.rvalue = TempNode(AnyType(TypeOfAny.special_form))
+
+    # however, if we instead manufacture a new node that uses the old
+    # one, then we can still get type checking for the call itself,
+    # e.g. the Column, relationship() call, etc.
+
+    # rewrite the node as:
+    # <attr> : Mapped[<typ>] =
+    # _sa_Mapped._empty_constructor(<original CallExpr from rvalue>)
+    # the original right-hand side is maintained so it gets type checked
+    # internally
+    api.add_symbol_table_node("_sa_Mapped", descriptor)
+    column_descriptor = nodes.NameExpr("_sa_Mapped")
+    column_descriptor.fullname = "sqlalchemy.orm.Mapped"
+    mm = nodes.MemberExpr(column_descriptor, "_empty_constructor")
+    orig_call_expr = stmt.rvalue
+    stmt.rvalue = CallExpr(
+        mm,
+        [orig_call_expr],
+        [nodes.ARG_POS],
+        ["arg1"],
+    )
+
+
+def _scan_for_mapped_bases(
+    cls: ClassDef,
+    api: SemanticAnalyzerPluginInterface,
+    cls_metadata: DeclClassApplied,
+) -> None:
+    """Given a class, iterate through its superclass hierarchy to find
+    all other classes that are considered as ORM-significant.
+
+    Locates non-mapped mixins and scans them for mapped attributes to be
+    applied to subclasses.
+
+    """
+
+    baseclasses = list(cls.info.bases)
+    while baseclasses:
+        base: Instance = baseclasses.pop(0)
+
+        # scan each base for mapped attributes.  if they are not already
+        # scanned, that means they are unmapped mixins
+        base_decl_class_applied = (
+            _scan_declarative_assignments_and_apply_types(
+                base.type.defn, api, is_mixin_scan=True
+            )
+        )
+        if base_decl_class_applied is not None:
+            cls_metadata.mapped_mro.append(base)
+        baseclasses.extend(base.type.bases)
+
+
+def _add_additional_orm_attributes(
+    cls: ClassDef,
+    api: SemanticAnalyzerPluginInterface,
+    cls_metadata: DeclClassApplied,
+) -> None:
+    """Apply __init__, __table__ and other attributes to the mapped class."""
+    if "__init__" not in cls.info.names and cls_metadata.is_mapped:
+        mapped_attr_names = {n: t for n, t in cls_metadata.mapped_attr_names}
+
+        for mapped_base in cls_metadata.mapped_mro:
+            base_cls_metadata = DeclClassApplied.deserialize(
+                mapped_base.type.metadata["_sa_decl_class_applied"], api
+            )
+            for n, t in base_cls_metadata.mapped_attr_names:
+                mapped_attr_names.setdefault(n, t)
+
+        arguments = []
+        for name, typ in mapped_attr_names.items():
+            if typ is None:
+                typ = AnyType(TypeOfAny.special_form)
+            arguments.append(
+                Argument(
+                    variable=Var(name, typ),
+                    type_annotation=typ,
+                    initializer=TempNode(typ),
+                    kind=ARG_NAMED_OPT,
+                )
+            )
+        add_method_to_class(api, cls, "__init__", arguments, NoneTyp())
+
+    if "__table__" not in cls.info.names and cls_metadata.has_table:
+        _apply_placeholder_attr_to_class(
+            api, cls, "sqlalchemy.sql.schema.Table", "__table__"
+        )
+    if cls_metadata.is_mapped:
+        _apply_placeholder_attr_to_class(
+            api, cls, "sqlalchemy.orm.mapper.Mapper", "__mapper__"
+        )
+
+
+def _apply_placeholder_attr_to_class(
+    api: SemanticAnalyzerPluginInterface,
+    cls: ClassDef,
+    qualified_name: str,
+    attrname: str,
+):
+    sym = api.lookup_fully_qualified_or_none(qualified_name)
+    if sym:
+        assert isinstance(sym.node, TypeInfo)
+        type_ = Instance(sym.node, [])
+    else:
+        type_ = AnyType(TypeOfAny.special_form)
+    var = Var(attrname)
+    var.info = cls.info
+    var.type = type_
+    cls.info.names[attrname] = SymbolTableNode(MDEF, var)
+
+
+def _extract_python_type_from_typeengine(node: TypeInfo) -> Instance:
+    for mr in node.mro:
+        if (
+            mr.bases
+            and mr.bases[-1].type.fullname
+            == "sqlalchemy.sql.type_api.TypeEngine"
+        ):
+            return mr.bases[-1].args[-1]
+    else:
+        assert False, "could not extract Python type from node: %s" % node
diff --git a/lib/sqlalchemy/ext/mypy/names.py b/lib/sqlalchemy/ext/mypy/names.py
new file mode 100644 (file)
index 0000000..c9d48fc
--- /dev/null
@@ -0,0 +1,194 @@
+# ext/mypy/names.py
+# Copyright (C) 2021 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
+
+from typing import List
+
+from mypy.nodes import ClassDef
+from mypy.nodes import Expression
+from mypy.nodes import FuncDef
+from mypy.nodes import RefExpr
+from mypy.nodes import SymbolNode
+from mypy.nodes import TypeAlias
+from mypy.nodes import TypeInfo
+from mypy.nodes import Union
+from mypy.plugin import SemanticAnalyzerPluginInterface
+from mypy.types import UnboundType
+
+from ... import util
+
+COLUMN = util.symbol("COLUMN")
+RELATIONSHIP = util.symbol("RELATIONSHIP")
+REGISTRY = util.symbol("REGISTRY")
+COLUMN_PROPERTY = util.symbol("COLUMN_PROPERTY")
+TYPEENGINE = util.symbol("TYPEENGNE")
+MAPPED = util.symbol("MAPPED")
+DECLARATIVE_BASE = util.symbol("DECLARATIVE_BASE")
+DECLARATIVE_META = util.symbol("DECLARATIVE_META")
+MAPPED_DECORATOR = util.symbol("MAPPED_DECORATOR")
+COLUMN_PROPERTY = util.symbol("COLUMN_PROPERTY")
+SYNONYM_PROPERTY = util.symbol("SYNONYM_PROPERTY")
+COMPOSITE_PROPERTY = util.symbol("COMPOSITE_PROPERTY")
+DECLARED_ATTR = util.symbol("DECLARED_ATTR")
+MAPPER_PROPERTY = util.symbol("MAPPER_PROPERTY")
+
+
+_lookup = {
+    "Column": (
+        COLUMN,
+        {
+            "sqlalchemy.sql.schema.Column",
+            "sqlalchemy.sql.Column",
+        },
+    ),
+    "RelationshipProperty": (
+        RELATIONSHIP,
+        {
+            "sqlalchemy.orm.relationships.RelationshipProperty",
+            "sqlalchemy.orm.RelationshipProperty",
+        },
+    ),
+    "registry": (
+        REGISTRY,
+        {
+            "sqlalchemy.orm.decl_api.registry",
+            "sqlalchemy.orm.registry",
+        },
+    ),
+    "ColumnProperty": (
+        COLUMN_PROPERTY,
+        {
+            "sqlalchemy.orm.properties.ColumnProperty",
+            "sqlalchemy.orm.ColumnProperty",
+        },
+    ),
+    "SynonymProperty": (
+        SYNONYM_PROPERTY,
+        {
+            "sqlalchemy.orm.descriptor_props.SynonymProperty",
+            "sqlalchemy.orm.SynonymProperty",
+        },
+    ),
+    "CompositeProperty": (
+        COMPOSITE_PROPERTY,
+        {
+            "sqlalchemy.orm.descriptor_props.CompositeProperty",
+            "sqlalchemy.orm.CompositeProperty",
+        },
+    ),
+    "MapperProperty": (
+        MAPPER_PROPERTY,
+        {
+            "sqlalchemy.orm.interfaces.MapperProperty",
+            "sqlalchemy.orm.MapperProperty",
+        },
+    ),
+    "TypeEngine": (TYPEENGINE, {"sqlalchemy.sql.type_api.TypeEngine"}),
+    "Mapped": (MAPPED, {"sqlalchemy.orm.attributes.Mapped"}),
+    "declarative_base": (
+        DECLARATIVE_BASE,
+        {
+            "sqlalchemy.ext.declarative.declarative_base",
+            "sqlalchemy.orm.declarative_base",
+            "sqlalchemy.orm.decl_api.declarative_base",
+        },
+    ),
+    "DeclarativeMeta": (
+        DECLARATIVE_META,
+        {
+            "sqlalchemy.ext.declarative.DeclarativeMeta",
+            "sqlalchemy.orm.DeclarativeMeta",
+            "sqlalchemy.orm.decl_api.DeclarativeMeta",
+        },
+    ),
+    "mapped": (
+        MAPPED_DECORATOR,
+        {
+            "sqlalchemy.orm.decl_api.registry.mapped",
+            "sqlalchemy.orm.registry.mapped",
+        },
+    ),
+    "declared_attr": (
+        DECLARED_ATTR,
+        {
+            "sqlalchemy.orm.decl_api.declared_attr",
+            "sqlalchemy.orm.declared_attr",
+        },
+    ),
+}
+
+
+def _mro_has_id(mro: List[TypeInfo], type_id: int):
+    for mr in mro:
+        check_type_id, fullnames = _lookup.get(mr.name, (None, None))
+        if check_type_id == type_id:
+            break
+    else:
+        return False
+
+    return mr.fullname in fullnames
+
+
+def _type_id_for_unbound_type(
+    type_: UnboundType, cls: ClassDef, api: SemanticAnalyzerPluginInterface
+) -> int:
+    type_id = None
+
+    sym = api.lookup(type_.name, type_)
+    if sym is not None:
+        if isinstance(sym.node, TypeAlias):
+            type_id = _type_id_for_named_node(sym.node.target.type)
+        elif isinstance(sym.node, TypeInfo):
+            type_id = _type_id_for_named_node(sym.node)
+
+    return type_id
+
+
+def _type_id_for_callee(callee: Expression) -> int:
+    if isinstance(callee.node, FuncDef):
+        return _type_id_for_funcdef(callee.node)
+    elif isinstance(callee.node, TypeAlias):
+        type_id = _type_id_for_fullname(callee.node.target.type.fullname)
+    elif isinstance(callee.node, TypeInfo):
+        type_id = _type_id_for_named_node(callee)
+    else:
+        type_id = None
+    return type_id
+
+
+def _type_id_for_funcdef(node: FuncDef) -> int:
+    if hasattr(node.type.ret_type, "type"):
+        type_id = _type_id_for_fullname(node.type.ret_type.type.fullname)
+    else:
+        type_id = None
+    return type_id
+
+
+def _type_id_for_named_node(node: Union[RefExpr, SymbolNode]) -> int:
+    type_id, fullnames = _lookup.get(node.name, (None, None))
+
+    if type_id is None:
+        return None
+
+    elif node.fullname in fullnames:
+        return type_id
+    else:
+        return None
+
+
+def _type_id_for_fullname(fullname: str) -> int:
+    tokens = fullname.split(".")
+    immediate = tokens[-1]
+
+    type_id, fullnames = _lookup.get(immediate, (None, None))
+
+    if type_id is None:
+        return None
+
+    elif fullname in fullnames:
+        return type_id
+    else:
+        return None
diff --git a/lib/sqlalchemy/ext/mypy/plugin.py b/lib/sqlalchemy/ext/mypy/plugin.py
new file mode 100644 (file)
index 0000000..9fcd09b
--- /dev/null
@@ -0,0 +1,215 @@
+# ext/mypy/plugin.py
+# Copyright (C) 2021 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
+
+"""
+Mypy plugin for SQLAlchemy ORM.
+
+"""
+from typing import List
+from typing import Tuple
+from typing import Type
+
+from mypy import nodes
+from mypy.mro import calculate_mro
+from mypy.mro import MroError
+from mypy.nodes import Block
+from mypy.nodes import ClassDef
+from mypy.nodes import GDEF
+from mypy.nodes import MypyFile
+from mypy.nodes import NameExpr
+from mypy.nodes import SymbolTable
+from mypy.nodes import SymbolTableNode
+from mypy.nodes import TypeInfo
+from mypy.plugin import AttributeContext
+from mypy.plugin import Callable
+from mypy.plugin import ClassDefContext
+from mypy.plugin import DynamicClassDefContext
+from mypy.plugin import Optional
+from mypy.plugin import Plugin
+from mypy.types import Instance
+
+from . import decl_class
+from . import names
+from . import util
+
+
+class CustomPlugin(Plugin):
+    def get_dynamic_class_hook(
+        self, fullname: str
+    ) -> Optional[Callable[[DynamicClassDefContext], None]]:
+        if names._type_id_for_fullname(fullname) is names.DECLARATIVE_BASE:
+            return _dynamic_class_hook
+        return None
+
+    def get_base_class_hook(
+        self, fullname: str
+    ) -> Optional[Callable[[ClassDefContext], None]]:
+
+        # kind of a strange relationship between get_metaclass_hook()
+        # and get_base_class_hook().  the former doesn't fire off for
+        # subclasses.   but then you can just check it here from the "base"
+        # and get the same effect.
+        sym = self.lookup_fully_qualified(fullname)
+        if (
+            sym
+            and isinstance(sym.node, TypeInfo)
+            and sym.node.metaclass_type
+            and names._type_id_for_named_node(sym.node.metaclass_type.type)
+            is names.DECLARATIVE_META
+        ):
+            return _base_cls_hook
+        return None
+
+    def get_class_decorator_hook(
+        self, fullname: str
+    ) -> Optional[Callable[[ClassDefContext], None]]:
+
+        sym = self.lookup_fully_qualified(fullname)
+
+        if (
+            sym is not None
+            and names._type_id_for_named_node(sym.node)
+            is names.MAPPED_DECORATOR
+        ):
+            return _cls_decorator_hook
+        return None
+
+    def get_customize_class_mro_hook(
+        self, fullname: str
+    ) -> Optional[Callable[[ClassDefContext], None]]:
+        return _fill_in_decorators
+
+    def get_attribute_hook(
+        self, fullname: str
+    ) -> Optional[Callable[[AttributeContext], Type]]:
+        if fullname.startswith(
+            "sqlalchemy.orm.attributes.QueryableAttribute."
+        ):
+            return _queryable_getattr_hook
+        return None
+
+    def get_additional_deps(
+        self, file: MypyFile
+    ) -> List[Tuple[int, str, int]]:
+        return [
+            (10, "sqlalchemy.orm.attributes", -1),
+            (10, "sqlalchemy.orm.decl_api", -1),
+        ]
+
+
+def plugin(version: str):
+    return CustomPlugin
+
+
+def _queryable_getattr_hook(ctx: AttributeContext) -> Type:
+    # how do I....tell it it has no attribute of a certain name?
+    # can't find any Type that seems to match that
+    return ctx.default_attr_type
+
+
+def _fill_in_decorators(ctx: ClassDefContext) -> None:
+    for decorator in ctx.cls.decorators:
+        # set the ".fullname" attribute of a class decorator
+        # that is a MemberExpr.   This causes the logic in
+        # semanal.py->apply_class_plugin_hooks to invoke the
+        # get_class_decorator_hook for our "registry.map_class()" method.
+        # this seems like a bug in mypy that these decorators are otherwise
+        # skipped.
+        if (
+            isinstance(decorator, nodes.MemberExpr)
+            and decorator.name == "mapped"
+        ):
+
+            sym = ctx.api.lookup(
+                decorator.expr.name, decorator, suppress_errors=True
+            )
+            if sym:
+                if sym.node.type and hasattr(sym.node.type, "type"):
+                    decorator.fullname = (
+                        f"{sym.node.type.type.fullname}.{decorator.name}"
+                    )
+                else:
+                    # if the registry is in the same file as where the
+                    # decorator is used, it might not have semantic
+                    # symbols applied and we can't get a fully qualified
+                    # name or an inferred type, so we are actually going to
+                    # flag an error in this case that they need to annotate
+                    # it.  The "registry" is declared just
+                    # once (or few times), so they have to just not use
+                    # type inference for its assignment in this one case.
+                    util.fail(
+                        ctx.api,
+                        "Class decorator called mapped(), but we can't "
+                        "tell if it's from an ORM registry.  Please "
+                        "annotate the registry assignment, e.g. "
+                        "my_registry: registry = registry()",
+                        sym.node,
+                    )
+
+
+def _cls_metadata_hook(ctx: ClassDefContext) -> None:
+    decl_class._scan_declarative_assignments_and_apply_types(ctx.cls, ctx.api)
+
+
+def _base_cls_hook(ctx: ClassDefContext) -> None:
+    decl_class._scan_declarative_assignments_and_apply_types(ctx.cls, ctx.api)
+
+
+def _cls_decorator_hook(ctx: ClassDefContext) -> None:
+    assert isinstance(ctx.reason, nodes.MemberExpr)
+    expr = ctx.reason.expr
+    assert names._type_id_for_named_node(expr.node.type.type) is names.REGISTRY
+
+    decl_class._scan_declarative_assignments_and_apply_types(ctx.cls, ctx.api)
+
+
+def _dynamic_class_hook(ctx: DynamicClassDefContext) -> None:
+    """Generate a declarative Base class when the declarative_base() function
+    is encountered."""
+
+    cls = ClassDef(ctx.name, Block([]))
+    cls.fullname = ctx.api.qualified_name(ctx.name)
+
+    declarative_meta_sym: SymbolTableNode = ctx.api.modules[
+        "sqlalchemy.orm.decl_api"
+    ].names["DeclarativeMeta"]
+    declarative_meta_typeinfo: TypeInfo = declarative_meta_sym.node
+
+    declarative_meta_name: NameExpr = NameExpr("DeclarativeMeta")
+    declarative_meta_name.kind = GDEF
+    declarative_meta_name.fullname = "sqlalchemy.orm.decl_api.DeclarativeMeta"
+    declarative_meta_name.node = declarative_meta_typeinfo
+
+    cls.metaclass = declarative_meta_name
+
+    declarative_meta_instance = Instance(declarative_meta_typeinfo, [])
+
+    info = TypeInfo(SymbolTable(), cls, ctx.api.cur_mod_id)
+    info.declared_metaclass = info.metaclass_type = declarative_meta_instance
+    cls.info = info
+
+    cls_arg = util._get_callexpr_kwarg(ctx.call, "cls")
+    if cls_arg is not None:
+        decl_class._scan_declarative_assignments_and_apply_types(
+            cls_arg.node.defn, ctx.api, is_mixin_scan=True
+        )
+        info.bases = [Instance(cls_arg.node, [])]
+    else:
+        obj = ctx.api.builtin_type("builtins.object")
+
+        info.bases = [obj]
+
+    try:
+        calculate_mro(info)
+    except MroError:
+        util.fail(
+            ctx.api, "Not able to calculate MRO for declarative base", ctx.call
+        )
+        info.bases = [obj]
+        info.fallback_to_any = True
+
+    ctx.api.add_symbol_table_node(ctx.name, SymbolTableNode(GDEF, info))
diff --git a/lib/sqlalchemy/ext/mypy/util.py b/lib/sqlalchemy/ext/mypy/util.py
new file mode 100644 (file)
index 0000000..e7178a8
--- /dev/null
@@ -0,0 +1,80 @@
+from typing import Optional
+
+from mypy.nodes import CallExpr
+from mypy.nodes import Context
+from mypy.nodes import IfStmt
+from mypy.nodes import NameExpr
+from mypy.nodes import SymbolTableNode
+from mypy.plugin import SemanticAnalyzerPluginInterface
+from mypy.types import Instance
+from mypy.types import NoneType
+from mypy.types import Type
+from mypy.types import UnboundType
+from mypy.types import UnionType
+
+
+def fail(api: SemanticAnalyzerPluginInterface, msg: str, ctx: Context):
+    msg = "[SQLAlchemy Mypy plugin] %s" % msg
+    return api.fail(msg, ctx)
+
+
+def _get_callexpr_kwarg(callexpr: CallExpr, name: str) -> Optional[NameExpr]:
+    try:
+        arg_idx = callexpr.arg_names.index(name)
+    except ValueError:
+        return None
+
+    return callexpr.args[arg_idx]
+
+
+def _flatten_typechecking(stmts):
+    for stmt in stmts:
+        if isinstance(stmt, IfStmt) and stmt.expr[0].name == "TYPE_CHECKING":
+            for substmt in stmt.body[0].body:
+                yield substmt
+        else:
+            yield stmt
+
+
+def _unbound_to_instance(
+    api: SemanticAnalyzerPluginInterface, typ: UnboundType
+) -> Type:
+    """Take the UnboundType that we seem to get as the ret_type from a FuncDef
+    and convert it into an Instance/TypeInfo kind of structure that seems
+    to work as the left-hand type of an AssignmentStatement.
+
+    """
+
+    if not isinstance(typ, UnboundType):
+        return typ
+
+    # TODO: figure out a more robust way to check this.  The node is some
+    # kind of _SpecialForm, there's a typing.Optional that's _SpecialForm,
+    # but I cant figure out how to get them to match up
+    if typ.name == "Optional":
+        # convert from "Optional?" to the more familiar
+        # UnionType[..., NoneType()]
+        return _unbound_to_instance(
+            api,
+            UnionType(
+                [_unbound_to_instance(api, typ_arg) for typ_arg in typ.args]
+                + [NoneType()]
+            ),
+        )
+
+    node = api.lookup(typ.name, typ)
+
+    if node is not None and isinstance(node, SymbolTableNode):
+        bound_type = node.node
+
+        return Instance(
+            bound_type,
+            [
+                _unbound_to_instance(api, arg)
+                if isinstance(arg, UnboundType)
+                else arg
+                for arg in typ.args
+            ],
+        )
+    else:
+        return typ
index daab8a2e3604760a913f7190c9597b0bf584c8eb..025d826e34f102c88518a452f4781d24222384de 100644 (file)
@@ -18,6 +18,7 @@ from . import mapper as mapperlib
 from . import strategy_options
 from .attributes import AttributeEvent
 from .attributes import InstrumentedAttribute
+from .attributes import Mapped
 from .attributes import QueryableAttribute
 from .context import QueryContext
 from .decl_api import as_declarative
index b96b3b61e34885125aa11d8f8ec8f80ff88e09a1..2e48695f51046878a979dbf1031ad08d68084a72 100644 (file)
@@ -334,7 +334,85 @@ def _queryable_attribute_unreduce(key, mapped_class, parententity, entity):
         return getattr(entity, key)
 
 
-class InstrumentedAttribute(QueryableAttribute):
+if util.py3k:
+    from typing import TypeVar, Generic
+
+    _T = TypeVar("_T")
+    _Generic_T = Generic[_T]
+else:
+    _Generic_T = type("_Generic_T", (), {})
+
+
+class Mapped(QueryableAttribute, _Generic_T):
+    """Represent an ORM mapped :term:`descriptor` attribute for typing purposes.
+
+    This class represents the complete descriptor interface for any class
+    attribute that will have been :term:`instrumented` by the ORM
+    :class:`_orm.Mapper` class. When used with typing stubs, it is the final
+    type that would be used by a type checker such as mypy to provide the full
+    behavioral contract for the attribute.
+
+    .. tip::
+
+        The :class:`_orm.Mapped` class represents attributes that are handled
+        directly by the :class:`_orm.Mapper` class. It does not include other
+        Python descriptor classes that are provided as extensions, including
+        :ref:`hybrids_toplevel` and the :ref:`associationproxy_toplevel`.
+        While these systems still make use of ORM-specific superclasses
+        and structures, they are not :term:`instrumented` by the
+        :class:`_orm.Mapper` and instead provide their own functionality
+        when they are accessed on a class.
+
+    When using the :ref:`SQLAlchemy Mypy plugin <mypy_toplevel>`, the
+    :class:`_orm.Mapped` construct is used in typing annotations to indicate to
+    the plugin those attributes that are expected to be mapped; the plugin also
+    applies :class:`_orm.Mapped` as an annotation automatically when it scans
+    through declarative mappings in :ref:`orm_declarative_table` style. For
+    more indirect mapping styles such as
+    :ref:`imperative table <orm_imperative_table_configuration>` it is
+    typically applied explicitly to class level attributes that expect
+    to be mapped based on a given :class:`_schema.Table` configuration.
+
+    :class:`_orm.Mapped` is defined in the
+    `sqlalchemy2-stubs <https://pypi.org/project/sqlalchemy2-stubs>`_ project
+    as a :pep:`484` generic class which may subscribe to any arbitrary Python
+    type, which represents the Python type handled by the attribute::
+
+        class MyMappedClass(Base):
+            __table_ = Table(
+                "some_table", Base.metadata,
+                Column("id", Integer, primary_key=True),
+                Column("data", String(50)),
+                Column("created_at", DateTime)
+            )
+
+            id : Mapped[int]
+            data: Mapped[str]
+            created_at: Mapped[datetime]
+
+    For complete background on how to use :class:`_orm.Mapped` with
+    pep-484 tools like Mypy, see the link below for background on SQLAlchemy's
+    Mypy plugin.
+
+    .. versionadded:: 1.4
+
+    .. seealso::
+
+        :ref:`mypy_toplevel` - complete background on Mypy integration
+
+    """
+
+    def __get__(self, instance, owner):
+        raise NotImplementedError()
+
+    def __set__(self, instance, value):
+        raise NotImplementedError()
+
+    def __delete__(self, instance):
+        raise NotImplementedError()
+
+
+class InstrumentedAttribute(Mapped):
     """Class bound instrumented attribute which adds basic
     :term:`descriptor` methods.
 
@@ -1369,6 +1447,7 @@ class CollectionAttributeImpl(AttributeImpl):
         value,
         initiator=None,
         passive=PASSIVE_OFF,
+        check_old=None,
         pop=False,
         _adapt=True,
     ):
index 994a76bdc855ed577020341cbde7f93fb8e6c6c8..7aae9ec37af577098da38b2087870596e8d058a3 100644 (file)
@@ -167,7 +167,7 @@ class MapperProperty(
         """
 
     def cascade_iterator(
-        self, type_, state, visited_instances=None, halt_on=None
+        self, type_, state, dict_, visited_states, halt_on=None
     ):
         """Iterate through instances related to the given instance for
         a particular 'cascade', starting with this MapperProperty.
index f16ba326cc8abde621f11ecbdd77964f7df46df1..de2b8f12c3d0a84a08a396921e92be5faf5b1f68 100644 (file)
@@ -1172,6 +1172,18 @@ class SuiteRequirements(Requirements):
             "Stability issues with coverage + py3k",
         )
 
+    @property
+    def sqlalchemy2_stubs(self):
+        def check(config):
+            try:
+                __import__("sqlalchemy-stubs.ext.mypy")
+            except ImportError:
+                return False
+            else:
+                return True
+
+        return exclusions.only_if(check)
+
     @property
     def python2(self):
         return exclusions.skip_if(
index 1e11063f3b048c37fbb5257d33eb3f548cba1841..fd196f4f56cf42cd968bd63ee50294e49f6fa2ba 100644 (file)
--- a/setup.cfg
+++ b/setup.cfg
@@ -45,6 +45,9 @@ install_requires =
 [options.extras_require]
 asyncio =
     greenlet!=0.4.17;python_version>="3"
+mypy =
+    mypy >= 0.800;python_version>="3"
+    sqlalchemy2-stubs
 mssql = pyodbc
 mssql_pymssql = pymssql
 mssql_pyodbc = pyodbc
@@ -108,6 +111,10 @@ per-file-ignores =
                 lib/sqlalchemy/types.py:F401
                 lib/sqlalchemy/sql/expression.py:F401
 
+[mypy]
+# min mypy version 0.800
+plugins = sqlalchemy.ext.mypy.plugin
+
 [sqla_testing]
 requirement_cls = test.requirements:DefaultRequirements
 profile_file = test/profiles.txt
diff --git a/test/ext/mypy/files/abstract_one.py b/test/ext/mypy/files/abstract_one.py
new file mode 100644 (file)
index 0000000..d11631d
--- /dev/null
@@ -0,0 +1,28 @@
+from sqlalchemy import Column
+from sqlalchemy import Integer
+from sqlalchemy import String
+from sqlalchemy.orm import declarative_base
+
+
+Base = declarative_base()
+
+
+class FooBase(Base):
+    __abstract__ = True
+
+    updated_at = Column(Integer)
+
+
+class Foo(FooBase):
+    __tablename__ = "foo"
+    id: int = Column(Integer(), primary_key=True)
+    name: str = Column(String)
+
+
+Foo.updated_at.in_([1, 2, 3])
+
+f1 = Foo(name="name", updated_at=5)
+
+# test that we read the __abstract__ flag and don't apply a constructor
+# EXPECTED_MYPY: Unexpected keyword argument "updated_at" for "FooBase"
+FooBase(updated_at=5)
diff --git a/test/ext/mypy/files/cols_noninferred_plain_nonopt.py b/test/ext/mypy/files/cols_noninferred_plain_nonopt.py
new file mode 100644 (file)
index 0000000..a2825e0
--- /dev/null
@@ -0,0 +1,36 @@
+from sqlalchemy import Column
+from sqlalchemy import Integer
+from sqlalchemy import String
+from sqlalchemy.orm import registry
+
+
+reg: registry = registry()
+
+
+@reg.mapped
+class Foo:
+    id: int = Column(Integer())
+    name: str = Column(String)
+    other_name: str = Column(String(50))
+
+    # has a string key in it
+    third_name = Column("foo", String(50))
+
+    some_name = "fourth_name"
+
+    fourth_name = Column(some_name, String(50))
+
+
+f1 = Foo()
+
+# This needs to work, e.g., value is "int" at the instance level
+val: int = f1.id  # noqa
+
+# also, the type are not optional, since we used an explicit
+# type without Optional
+p: str = f1.name
+
+Foo.id.property
+
+
+Foo(name="n", other_name="on", third_name="tn", fourth_name="fn")
diff --git a/test/ext/mypy/files/cols_notype_on_fk_col.py b/test/ext/mypy/files/cols_notype_on_fk_col.py
new file mode 100644 (file)
index 0000000..3195714
--- /dev/null
@@ -0,0 +1,44 @@
+from typing import Optional
+
+from sqlalchemy import Column
+from sqlalchemy import ForeignKey
+from sqlalchemy import Integer
+from sqlalchemy import String
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import registry
+
+reg: registry = registry()
+
+
+@reg.mapped
+class User:
+    __tablename__ = "user"
+
+    id = Column(Integer(), primary_key=True)
+    name = Column(String)
+
+
+@reg.mapped
+class Address:
+    __tablename__ = "address"
+
+    id = Column(Integer, primary_key=True)
+    user_id: Mapped[int] = Column(ForeignKey("user.id"))
+    email_address = Column(String)
+
+
+ad1 = Address()
+
+p: Optional[int] = ad1.user_id
+
+# it's not optional because we called it Mapped[int]
+# and not Mapped[Optional[int]]
+p2: int = ad1.user_id
+
+
+# class-level descriptor access
+User.name.in_(["x", "y"])
+
+
+# class-level descriptor access
+Address.user_id.in_([1, 2])
diff --git a/test/ext/mypy/files/complete_orm_no_plugin.py b/test/ext/mypy/files/complete_orm_no_plugin.py
new file mode 100644 (file)
index 0000000..5329150
--- /dev/null
@@ -0,0 +1,96 @@
+# NOPLUGINS
+# this should pass typing with no plugins
+
+from typing import Any
+from typing import List
+from typing import Mapping
+from typing import Optional
+
+from sqlalchemy import Column
+from sqlalchemy import create_engine
+from sqlalchemy import ForeignKey
+from sqlalchemy import Integer
+from sqlalchemy import select
+from sqlalchemy import String
+from sqlalchemy import Table
+from sqlalchemy.orm import registry
+from sqlalchemy.orm import relationship
+from sqlalchemy.orm import Session
+from sqlalchemy.orm.attributes import Mapped
+from sqlalchemy.orm.decl_api import DeclarativeMeta
+
+
+class Base(metaclass=DeclarativeMeta):
+    __abstract__ = True
+    registry = registry()
+    metadata = registry.metadata
+
+
+class A(Base):
+    __table__ = Table(
+        "a",
+        Base.metadata,
+        Column("id", Integer, primary_key=True),
+        Column("data", String),
+    )
+
+    __mapper_args__: Mapping[str, Any] = {
+        "properties": {"bs": relationship("B")}
+    }
+
+    id: Mapped[int]
+    data: Mapped[str]
+    bs: "Mapped[List[B]]"
+
+    def __init__(
+        self,
+        id: Optional[int] = None,  # noqa: A002
+        data: Optional[str] = None,
+        bs: "Optional[List[B]]" = None,
+    ):
+        self.registry.constructor(self, id=id, data=data, bs=bs)
+
+
+class B(Base):
+    __table__ = Table(
+        "b",
+        Base.metadata,
+        Column("id", Integer, primary_key=True),
+        Column("a_id", ForeignKey("a.id")),
+        Column("data", String),
+    )
+    id: Mapped[int]
+    a_id: Mapped[int]
+    data: Mapped[str]
+
+    def __init__(
+        self,
+        id: Optional[int] = None,  # noqa: A002
+        a_id: Optional[int] = None,
+        data: Optional[str] = None,
+    ):
+        self.registry.constructor(self, id=id, a_id=a_id, data=data)
+
+
+e = create_engine("sqlite://", echo=True)
+Base.metadata.create_all(e)
+
+s = Session(e)
+
+
+a1 = A(data="some data", bs=[B(data="some data")])
+
+x: List[B] = a1.bs
+
+s.add(a1)
+s.commit()
+
+# illustrate descriptor working at the class level, A.data.in_()
+stmt = (
+    select(A.data, B.data)
+    .join(B)
+    .where(A.data.in_(["some data", "some other data"]))
+)
+
+for row in s.execute(stmt):
+    print(row)
diff --git a/test/ext/mypy/files/composite_props.py b/test/ext/mypy/files/composite_props.py
new file mode 100644 (file)
index 0000000..f92b93c
--- /dev/null
@@ -0,0 +1,60 @@
+from typing import Any
+from typing import Tuple
+
+from sqlalchemy import Column
+from sqlalchemy import Integer
+from sqlalchemy import select
+from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy.orm import composite
+
+Base = declarative_base()
+
+
+class Point:
+    def __init__(self, x: int, y: int):
+        self.x = x
+        self.y = y
+
+    def __composite_values__(self) -> Tuple[int, int]:
+        return self.x, self.y
+
+    def __repr__(self) -> str:
+        return "Point(x=%r, y=%r)" % (self.x, self.y)
+
+    def __eq__(self, other: Any) -> bool:
+        return (
+            isinstance(other, Point)
+            and other.x == self.x
+            and other.y == self.y
+        )
+
+    def __ne__(self, other: Any) -> bool:
+        return not self.__eq__(other)
+
+
+class Vertex(Base):
+    __tablename__ = "vertices"
+
+    id = Column(Integer, primary_key=True)
+    x1 = Column(Integer)
+    y1 = Column(Integer)
+    x2 = Column(Integer)
+    y2 = Column(Integer)
+
+    # inferred from right hand side
+    start = composite(Point, x1, y1)
+
+    # taken from left hand side
+    end: Point = composite(Point, x2, y2)
+
+
+v1 = Vertex(start=Point(3, 4), end=Point(5, 6))
+
+# I'm not even sure composites support this but it should work from a
+# typing perspective
+stmt = select(v1).where(Vertex.start.in_([Point(3, 4)]))
+
+p1: Point = v1.start
+p2: Point = v1.end
+
+y3: int = v1.end.y
diff --git a/test/ext/mypy/files/constr_cols_only.py b/test/ext/mypy/files/constr_cols_only.py
new file mode 100644 (file)
index 0000000..cd4da55
--- /dev/null
@@ -0,0 +1,29 @@
+from sqlalchemy import Column
+from sqlalchemy import Integer
+from sqlalchemy import String
+from sqlalchemy.ext.declarative import declarative_base
+
+Base = declarative_base()
+
+
+class A(Base):
+    __tablename__ = "a"
+
+    id = Column(Integer, primary_key=True)
+    data = Column(String)
+    x = Column(Integer)
+    y = Column(Integer)
+
+
+a1 = A(data="d", x=5, y=4)
+
+
+# EXPECTED_MYPY: Argument "data" to "A" has incompatible type "int"; expected "Optional[str]" # noqa
+a2 = A(data=5)
+
+# EXPECTED_MYPY: Unexpected keyword argument "nonexistent" for "A"
+a3 = A(nonexistent="hi")
+
+print(a1)
+print(a2)
+print(a3)
diff --git a/test/ext/mypy/files/dataclasses_workaround.py b/test/ext/mypy/files/dataclasses_workaround.py
new file mode 100644 (file)
index 0000000..f8ee387
--- /dev/null
@@ -0,0 +1,66 @@
+from __future__ import annotations
+
+from dataclasses import dataclass
+from dataclasses import field
+from typing import List
+from typing import Optional
+from typing import TYPE_CHECKING
+
+from sqlalchemy import Column
+from sqlalchemy import ForeignKey
+from sqlalchemy import Integer
+from sqlalchemy import select
+from sqlalchemy import String
+from sqlalchemy import Table
+from sqlalchemy.orm import registry
+from sqlalchemy.orm import relationship
+
+mapper_registry: registry = registry()
+
+
+@mapper_registry.mapped
+@dataclass
+class User:
+    __table__ = Table(
+        "user",
+        mapper_registry.metadata,
+        Column("id", Integer, primary_key=True),
+        Column("name", String(50)),
+        Column("fullname", String(50)),
+        Column("nickname", String(12)),
+    )
+    id: int = field(init=False)
+    name: Optional[str] = None
+    fullname: Optional[str] = None
+    nickname: Optional[str] = None
+    addresses: List[Address] = field(default_factory=list)
+
+    if TYPE_CHECKING:
+        _mypy_mapped_attrs = [id, name, fullname, nickname, addresses]
+
+    __mapper_args__ = {  # type: ignore
+        "properties": {"addresses": relationship("Address")}
+    }
+
+
+@mapper_registry.mapped
+@dataclass
+class Address:
+    __table__ = Table(
+        "address",
+        mapper_registry.metadata,
+        Column("id", Integer, primary_key=True),
+        Column("user_id", Integer, ForeignKey("user.id")),
+        Column("email_address", String(50)),
+    )
+
+    id: int = field(init=False)
+    user_id: int = field(init=False)
+    email_address: Optional[str] = None
+
+    if TYPE_CHECKING:
+        _mypy_mapped_attrs = [id, user_id, email_address]
+
+
+stmt = select(User.name).where(User.id.in_([1, 2, 3]))
+stmt = select(Address).where(Address.email_address.contains(["foo"]))
diff --git a/test/ext/mypy/files/decl_attrs_one.py b/test/ext/mypy/files/decl_attrs_one.py
new file mode 100644 (file)
index 0000000..1f2261c
--- /dev/null
@@ -0,0 +1,37 @@
+from sqlalchemy import Column
+from sqlalchemy import Integer
+from sqlalchemy import String
+from sqlalchemy.orm import registry
+from sqlalchemy.sql.schema import ForeignKey
+from sqlalchemy.sql.schema import MetaData
+from sqlalchemy.sql.schema import Table
+
+
+reg: registry = registry()
+
+
+@reg.mapped
+class Foo:
+    __tablename__ = "foo"
+    id: int = Column(Integer(), primary_key=True)
+    name: str = Column(String)
+
+
+@reg.mapped
+class Bar(Foo):
+    __tablename__ = "bar"
+    id: int = Column(ForeignKey("foo.id"), primary_key=True)
+
+
+@reg.mapped
+class Bat(Foo):
+    pass
+
+
+m1: MetaData = reg.metadata
+
+t1: Table = Foo.__table__
+
+t2: Table = Bar.__table__
+
+t3: Table = Bat.__table__
diff --git a/test/ext/mypy/files/decl_attrs_two.py b/test/ext/mypy/files/decl_attrs_two.py
new file mode 100644 (file)
index 0000000..a20af49
--- /dev/null
@@ -0,0 +1,39 @@
+from sqlalchemy import Column
+from sqlalchemy import Integer
+from sqlalchemy import String
+from sqlalchemy.orm import declarative_base
+from sqlalchemy.orm import registry
+from sqlalchemy.sql.schema import ForeignKey
+from sqlalchemy.sql.schema import MetaData
+from sqlalchemy.sql.schema import Table
+
+
+Base = declarative_base()
+
+
+class Foo(Base):
+    __tablename__ = "foo"
+    id: int = Column(Integer(), primary_key=True)
+    name: str = Column(String)
+
+
+class Bar(Foo):
+    __tablename__ = "bar"
+    id: int = Column(ForeignKey("foo.id"), primary_key=True)
+
+
+class Bat(Foo):
+    pass
+
+
+m0: MetaData = Base.metadata
+r0: registry = Base.registry
+
+t1: Table = Foo.__table__
+m1: MetaData = Foo.metadata
+
+t2: Table = Bar.__table__
+m2: MetaData = Bar.metadata
+
+t3: Table = Bat.__table__
+m3: MetaData = Bat.metadata
diff --git a/test/ext/mypy/files/decl_base_subclass_one.py b/test/ext/mypy/files/decl_base_subclass_one.py
new file mode 100644 (file)
index 0000000..abe28a4
--- /dev/null
@@ -0,0 +1,30 @@
+from sqlalchemy import Column
+from sqlalchemy import Integer
+from sqlalchemy import String
+from sqlalchemy.orm import declarative_base
+
+
+class _Base:
+    updated_at = Column(Integer)
+
+
+Base = declarative_base(cls=_Base)
+
+
+class Foo(Base):
+    __tablename__ = "foo"
+    id: int = Column(Integer(), primary_key=True)
+    name: str = Column(String)
+
+
+class Bar(Base):
+    __tablename__ = "bar"
+    id = Column(Integer(), primary_key=True)
+    num = Column(Integer)
+
+
+Foo.updated_at.in_([1, 2, 3])
+
+f1 = Foo(name="name", updated_at=5)
+
+b1 = Bar(num=5, updated_at=6)
diff --git a/test/ext/mypy/files/decl_base_subclass_two.py b/test/ext/mypy/files/decl_base_subclass_two.py
new file mode 100644 (file)
index 0000000..78b7a9b
--- /dev/null
@@ -0,0 +1,73 @@
+from typing import List
+from typing import Optional
+
+from sqlalchemy import Column
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import registry
+from sqlalchemy.orm import relationship
+from sqlalchemy.orm.decl_api import declared_attr
+from sqlalchemy.sql.schema import ForeignKey
+from sqlalchemy.sql.sqltypes import Integer
+from sqlalchemy.sql.sqltypes import String
+
+reg: registry = registry()
+
+
+@reg.mapped
+class User:
+    __tablename__ = "user"
+
+    id = Column(Integer, primary_key=True)
+    name = Column(String(50))
+
+    name3 = Column(String(50))
+
+    addresses: List["Address"] = relationship("Address")
+
+
+@reg.mapped
+class SubUser(User):
+    __tablename__ = "subuser"
+
+    id: int = Column(ForeignKey("user.id"), primary_key=True)
+
+    @declared_attr
+    def name(cls) -> Column[String]:
+        return Column(String(50))
+
+    @declared_attr
+    def name2(cls) -> Mapped[Optional[str]]:
+        return Column(String(50))
+
+    @declared_attr
+    def name3(cls) -> Mapped[str]:
+        return Column(String(50))
+
+    subname = Column(String)
+
+
+@reg.mapped
+class Address:
+    __tablename__ = "address"
+
+    id = Column(Integer, primary_key=True)
+    user_id: int = Column(ForeignKey("user.id"))
+    email = Column(String(50))
+
+    user = relationship(User, uselist=False)
+
+
+s1 = SubUser()
+
+# EXPECTED_MYPY: Incompatible types in assignment (expression has type "Optional[str]", variable has type "str" # noqa
+x1: str = s1.name
+
+# EXPECTED_MYPY: Incompatible types in assignment (expression has type "Optional[str]", variable has type "str") # noqa
+x2: str = s1.name2
+
+x3: str = s1.name3
+
+u1 = User()
+
+# EXPECTED_MYPY: Incompatible types in assignment (expression has type "Optional[str]", variable has type "str") # noqa
+x4: str = u1.name3
diff --git a/test/ext/mypy/files/declarative_base_dynamic.py b/test/ext/mypy/files/declarative_base_dynamic.py
new file mode 100644 (file)
index 0000000..eee9b31
--- /dev/null
@@ -0,0 +1,31 @@
+from sqlalchemy import Column
+from sqlalchemy import Integer
+from sqlalchemy import String
+from sqlalchemy.ext.declarative import declarative_base
+
+# this is actually in orm now
+
+
+Base = declarative_base()
+
+
+class Foo(Base):
+    __tablename__ = "foo"
+    id: int = Column(Integer(), primary_key=True)
+    name: str = Column(String)
+    other_name: str = Column(String(50))
+
+
+f1 = Foo()
+
+val: int = f1.id
+
+p: str = f1.name
+
+Foo.id.property
+
+# TODO: getitem checker?  this should raise
+Foo.id.property_nonexistent
+
+
+f2 = Foo(name="some name", other_name="some other name")
diff --git a/test/ext/mypy/files/declarative_base_explicit.py b/test/ext/mypy/files/declarative_base_explicit.py
new file mode 100644 (file)
index 0000000..b1b02bf
--- /dev/null
@@ -0,0 +1,30 @@
+from sqlalchemy import Column
+from sqlalchemy import Integer
+from sqlalchemy import String
+from sqlalchemy.orm import registry
+from sqlalchemy.orm.decl_api import DeclarativeMeta
+
+
+class Base(metaclass=DeclarativeMeta):
+    __abstract__ = True
+    registry = registry()
+    metadata = registry.metadata
+
+
+class Foo(Base):
+    __tablename__ = "foo"
+    id: int = Column(Integer(), primary_key=True)
+    name: str = Column(String)
+    other_name: str = Column(String(50))
+
+
+f1 = Foo()
+
+val: int = f1.id
+
+p: str = f1.name
+
+Foo.id.property
+
+# TODO: getitem checker?  this should raise
+Foo.id.property_nonexistent
diff --git a/test/ext/mypy/files/ensure_descriptor_type_fully_inferred.py b/test/ext/mypy/files/ensure_descriptor_type_fully_inferred.py
new file mode 100644 (file)
index 0000000..1a89041
--- /dev/null
@@ -0,0 +1,20 @@
+from sqlalchemy import Column
+from sqlalchemy import Integer
+from sqlalchemy import String
+from sqlalchemy.orm import registry
+
+reg: registry = registry()
+
+
+@reg.mapped
+class User:
+    __tablename__ = "user"
+
+    id = Column(Integer(), primary_key=True)
+    name = Column(String, nullable=False)
+
+
+u1 = User()
+
+# EXPECTED_MYPY: Incompatible types in assignment (expression has type "Optional[str]", variable has type "str")  # noqa E501
+p: str = u1.name
diff --git a/test/ext/mypy/files/ensure_descriptor_type_noninferred.py b/test/ext/mypy/files/ensure_descriptor_type_noninferred.py
new file mode 100644 (file)
index 0000000..b1dabe8
--- /dev/null
@@ -0,0 +1,23 @@
+from typing import Optional
+
+from sqlalchemy import Column
+from sqlalchemy import Integer
+from sqlalchemy import String
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import registry
+
+reg: registry = registry()
+
+
+@reg.mapped
+class User:
+    __tablename__ = "user"
+
+    id = Column(Integer(), primary_key=True)
+    name: Mapped[Optional[str]] = Column(String)
+
+
+u1 = User()
+
+# EXPECTED_MYPY: Incompatible types in assignment (expression has type "Optional[str]", variable has type "Optional[int]") # noqa E501
+p: Optional[int] = u1.name
diff --git a/test/ext/mypy/files/ensure_descriptor_type_semiinferred.py b/test/ext/mypy/files/ensure_descriptor_type_semiinferred.py
new file mode 100644 (file)
index 0000000..2154ff0
--- /dev/null
@@ -0,0 +1,26 @@
+from typing import Optional
+
+from sqlalchemy import Column
+from sqlalchemy import Integer
+from sqlalchemy import String
+from sqlalchemy.orm import registry
+
+reg: registry = registry()
+
+
+@reg.mapped
+class User:
+    __tablename__ = "user"
+
+    id = Column(Integer(), primary_key=True)
+
+    # we will call this "semi-inferred", since the real
+    # type will be Mapped[Optional[str]], but the Optional[str]
+    # which is not inferred, we use that to create it
+    name: Optional[str] = Column(String)
+
+
+u1 = User()
+
+# EXPECTED_MYPY: Incompatible types in assignment (expression has type "Optional[str]", variable has type "str")  # noqa E501
+p: str = u1.name
diff --git a/test/ext/mypy/files/imperative_table.py b/test/ext/mypy/files/imperative_table.py
new file mode 100644 (file)
index 0000000..0548a79
--- /dev/null
@@ -0,0 +1,37 @@
+import datetime
+from typing import Optional
+
+from sqlalchemy import Column
+from sqlalchemy import DateTime
+from sqlalchemy import Integer
+from sqlalchemy import String
+from sqlalchemy import Table
+from sqlalchemy.orm import declarative_base
+from sqlalchemy.orm import Mapped
+
+
+Base = declarative_base()
+
+
+class MyMappedClass(Base):
+    __table_ = Table(
+        "some_table",
+        Base.metadata,
+        Column("id", Integer, primary_key=True),
+        Column("data", String(50)),
+        Column("created_at", DateTime),
+    )
+
+    id: Mapped[int]
+    data: Mapped[Optional[str]]
+    created_at: Mapped[datetime.datetime]
+
+
+m1 = MyMappedClass(id=5, data="string", created_at=datetime.datetime.now())
+
+# EXPECTED_MYPY: Argument "created_at" to "MyMappedClass" has incompatible type "int"; expected "datetime" # noqa
+m2 = MyMappedClass(id=5, data="string", created_at=12)
+
+
+# EXPECTED_MYPY: Incompatible types in assignment (expression has type "Optional[str]", variable has type "str") # noqa
+x: str = MyMappedClass().data
diff --git a/test/ext/mypy/files/inspect.py b/test/ext/mypy/files/inspect.py
new file mode 100644 (file)
index 0000000..c67b515
--- /dev/null
@@ -0,0 +1,43 @@
+"""
+test inspect()
+
+however this is not really working
+
+"""
+from sqlalchemy import Column
+from sqlalchemy import create_engine
+from sqlalchemy import inspect
+from sqlalchemy import Integer
+from sqlalchemy import String
+from sqlalchemy.engine.reflection import Inspector
+from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy.orm import Mapper
+
+Base = declarative_base()
+
+
+class A(Base):
+    __tablename__ = "a"
+
+    id = Column(Integer, primary_key=True)
+    data = Column(String)
+
+
+a1 = A(data="d")
+
+e = create_engine("sqlite://")
+
+# TODO: I can't get these to work, pylance and mypy both don't want
+# to accommodate for different types for the first argument
+
+t: bool = inspect(a1).transient
+
+m: Mapper = inspect(A)
+
+inspect(e).get_table_names()
+
+i: Inspector = inspect(e)
+
+
+with e.connect() as conn:
+    inspect(conn).get_table_names()
diff --git a/test/ext/mypy/files/invalid_noninferred_lh_type.py b/test/ext/mypy/files/invalid_noninferred_lh_type.py
new file mode 100644 (file)
index 0000000..5084de7
--- /dev/null
@@ -0,0 +1,15 @@
+from sqlalchemy import Column
+from sqlalchemy import Integer
+from sqlalchemy import String
+from sqlalchemy.orm import registry
+
+reg: registry = registry()
+
+
+@reg.mapped
+class User:
+    __tablename__ = "user"
+
+    id = Column(Integer(), primary_key=True)
+    # EXPECTED: Left hand assignment 'name: "int"' not compatible with ORM mapped expression # noqa E501
+    name: int = Column(String())
diff --git a/test/ext/mypy/files/mapped_attr_assign.py b/test/ext/mypy/files/mapped_attr_assign.py
new file mode 100644 (file)
index 0000000..06bc24d
--- /dev/null
@@ -0,0 +1,58 @@
+"""Test patterns that can be used for assignment of mapped attributes
+after the mapping is complete
+
+
+"""
+from typing import Optional
+
+from sqlalchemy import Column
+from sqlalchemy import ForeignKey
+from sqlalchemy import inspect
+from sqlalchemy import Integer
+from sqlalchemy import select
+from sqlalchemy import String
+from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy.orm import column_property
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import relationship
+
+Base = declarative_base()
+
+
+class B(Base):
+    __tablename__ = "b"
+    id = Column(Integer, primary_key=True)
+    a_id: int = Column(ForeignKey("a.id"))
+
+    # to attach attrs after the fact, declare them with Mapped
+    # on the class...
+    data: Mapped[str]
+
+    a: Mapped[Optional["A"]]
+
+
+class A(Base):
+    __tablename__ = "a"
+
+    id = Column(Integer, primary_key=True)
+    data = Column(String)
+    bs = relationship(B, uselist=True, back_populates="a")
+
+
+# There's no way to intercept the __setattr__() from the metaclass
+# here, and also when @reg.mapped() is used there is no metaclass.
+# so have them do it the old way
+inspect(B).add_property(
+    "data",
+    column_property(select(A.data).where(A.id == B.a_id).scalar_subquery()),
+)
+inspect(B).add_property("a", relationship(A))
+
+
+# the constructor will pick them up
+a1 = A()
+b1 = B(data="b", a=a1)
+
+# and it's mapped
+B.data.in_(["x", "y"])
+B.a.any()
diff --git a/test/ext/mypy/files/mixin_one.py b/test/ext/mypy/files/mixin_one.py
new file mode 100644 (file)
index 0000000..a471edf
--- /dev/null
@@ -0,0 +1,41 @@
+from sqlalchemy import Column
+from sqlalchemy import Integer
+from sqlalchemy import String
+from sqlalchemy.orm import declarative_base
+from sqlalchemy.orm import registry
+
+
+reg: registry = registry()
+
+# TODO: also reg.as_declarative_base()
+Base = declarative_base()
+
+
+class HasUpdatedAt:
+    updated_at = Column(Integer)
+
+
+@reg.mapped
+class Foo(HasUpdatedAt):
+    __tablename__ = "foo"
+    id: int = Column(Integer(), primary_key=True)
+    name: str = Column(String)
+
+
+class Bar(HasUpdatedAt, Base):
+    __tablename__ = "bar"
+    id = Column(Integer(), primary_key=True)
+    num = Column(Integer)
+
+
+Foo.updated_at.in_([1, 2, 3])
+Bar.updated_at.in_([1, 2, 3])
+
+f1 = Foo(name="name", updated_at=5)
+
+b1 = Bar(num=5, updated_at=6)
+
+
+# test that we detected this as an unmapped mixin
+# EXPECTED_MYPY: Unexpected keyword argument "updated_at" for "HasUpdatedAt"
+HasUpdatedAt(updated_at=5)
diff --git a/test/ext/mypy/files/mixin_two.py b/test/ext/mypy/files/mixin_two.py
new file mode 100644 (file)
index 0000000..c4dc610
--- /dev/null
@@ -0,0 +1,105 @@
+from typing import Callable
+
+from sqlalchemy import Column
+from sqlalchemy import Integer
+from sqlalchemy import String
+from sqlalchemy.orm import deferred
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import registry
+from sqlalchemy.orm import relationship
+from sqlalchemy.orm.decl_api import declared_attr
+from sqlalchemy.orm.interfaces import MapperProperty
+from sqlalchemy.sql.schema import ForeignKey
+
+
+reg: registry = registry()
+
+
+@reg.mapped
+class C:
+    __tablename__ = "c"
+    id = Column(Integer, primary_key=True)
+
+
+def some_other_decorator(fn: Callable[..., None]) -> Callable[..., None]:
+    return fn
+
+
+class HasAMixin:
+    @declared_attr
+    def a(cls) -> Mapped["A"]:
+        return relationship("A", back_populates="bs")
+
+    # EXPECTED: Can't infer type from @declared_attr on function 'a2';
+    @declared_attr
+    def a2(cls):
+        return relationship("A", back_populates="bs")
+
+    @declared_attr
+    def a3(cls) -> relationship["A"]:
+        return relationship("A", back_populates="bs")
+
+    @declared_attr
+    def c1(cls) -> relationship[C]:
+        return relationship(C, back_populates="bs")
+
+    @declared_attr
+    def c2(cls) -> Mapped[C]:
+        return relationship(C, back_populates="bs")
+
+    @declared_attr
+    def data(cls) -> Column[String]:
+        return Column(String)
+
+    @declared_attr
+    def data2(cls) -> MapperProperty[str]:
+        return deferred(Column(String))
+
+    @some_other_decorator
+    def q(cls) -> None:
+        return None
+
+
+@reg.mapped
+class B(HasAMixin):
+    __tablename__ = "b"
+    id = Column(Integer, primary_key=True)
+    a_id: int = Column(ForeignKey("a.id"))
+    c_id: int = Column(ForeignKey("c.id"))
+
+
+@reg.mapped
+class A:
+    __tablename__ = "a"
+
+    id = Column(Integer, primary_key=True)
+
+    @declared_attr
+    def data(cls) -> Column[String]:
+        return Column(String)
+
+    # EXPECTED: Can't infer type from @declared_attr on function 'data2';
+    @declared_attr
+    def data2(cls):
+        return Column(String)
+
+    bs = relationship(B, uselist=True, back_populates="a")
+
+
+a1 = A(id=1, data="d1", data2="d2")
+
+
+b1 = B(a=A(), a2=A(), c1=C(), c2=C(), data="d1", data2="d2")
+
+# descriptor access as Mapped[<type>]
+B.a.any()
+B.a2.any()
+B.c1.any()
+B.c2.any()
+
+# sanity check against another fn that isn't mapped
+# EXPECTED_MYPY: "Callable[..., None]" has no attribute "any"
+B.q.any()
+
+B.data.in_(["a", "b"])
+B.data2.in_(["a", "b"])
diff --git a/test/ext/mypy/files/other_mapper_props.py b/test/ext/mypy/files/other_mapper_props.py
new file mode 100644 (file)
index 0000000..993e144
--- /dev/null
@@ -0,0 +1,56 @@
+from typing import Optional
+
+from sqlalchemy import Column
+from sqlalchemy import Integer
+from sqlalchemy import String
+from sqlalchemy.orm import column_property
+from sqlalchemy.orm import deferred
+from sqlalchemy.orm import registry
+from sqlalchemy.orm import Session
+from sqlalchemy.orm import synonym
+from sqlalchemy.sql.functions import func
+from sqlalchemy.sql.sqltypes import Text
+
+reg: registry = registry()
+
+
+@reg.mapped
+class User:
+    __tablename__ = "user"
+
+    id = Column(Integer(), primary_key=True)
+    name = Column(String)
+
+    # this gets inferred
+    big_col = deferred(Column(Text))
+
+    # this gets inferred
+    explicit_col = column_property(Column(Integer))
+
+    # EXPECTED: Can't infer type from ORM mapped expression assigned to attribute 'lower_name'; # noqa
+    lower_name = column_property(func.lower(name))
+
+    # EXPECTED: Can't infer type from ORM mapped expression assigned to attribute 'syn_name'; # noqa
+    syn_name = synonym("name")
+
+    # this uses our type
+    lower_name_exp: str = column_property(func.lower(name))
+
+    # this uses our type
+    syn_name_exp: Optional[str] = synonym("name")
+
+
+s = Session()
+
+u1: User = s.get(User, 5)
+
+q1: Optional[str] = u1.big_col
+
+q2: Optional[int] = u1.explicit_col
+
+
+# EXPECTED_MYPY: Incompatible types in assignment (expression has type "str", variable has type "int") # noqa
+x: int = u1.lower_name_exp
+
+# EXPECTED_MYPY: Incompatible types in assignment (expression has type "Optional[str]", variable has type "int") # noqa
+y: int = u1.syn_name_exp
diff --git a/test/ext/mypy/files/plugin_doesnt_break_one.py b/test/ext/mypy/files/plugin_doesnt_break_one.py
new file mode 100644 (file)
index 0000000..19cb2bf
--- /dev/null
@@ -0,0 +1,20 @@
+from sqlalchemy import Column
+from sqlalchemy import Integer
+from sqlalchemy import String
+from sqlalchemy.orm import registry
+
+reg: registry = registry()
+
+
+@reg.mapped
+class Foo:
+    pass
+    id: int = Column(Integer())
+    name: str = Column(String)
+
+
+f1 = Foo()
+
+
+# EXPECTED_MYPY: Name 'u1' is not defined
+p: str = u1.name  # noqa
diff --git a/test/ext/mypy/files/relationship_direct_cls.py b/test/ext/mypy/files/relationship_direct_cls.py
new file mode 100644 (file)
index 0000000..1c4efde
--- /dev/null
@@ -0,0 +1,36 @@
+from typing import List
+from typing import Optional
+
+from sqlalchemy import Column
+from sqlalchemy import ForeignKey
+from sqlalchemy import Integer
+from sqlalchemy import String
+from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy.orm import relationship
+
+Base = declarative_base()
+
+
+class B(Base):
+    __tablename__ = "b"
+    id = Column(Integer, primary_key=True)
+    a_id: int = Column(ForeignKey("a.id"))
+    data = Column(String)
+
+    a: Optional["A"] = relationship("A", back_populates="bs")
+
+
+class A(Base):
+    __tablename__ = "a"
+
+    id = Column(Integer, primary_key=True)
+    data = Column(String)
+    bs = relationship(B, uselist=True, back_populates="a")
+
+
+a1 = A(bs=[B(data="b"), B(data="b")])
+
+x: List[B] = a1.bs
+
+
+b1 = B(a=A())
diff --git a/test/ext/mypy/files/relationship_err1.py b/test/ext/mypy/files/relationship_err1.py
new file mode 100644 (file)
index 0000000..46e7067
--- /dev/null
@@ -0,0 +1,30 @@
+from typing import List
+
+from sqlalchemy import Column
+from sqlalchemy import ForeignKey
+from sqlalchemy import Integer
+from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy.orm import relationship
+
+Base = declarative_base()
+
+
+class B(Base):
+    __tablename__ = "b"
+    id = Column(Integer, primary_key=True)
+
+    # EXPECTED: Expected Python collection type for collection_class parameter # noqa
+    as_: List["A"] = relationship("A", collection_class=None)
+
+    # EXPECTED: Can't infer type from ORM mapped expression assigned to attribute 'another_as_'; # noqa
+    another_as_ = relationship("A", uselist=True)
+
+
+class A(Base):
+    __tablename__ = "a"
+
+    id = Column(Integer, primary_key=True)
+    b_id: int = Column(ForeignKey("b.id"))
+
+    # EXPECTED: Sending uselist=False and collection_class at the same time does not make sense # noqa
+    b: B = relationship(B, uselist=False, collection_class=set)
diff --git a/test/ext/mypy/files/relationship_err2.py b/test/ext/mypy/files/relationship_err2.py
new file mode 100644 (file)
index 0000000..4057bae
--- /dev/null
@@ -0,0 +1,32 @@
+from typing import Set
+
+from sqlalchemy import Column
+from sqlalchemy import ForeignKey
+from sqlalchemy import Integer
+from sqlalchemy import String
+from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy.orm import relationship
+
+Base = declarative_base()
+
+
+class B(Base):
+    __tablename__ = "b"
+    id = Column(Integer, primary_key=True)
+    a_id: int = Column(ForeignKey("a.id"))
+    data = Column(String)
+
+
+class A(Base):
+    __tablename__ = "a"
+
+    id = Column(Integer, primary_key=True)
+    data = Column(String)
+    bs = relationship(B, uselist=True)
+
+
+# EXPECTED_MYPY: List item 1 has incompatible type "A"; expected "B"
+a1 = A(bs=[B(data="b"), A()])
+
+# EXPECTED_MYPY: Incompatible types in assignment (expression has type "List[B]", variable has type "Set[B]") # noqa
+x: Set[B] = a1.bs
diff --git a/test/ext/mypy/files/relationship_err3.py b/test/ext/mypy/files/relationship_err3.py
new file mode 100644 (file)
index 0000000..aa76ae1
--- /dev/null
@@ -0,0 +1,35 @@
+from typing import Optional
+from typing import Set
+
+from sqlalchemy import Column
+from sqlalchemy import ForeignKey
+from sqlalchemy import Integer
+from sqlalchemy import String
+from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy.orm import relationship
+
+Base = declarative_base()
+
+
+class B(Base):
+    __tablename__ = "b"
+    id = Column(Integer, primary_key=True)
+    a_id: int = Column(ForeignKey("a.id"))
+    data = Column(String)
+    a: Optional["A"] = relationship("A", back_populates="bs")
+
+
+class A(Base):
+    __tablename__ = "a"
+
+    id = Column(Integer, primary_key=True)
+    data = Column(String)
+    # EXPECTED: Left hand assignment 'bs: "Set[B]"' not compatible with ORM mapped expression of type "Mapped[List[B]]" # noqa
+    bs: Set[B] = relationship(B, uselist=True, back_populates="a")
+
+    # EXPECTED: Left hand assignment 'another_bs: "Set[B]"' not compatible with ORM mapped expression of type "Mapped[B]" # noqa
+    another_bs: Set[B] = relationship(B, viewonly=True)
+
+
+# EXPECTED_MYPY: Argument "a" to "B" has incompatible type "str"; expected "Optional[A]" # noqa
+b1 = B(a="not an a")
diff --git a/test/ext/mypy/files/typeless_fk_col_cant_infer.py b/test/ext/mypy/files/typeless_fk_col_cant_infer.py
new file mode 100644 (file)
index 0000000..beb4a7a
--- /dev/null
@@ -0,0 +1,25 @@
+from sqlalchemy import Column
+from sqlalchemy import ForeignKey
+from sqlalchemy import Integer
+from sqlalchemy import String
+from sqlalchemy.orm import registry
+
+reg: registry = registry()
+
+
+@reg.mapped
+class User:
+    __tablename__ = "user"
+
+    id = Column(Integer(), primary_key=True)
+    name = Column(String)
+
+
+@reg.mapped
+class Address:
+    __tablename__ = "address"
+
+    id = Column(Integer, primary_key=True)
+    # EXPECTED: Can't infer type from ORM mapped expression assigned to attribute 'user_id';  # noqa E501
+    user_id = Column(ForeignKey("user.id"))
+    email_address = Column(String)
diff --git a/test/ext/mypy/files/typing_err1.py b/test/ext/mypy/files/typing_err1.py
new file mode 100644 (file)
index 0000000..f262cd5
--- /dev/null
@@ -0,0 +1,31 @@
+from sqlalchemy import Column
+from sqlalchemy import Integer
+from sqlalchemy.orm import registry
+from sqlalchemy.types import TypeEngine
+
+
+# EXPECTED_MYPY: Missing type parameters for generic type "TypeEngine"
+class MyCustomType(TypeEngine):
+    pass
+
+
+# correct way
+class MyOtherCustomType(TypeEngine[str]):
+    pass
+
+
+reg: registry = registry()
+
+
+@reg.mapped
+class Foo:
+    id: int = Column(Integer())
+
+    name = Column(MyCustomType())
+    other_name: str = Column(MyCustomType())
+
+    name2 = Column(MyOtherCustomType())
+    other_name2: str = Column(MyOtherCustomType())
+
+
+Foo(name="x", other_name="x", name2="x", other_name2="x")
diff --git a/test/ext/mypy/files/typing_err2.py b/test/ext/mypy/files/typing_err2.py
new file mode 100644 (file)
index 0000000..adc50f9
--- /dev/null
@@ -0,0 +1,37 @@
+from sqlalchemy import Column
+from sqlalchemy import Integer
+from sqlalchemy import String
+from sqlalchemy.orm import declared_attr
+from sqlalchemy.orm import registry
+from sqlalchemy.orm import relationship
+
+reg: registry = registry()
+
+
+@reg.mapped
+class Foo:
+    id: int = Column(Integer())
+
+    # EXPECTED: Can't infer type from @declared_attr on function 'name'; # noqa
+    @declared_attr
+    # EXPECTED: Column type should be a TypeEngine subclass not 'builtins.str'
+    def name(cls) -> Column[str]:
+        return Column(String)
+
+    # EXPECTED: Left hand assignment 'other_name: "Column[String]"' not compatible with ORM mapped expression of type "Mapped[str]" # noqa
+    other_name: Column[String] = Column(String)
+
+    # EXPECTED: Can't infer type from @declared_attr on function 'third_name';
+    @declared_attr
+    # EXPECTED_MYPY: Missing type parameters for generic type "Column"
+    def third_name(cls) -> Column:
+        return Column(String)
+
+    # EXPECTED: Can't infer type from @declared_attr on function 'some_relationship' # noqa
+    @declared_attr
+    # EXPECTED_MYPY: Missing type parameters for generic type "relationship"
+    def some_relationship(cls) -> relationship:
+        return relationship("Bar")
+
+
+Foo(name="x")
diff --git a/test/ext/mypy/files/typing_err3.py b/test/ext/mypy/files/typing_err3.py
new file mode 100644 (file)
index 0000000..5383f89
--- /dev/null
@@ -0,0 +1,54 @@
+"""Test that the right-hand expressions we normally "replace" are actually
+type checked.
+
+"""
+from typing import List
+
+from sqlalchemy import Column
+from sqlalchemy import ForeignKey
+from sqlalchemy import Integer
+from sqlalchemy import String
+from sqlalchemy.orm import declarative_base
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import relationship
+from sqlalchemy.orm.decl_api import declared_attr
+
+
+Base = declarative_base()
+
+
+class User(Base):
+    __tablename__ = "user"
+
+    id = Column(Integer, primary_key=True)
+
+    # EXPECTED_MYPY: Unexpected keyword argument "wrong_arg" for "RelationshipProperty" # noqa
+    addresses: Mapped[List["Address"]] = relationship(
+        "Address", wrong_arg="imwrong"
+    )
+
+
+class SubUser(User):
+    __tablename__ = "subuser"
+
+    id: int = Column(Integer, ForeignKey("user.id"), primary_key=True)
+
+
+class Address(Base):
+    __tablename__ = "address"
+
+    id: int = Column(Integer, primary_key=True)
+
+    user_id: int = Column(ForeignKey("user.id"))
+
+    @declared_attr
+    def email_address(cls) -> Column[String]:
+        # EXPECTED_MYPY: No overload variant of "Column" matches argument type "bool" # noqa
+        return Column(True)
+
+    @declared_attr
+    # EXPECTED_MYPY: Invalid type comment or annotation
+    def thisisweird(cls) -> Column(String):
+        # with the bad annotation mypy seems to not go into the
+        # function body
+        return Column(False)
diff --git a/test/ext/mypy/test_mypy_plugin_py3k.py b/test/ext/mypy/test_mypy_plugin_py3k.py
new file mode 100644 (file)
index 0000000..bf82aaa
--- /dev/null
@@ -0,0 +1,115 @@
+import os
+import re
+import tempfile
+
+from sqlalchemy import testing
+from sqlalchemy.testing import eq_
+from sqlalchemy.testing import fixtures
+
+
+class MypyPluginTest(fixtures.TestBase):
+    __requires__ = ("sqlalchemy2_stubs",)
+
+    @testing.fixture(scope="class")
+    def cachedir(self):
+        with tempfile.TemporaryDirectory() as cachedir:
+            with open(
+                os.path.join(cachedir, "sqla_mypy_config.cfg"), "w"
+            ) as config_file:
+                config_file.write(
+                    """
+                    [mypy]\n
+                    plugins = sqlalchemy.ext.mypy.plugin\n
+                    """
+                )
+            with open(
+                os.path.join(cachedir, "plain_mypy_config.cfg"), "w"
+            ) as config_file:
+                config_file.write(
+                    """
+                    [mypy]\n
+                    """
+                )
+            yield cachedir
+
+    @testing.fixture()
+    def mypy_runner(self, cachedir):
+        from mypy import api
+
+        def run(filename, use_plugin=True):
+            path = os.path.join(os.path.dirname(__file__), "files", filename)
+
+            args = [
+                "--strict",
+                "--raise-exceptions",
+                "--cache-dir",
+                cachedir,
+                "--config-file",
+                os.path.join(
+                    cachedir,
+                    "sqla_mypy_config.cfg"
+                    if use_plugin
+                    else "plain_mypy_config.cfg",
+                ),
+            ]
+
+            args.append(path)
+
+            return api.run(args)
+
+        return run
+
+    def _file_combinations():
+        path = os.path.join(os.path.dirname(__file__), "files")
+        return [f for f in os.listdir(path) if f.endswith(".py")]
+
+    @testing.combinations(
+        *[(filename,) for filename in _file_combinations()],
+        argnames="filename"
+    )
+    def test_mypy(self, mypy_runner, filename):
+        path = os.path.join(os.path.dirname(__file__), "files", filename)
+
+        use_plugin = True
+
+        expected_errors = []
+        with open(path) as file_:
+            for num, line in enumerate(file_, 1):
+                if line.startswith("# NOPLUGINS"):
+                    use_plugin = False
+                    continue
+
+                m = re.match(r"\s*# EXPECTED(_MYPY)?: (.+)", line)
+                if m:
+                    is_mypy = bool(m.group(1))
+                    expected_msg = m.group(2)
+                    expected_msg = re.sub(r"# noqa ?.*", "", m.group(2))
+                    expected_errors.append(
+                        (num, is_mypy, expected_msg.strip())
+                    )
+
+        result = mypy_runner(filename, use_plugin=use_plugin)
+
+        if expected_errors:
+            eq_(result[2], 1)
+
+            print(result[0])
+
+            errors = []
+            for e in result[0].split("\n"):
+                if re.match(r".+\.py:\d+: error: .*", e):
+                    errors.append(e)
+
+            for num, is_mypy, msg in expected_errors:
+                prefix = "[SQLAlchemy Mypy plugin] " if not is_mypy else ""
+                for idx, errmsg in enumerate(errors):
+                    if f"{filename}:{num + 1}: error: {prefix}{msg}" in errmsg:
+                        break
+                else:
+                    continue
+                del errors[idx]
+
+            assert not errors, "errors remain: %s" % "\n".join(errors)
+
+        else:
+            eq_(result[2], 0, msg=result[0])
diff --git a/tox.ini b/tox.ini
index d57b95431b9d8ff21f9dd32aa7c9b2df611339a6..96604887756c5fd39848ad0deaf11cf792a9deb7 100644 (file)
--- a/tox.ini
+++ b/tox.ini
@@ -113,6 +113,19 @@ commands=
   {env:BASECOMMAND} {env:WORKERS} {env:SQLITE:} {env:POSTGRESQL:} {env:EXTRA_PG_DRIVERS:} {env:MYSQL:} {env:EXTRA_MYSQL_DRIVERS:} {env:ORACLE:} {env:MSSQL:} {env:BACKENDONLY:} {env:IDENTS:} {env:MEMUSAGE:} {env:COVERAGE:} {posargs}
   oracle,mssql,sqlite_file: python reap_dbs.py db_idents.txt
 
+
+[testenv:mypy]
+deps=
+     pytest>=6.2; python_version >= '3'
+     pytest-xdist
+     greenlet != 0.4.17
+     mock; python_version < '3.3'
+     importlib_metadata; python_version < '3.8'
+     mypy
+     git+https://github.com/sqlalchemy/sqlalchemy2-stubs
+commands =
+    pytest test/ext/mypy/test_mypy_plugin_py3k.py {posargs}
+
 # thanks to https://julien.danjou.info/the-best-flake8-extensions/
 [testenv:pep8]
 basepython = python3