From 606096ae01c71298da4d3fda3f62c730d9985105 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Fri, 26 Mar 2021 19:45:29 -0400 Subject: [PATCH] Adjust for mypy incremental behaviors Applied a series of refactorings and fixes to accommodate for Mypy "incremental" mode across multiple files, which previously was not taken into account. In this mode the Mypy plugin has to accommodate Python datatypes expressed in other files coming in with less information than they have on a direct run. Additionally, a new decorator :func:`_orm.declarative_mixin` is added, which is necessary for the Mypy plugin to be able to definifitely identify a Declarative mixin class that is otherwise not used inside a particular Python file. discussion: With incremental / deserialized mypy runs, it appears that when we look at a base class that comes from another file, cls.info is set to a special undefined node that matches CLASSDEF_NO_INFO, and we otherwise can't touch it without crashing. Additionally, sometimes cls.defs.body is present but empty. However, it appears that both of these cases can be sidestepped, first by doing a lookup() for the type name where we get a SymbolTableNode that then has the TypeInfo we wanted when we tried touching cls.info, and then however we got the TypeInfo, if cls.defs.body is empty we can just look in the names to get at the symbols for that class; we just can't access AssignmentStmts, but that's fine because we just need the information for classes we aren't actually type checking. This work also revealed there's no easy way to detect a mixin class so we just create a new decorator to mark that. will make code look better in any case. Fixes: #6147 Change-Id: Ia8fac8acfeec931d8f280491cffc5c6cb4a1204e --- doc/build/changelog/unreleased_14/6147.rst | 19 + doc/build/orm/declarative_mixins.rst | 70 +- doc/build/orm/extensions/mypy.rst | 33 +- doc/build/orm/mapping_api.rst | 2 + lib/sqlalchemy/ext/mypy/apply.py | 215 +++++ lib/sqlalchemy/ext/mypy/decl_class.py | 771 +++--------------- lib/sqlalchemy/ext/mypy/infer.py | 398 +++++++++ lib/sqlalchemy/ext/mypy/names.py | 8 + lib/sqlalchemy/ext/mypy/plugin.py | 31 +- lib/sqlalchemy/ext/mypy/util.py | 62 +- lib/sqlalchemy/orm/__init__.py | 1 + lib/sqlalchemy/orm/decl_api.py | 42 + test/ext/mypy/files/mixin_three.py | 33 + .../mypy/incremental/ticket_6147/__init__.py | 0 test/ext/mypy/incremental/ticket_6147/base.py | 3 + test/ext/mypy/incremental/ticket_6147/one.py | 13 + .../incremental/ticket_6147/patch1.testpatch | 19 + .../incremental/ticket_6147/patch2.testpatch | 38 + test/ext/mypy/test_mypy_plugin_py3k.py | 17 +- test/orm/declarative/test_mixin.py | 34 + 20 files changed, 1122 insertions(+), 687 deletions(-) create mode 100644 doc/build/changelog/unreleased_14/6147.rst create mode 100644 lib/sqlalchemy/ext/mypy/apply.py create mode 100644 lib/sqlalchemy/ext/mypy/infer.py create mode 100644 test/ext/mypy/files/mixin_three.py create mode 100644 test/ext/mypy/incremental/ticket_6147/__init__.py create mode 100644 test/ext/mypy/incremental/ticket_6147/base.py create mode 100644 test/ext/mypy/incremental/ticket_6147/one.py create mode 100644 test/ext/mypy/incremental/ticket_6147/patch1.testpatch create mode 100644 test/ext/mypy/incremental/ticket_6147/patch2.testpatch diff --git a/doc/build/changelog/unreleased_14/6147.rst b/doc/build/changelog/unreleased_14/6147.rst new file mode 100644 index 0000000000..325ae0edc7 --- /dev/null +++ b/doc/build/changelog/unreleased_14/6147.rst @@ -0,0 +1,19 @@ +.. change:: + :tags: bug, mypy + :tickets: 6147 + + Applied a series of refactorings and fixes to accommodate for Mypy + "incremental" mode across multiple files, which previously was not taken + into account. In this mode the Mypy plugin has to accommodate Python + datatypes expressed in other files coming in with less information than + they have on a direct run. + + Additionally, a new decorator :func:`_orm.declarative_mixin` is added, + which is necessary for the Mypy plugin to be able to definifitely identify + a Declarative mixin class that is otherwise not used inside a particular + Python file. + + .. seealso:: + + :ref:`mypy_declarative_mixins` + diff --git a/doc/build/orm/declarative_mixins.rst b/doc/build/orm/declarative_mixins.rst index 309c322602..9bb4c782e4 100644 --- a/doc/build/orm/declarative_mixins.rst +++ b/doc/build/orm/declarative_mixins.rst @@ -16,9 +16,11 @@ or :func:`_orm.declarative_base` functions. An example of some commonly mixed-in idioms is below:: + from sqlalchemy.orm import declarative_mixin from sqlalchemy.orm import declared_attr - class MyMixin(object): + @declarative_mixin + class MyMixin: @declared_attr def __tablename__(cls): @@ -37,6 +39,14 @@ as the primary key, a ``__tablename__`` attribute that derives from the name of the class itself, as well as ``__table_args__`` and ``__mapper_args__`` defined by the ``MyMixin`` mixin class. +.. tip:: + + The use of the :func:`_orm.declarative_mixin` class decorator marks a + particular class as providing the service of providing SQLAlchemy declarative + assignments as a mixin for other classes. This decorator is currently only + necessary to provide a hint to the :ref:`Mypy plugin ` that + this class should be handled as part of declarative mappings. + There's no fixed convention over whether ``MyMixin`` precedes ``Base`` or not. Normal Python method resolution rules apply, and the above example would work just as well with:: @@ -61,7 +71,7 @@ using the ``cls`` argument of the :func:`_orm.declarative_base` function:: from sqlalchemy.orm import declared_attr - class Base(object): + class Base: @declared_attr def __tablename__(cls): return cls.__name__.lower() @@ -87,7 +97,8 @@ Mixing in Columns The most basic way to specify a column on a mixin is by simple declaration:: - class TimestampMixin(object): + @declarative_mixin + class TimestampMixin: created_at = Column(DateTime, default=func.now()) class MyModel(TimestampMixin, Base): @@ -124,7 +135,8 @@ patterns common to many classes can be defined as callables:: from sqlalchemy.orm import declared_attr - class ReferenceAddressMixin(object): + @declarative_mixin + class ReferenceAddressMixin: @declared_attr def address_id(cls): return Column(Integer, ForeignKey('address.id')) @@ -143,6 +155,7 @@ referenced by ``__mapper_args__`` to a limited degree, currently by ``polymorphic_on`` and ``version_id_col``; the declarative extension will resolve them at class construction time:: + @declarative_mixin class MyMixin: @declared_attr def type_(cls): @@ -167,7 +180,8 @@ contents. Below is an example which combines a foreign key column and a relationship so that two classes ``Foo`` and ``Bar`` can both be configured to reference a common target class via many-to-one:: - class RefTargetMixin(object): + @declarative_mixin + class RefTargetMixin: @declared_attr def target_id(cls): return Column('target_id', ForeignKey('target.id')) @@ -206,7 +220,8 @@ Declarative will be using as it calls the methods on its own, thus using The canonical example is the primaryjoin condition that depends upon another mixed-in column:: - class RefTargetMixin(object): + @declarative_mixin + class RefTargetMixin: @declared_attr def target_id(cls): return Column('target_id', ForeignKey('target.id')) @@ -228,7 +243,8 @@ actually going to map to our table. The condition above is resolved using a lambda:: - class RefTargetMixin(object): + @declarative_mixin + class RefTargetMixin: @declared_attr def target_id(cls): return Column('target_id', ForeignKey('target.id')) @@ -241,7 +257,8 @@ The condition above is resolved using a lambda:: or alternatively, the string form (which ultimately generates a lambda):: - class RefTargetMixin(object): + @declarative_mixin + class RefTargetMixin: @declared_attr def target_id(cls): return Column('target_id', ForeignKey('target.id')) @@ -266,7 +283,8 @@ etc. ultimately involve references to columns, and therefore, when used with declarative mixins, have the :class:`_orm.declared_attr` requirement so that no reliance on copying is needed:: - class SomethingMixin(object): + @declarative_mixin + class SomethingMixin: @declared_attr def dprop(cls): @@ -279,7 +297,8 @@ The :func:`.column_property` or other construct may refer to other columns from the mixin. These are copied ahead of time before the :class:`_orm.declared_attr` is invoked:: - class SomethingMixin(object): + @declarative_mixin + class SomethingMixin: x = Column(Integer) y = Column(Integer) @@ -306,13 +325,16 @@ target a different type of child object. Below is an string values to an implementing class:: from sqlalchemy import Column, Integer, ForeignKey, String - from sqlalchemy.orm import relationship from sqlalchemy.ext.associationproxy import association_proxy - from sqlalchemy.orm import declarative_base, declared_attr + from sqlalchemy.orm import declarative_base + from sqlalchemy.orm import declarative_mixin + from sqlalchemy.orm import declared_attr + from sqlalchemy.orm import relationship Base = declarative_base() - class HasStringCollection(object): + @declarative_mixin + class HasStringCollection: @declared_attr def _strings(cls): class StringAttribute(Base): @@ -389,8 +411,10 @@ correct answer for each. For example, to create a mixin that gives every class a simple table name based on class name:: + from sqlalchemy.orm import declarative_mixin from sqlalchemy.orm import declared_attr + @declarative_mixin class Tablename: @declared_attr def __tablename__(cls): @@ -411,10 +435,12 @@ Alternatively, we can modify our ``__tablename__`` function to return the effect of those subclasses being mapped with single table inheritance against the parent:: + from sqlalchemy.orm import declarative_mixin from sqlalchemy.orm import declared_attr from sqlalchemy.orm import has_inherited_table - class Tablename(object): + @declarative_mixin + class Tablename: @declared_attr def __tablename__(cls): if has_inherited_table(cls): @@ -443,7 +469,8 @@ invoked for the **base class only** in the hierarchy. Below, only the called ``id``; the mapping will fail on ``Engineer``, which is not given a primary key:: - class HasId(object): + @declarative_mixin + class HasId: @declared_attr def id(cls): return Column('id', Integer, primary_key=True) @@ -466,7 +493,8 @@ foreign key. We can achieve this as a mixin by using the function should be invoked **for each class in the hierarchy**, in *almost* (see warning below) the same way as it does for ``__tablename__``:: - class HasIdMixin(object): + @declarative_mixin + class HasIdMixin: @declared_attr.cascading def id(cls): if has_inherited_table(cls): @@ -509,12 +537,15 @@ define on the class itself. The here to create user-defined collation routines that pull from multiple collections:: + from sqlalchemy.orm import declarative_mixin from sqlalchemy.orm import declared_attr - class MySQLSettings(object): + @declarative_mixin + class MySQLSettings: __table_args__ = {'mysql_engine':'InnoDB'} - class MyOtherMixin(object): + @declarative_mixin + class MyOtherMixin: __table_args__ = {'info':'foo'} class MyModel(MySQLSettings, MyOtherMixin, Base): @@ -536,7 +567,8 @@ To define a named, potentially multicolumn :class:`.Index` that applies to all tables derived from a mixin, use the "inline" form of :class:`.Index` and establish it as part of ``__table_args__``:: - class MyMixin(object): + @declarative_mixin + class MyMixin: a = Column(Integer) b = Column(Integer) diff --git a/doc/build/orm/extensions/mypy.rst b/doc/build/orm/extensions/mypy.rst index e8d85c1bf2..fd3beed0b2 100644 --- a/doc/build/orm/extensions/mypy.rst +++ b/doc/build/orm/extensions/mypy.rst @@ -415,23 +415,32 @@ applied explicitly:: user: Mapped[User] = relationship(User, back_populates="addresses") -Using @declared_attr -^^^^^^^^^^^^^^^^^^^^ - -The :class:`_orm.declared_attr` class allows Declarative mapped attributes -to be declared in class level functions, and is particularly useful when -using `declarative mixins `_. For these functions, -the return type of the function should be annotated using either the -``Mapped[]`` construct or by indicating the exact kind of object returned -by the function:: - - from sqlalchemy.orm.decl_api import declared_attr - +.. _mypy_declarative_mixins: + +Using @declared_attr and Declarative Mixins +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The :class:`_orm.declared_attr` class allows Declarative mapped attributes to +be declared in class level functions, and is particularly useful when using +`declarative mixins `_. For these functions, the return +type of the function should be annotated using either the ``Mapped[]`` +construct or by indicating the exact kind of object returned by the function. +Additionally, "mixin" classes that are not otherwise mapped (i.e. don't extend +from a :func:`_orm.declarative_base` class nor are they mapped with a method +such as :meth:`_orm.registry.mapped`) should be decorated with the +:func:`_orm.declarative_mixin` decorator, which provides a hint to the Mypy +plugin that a particular class intends to serve as a declarative mixin:: + + from sqlalchemy.orm import declared_attr + from sqlalchemy.orm import declarative_mixin + + @declarative_mixin class HasUpdatedAt: @declared_attr def updated_at(cls) -> Column[DateTime]: # uses Column return Column(DateTime) + @declarative_mixin class HasCompany: @declared_attr diff --git a/doc/build/orm/mapping_api.rst b/doc/build/orm/mapping_api.rst index 6aa08114da..5d0b6c0d02 100644 --- a/doc/build/orm/mapping_api.rst +++ b/doc/build/orm/mapping_api.rst @@ -9,6 +9,8 @@ Class Mapping API .. autofunction:: declarative_base +.. autofunction:: declarative_mixin + .. autofunction:: as_declarative .. autoclass:: declared_attr diff --git a/lib/sqlalchemy/ext/mypy/apply.py b/lib/sqlalchemy/ext/mypy/apply.py new file mode 100644 index 0000000000..6442cbc220 --- /dev/null +++ b/lib/sqlalchemy/ext/mypy/apply.py @@ -0,0 +1,215 @@ +# ext/mypy/apply.py +# Copyright (C) 2021 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +from typing import Optional +from typing import Union + +from mypy import nodes +from mypy.nodes import ARG_NAMED_OPT +from mypy.nodes import Argument +from mypy.nodes import AssignmentStmt +from mypy.nodes import CallExpr +from mypy.nodes import ClassDef +from mypy.nodes import MDEF +from mypy.nodes import NameExpr +from mypy.nodes import StrExpr +from mypy.nodes import SymbolTableNode +from mypy.nodes import TempNode +from mypy.nodes import TypeInfo +from mypy.nodes import Var +from mypy.plugin import SemanticAnalyzerPluginInterface +from mypy.plugins.common import add_method_to_class +from mypy.types import AnyType +from mypy.types import Instance +from mypy.types import NoneTyp +from mypy.types import TypeOfAny +from mypy.types import UnionType + +from . import util + + +def _apply_mypy_mapped_attr( + cls: ClassDef, + api: SemanticAnalyzerPluginInterface, + item: Union[NameExpr, StrExpr], + cls_metadata: util.DeclClassApplied, +): + if isinstance(item, NameExpr): + name = item.name + elif isinstance(item, StrExpr): + name = item.value + else: + return + + for stmt in cls.defs.body: + if isinstance(stmt, AssignmentStmt) and stmt.lvalues[0].name == name: + break + else: + util.fail(api, "Can't find mapped attribute {}".format(name), cls) + return + + if stmt.type is None: + util.fail( + api, + "Statement linked from _mypy_mapped_attrs has no " + "typing information", + stmt, + ) + return + + left_hand_explicit_type = stmt.type + + cls_metadata.mapped_attr_names.append((name, left_hand_explicit_type)) + + _apply_type_to_mapped_statement( + api, stmt, stmt.lvalues[0], left_hand_explicit_type, None + ) + + +def _re_apply_declarative_assignments( + cls: ClassDef, + api: SemanticAnalyzerPluginInterface, + cls_metadata: util.DeclClassApplied, +): + """For multiple class passes, re-apply our left-hand side types as mypy + seems to reset them in place. + + """ + mapped_attr_lookup = { + name: typ for name, typ in cls_metadata.mapped_attr_names + } + + descriptor = api.lookup("__sa_Mapped", cls) + for stmt in cls.defs.body: + # for a re-apply, all of our statements are AssignmentStmt; + # @declared_attr calls will have been converted and this + # currently seems to be preserved by mypy (but who knows if this + # will change). + if ( + isinstance(stmt, AssignmentStmt) + and stmt.lvalues[0].name in mapped_attr_lookup + ): + typ = mapped_attr_lookup[stmt.lvalues[0].name] + left_node = stmt.lvalues[0].node + + inst = Instance(descriptor.node, [typ]) + left_node.type = inst + + +def _apply_type_to_mapped_statement( + api: SemanticAnalyzerPluginInterface, + stmt: AssignmentStmt, + lvalue: NameExpr, + left_hand_explicit_type: Optional[Union[Instance, UnionType]], + python_type_for_type: Union[Instance, UnionType], +) -> None: + """Apply the Mapped[] annotation and right hand object to a + declarative assignment statement. + + This converts a Python declarative class statement such as:: + + class User(Base): + # ... + + attrname = Column(Integer) + + To one that describes the final Python behavior to Mypy:: + + class User(Base): + # ... + + attrname : Mapped[Optional[int]] = + + """ + descriptor = api.lookup("__sa_Mapped", stmt) + left_node = lvalue.node + + inst = Instance(descriptor.node, [python_type_for_type]) + + if left_hand_explicit_type is not None: + left_node.type = Instance(descriptor.node, [left_hand_explicit_type]) + else: + lvalue.is_inferred_def = False + left_node.type = inst + + # so to have it skip the right side totally, we can do this: + # stmt.rvalue = TempNode(AnyType(TypeOfAny.special_form)) + + # however, if we instead manufacture a new node that uses the old + # one, then we can still get type checking for the call itself, + # e.g. the Column, relationship() call, etc. + + # rewrite the node as: + # : Mapped[] = + # _sa_Mapped._empty_constructor() + # the original right-hand side is maintained so it gets type checked + # internally + column_descriptor = nodes.NameExpr("__sa_Mapped") + column_descriptor.fullname = "sqlalchemy.orm.Mapped" + mm = nodes.MemberExpr(column_descriptor, "_empty_constructor") + orig_call_expr = stmt.rvalue + stmt.rvalue = CallExpr(mm, [orig_call_expr], [nodes.ARG_POS], ["arg1"]) + + +def _add_additional_orm_attributes( + cls: ClassDef, + api: SemanticAnalyzerPluginInterface, + cls_metadata: util.DeclClassApplied, +) -> None: + """Apply __init__, __table__ and other attributes to the mapped class.""" + + info = util._info_for_cls(cls, api) + if "__init__" not in info.names and cls_metadata.is_mapped: + mapped_attr_names = {n: t for n, t in cls_metadata.mapped_attr_names} + + for mapped_base in cls_metadata.mapped_mro: + base_cls_metadata = util.DeclClassApplied.deserialize( + mapped_base.type.metadata["_sa_decl_class_applied"], api + ) + for n, t in base_cls_metadata.mapped_attr_names: + mapped_attr_names.setdefault(n, t) + + arguments = [] + for name, typ in mapped_attr_names.items(): + if typ is None: + typ = AnyType(TypeOfAny.special_form) + arguments.append( + Argument( + variable=Var(name, typ), + type_annotation=typ, + initializer=TempNode(typ), + kind=ARG_NAMED_OPT, + ) + ) + add_method_to_class(api, cls, "__init__", arguments, NoneTyp()) + + if "__table__" not in info.names and cls_metadata.has_table: + _apply_placeholder_attr_to_class( + api, cls, "sqlalchemy.sql.schema.Table", "__table__" + ) + if cls_metadata.is_mapped: + _apply_placeholder_attr_to_class( + api, cls, "sqlalchemy.orm.mapper.Mapper", "__mapper__" + ) + + +def _apply_placeholder_attr_to_class( + api: SemanticAnalyzerPluginInterface, + cls: ClassDef, + qualified_name: str, + attrname: str, +): + sym = api.lookup_fully_qualified_or_none(qualified_name) + if sym: + assert isinstance(sym.node, TypeInfo) + type_ = Instance(sym.node, []) + else: + type_ = AnyType(TypeOfAny.special_form) + var = Var(attrname) + var.info = cls.info + var.type = type_ + cls.info.names[attrname] = SymbolTableNode(MDEF, var) diff --git a/lib/sqlalchemy/ext/mypy/decl_class.py b/lib/sqlalchemy/ext/mypy/decl_class.py index 46f3cc30e3..a0e272f713 100644 --- a/lib/sqlalchemy/ext/mypy/decl_class.py +++ b/lib/sqlalchemy/ext/mypy/decl_class.py @@ -6,23 +6,14 @@ # the MIT License: http://www.opensource.org/licenses/mit-license.php from typing import Optional -from typing import Sequence -from typing import Tuple from typing import Type -from typing import Union from mypy import nodes -from mypy import types -from mypy.messages import format_type -from mypy.nodes import ARG_NAMED_OPT -from mypy.nodes import Argument from mypy.nodes import AssignmentStmt from mypy.nodes import CallExpr from mypy.nodes import ClassDef from mypy.nodes import Decorator -from mypy.nodes import JsonDict from mypy.nodes import ListExpr -from mypy.nodes import MDEF from mypy.nodes import NameExpr from mypy.nodes import PlaceholderNode from mypy.nodes import RefExpr @@ -32,73 +23,30 @@ from mypy.nodes import TempNode from mypy.nodes import TypeInfo from mypy.nodes import Var from mypy.plugin import SemanticAnalyzerPluginInterface -from mypy.plugins.common import add_method_to_class -from mypy.plugins.common import deserialize_and_fixup_type -from mypy.subtypes import is_subtype from mypy.types import AnyType from mypy.types import Instance -from mypy.types import NoneTyp from mypy.types import NoneType from mypy.types import TypeOfAny from mypy.types import UnboundType from mypy.types import UnionType +from . import apply +from . import infer from . import names from . import util -class DeclClassApplied: - def __init__( - self, - is_mapped: bool, - has_table: bool, - mapped_attr_names: Sequence[Tuple[str, Type]], - mapped_mro: Sequence[Type], - ): - self.is_mapped = is_mapped - self.has_table = has_table - self.mapped_attr_names = mapped_attr_names - self.mapped_mro = mapped_mro - - def serialize(self) -> JsonDict: - return { - "is_mapped": self.is_mapped, - "has_table": self.has_table, - "mapped_attr_names": [ - (name, type_.serialize()) - for name, type_ in self.mapped_attr_names - ], - "mapped_mro": [type_.serialize() for type_ in self.mapped_mro], - } - - @classmethod - def deserialize( - cls, data: JsonDict, api: SemanticAnalyzerPluginInterface - ) -> "DeclClassApplied": - - return DeclClassApplied( - is_mapped=data["is_mapped"], - has_table=data["has_table"], - mapped_attr_names=[ - (name, deserialize_and_fixup_type(type_, api)) - for name, type_ in data["mapped_attr_names"] - ], - mapped_mro=[ - deserialize_and_fixup_type(type_, api) - for type_ in data["mapped_mro"] - ], - ) - - def _scan_declarative_assignments_and_apply_types( cls: ClassDef, api: SemanticAnalyzerPluginInterface, is_mixin_scan=False -) -> Optional[DeclClassApplied]: +) -> Optional[util.DeclClassApplied]: + + info = util._info_for_cls(cls, api) if cls.fullname.startswith("builtins"): return None - elif "_sa_decl_class_applied" in cls.info.metadata: - cls_metadata = DeclClassApplied.deserialize( - cls.info.metadata["_sa_decl_class_applied"], api + elif "_sa_decl_class_applied" in info.metadata: + cls_metadata = util.DeclClassApplied.deserialize( + info.metadata["_sa_decl_class_applied"], api ) # ensure that a class that's mapped is always picked up by @@ -112,30 +60,117 @@ def _scan_declarative_assignments_and_apply_types( # removing our ability to re-scan. but we have the types # here, so lets re-apply them. - _re_apply_declarative_assignments(cls, api, cls_metadata) + apply._re_apply_declarative_assignments(cls, api, cls_metadata) return cls_metadata - cls_metadata = DeclClassApplied(not is_mixin_scan, False, [], []) + cls_metadata = util.DeclClassApplied(not is_mixin_scan, False, [], []) + + if not cls.defs.body: + # when we get a mixin class from another file, the body is + # empty (!) but the names are in the symbol table. so use that. - for stmt in util._flatten_typechecking(cls.defs.body): - if isinstance(stmt, AssignmentStmt): - _scan_declarative_assignment_stmt(cls, api, stmt, cls_metadata) - elif isinstance(stmt, Decorator): - _scan_declarative_decorator_stmt(cls, api, stmt, cls_metadata) + for sym_name, sym in info.names.items(): + _scan_symbol_table_entry(cls, api, sym_name, sym, cls_metadata) + else: + for stmt in util._flatten_typechecking(cls.defs.body): + if isinstance(stmt, AssignmentStmt): + _scan_declarative_assignment_stmt(cls, api, stmt, cls_metadata) + elif isinstance(stmt, Decorator): + _scan_declarative_decorator_stmt(cls, api, stmt, cls_metadata) _scan_for_mapped_bases(cls, api, cls_metadata) - _add_additional_orm_attributes(cls, api, cls_metadata) - cls.info.metadata["_sa_decl_class_applied"] = cls_metadata.serialize() + if not is_mixin_scan: + apply._add_additional_orm_attributes(cls, api, cls_metadata) + + info.metadata["_sa_decl_class_applied"] = cls_metadata.serialize() return cls_metadata +def _scan_symbol_table_entry( + cls: ClassDef, + api: SemanticAnalyzerPluginInterface, + name: str, + value: SymbolTableNode, + cls_metadata: util.DeclClassApplied, +): + """Extract mapping information from a SymbolTableNode that's in the + type.names dictionary. + + """ + if not isinstance(value.type, Instance): + return + + left_hand_explicit_type = None + type_id = names._type_id_for_named_node(value.type.type) + # type_id = names._type_id_for_unbound_type(value.type.type, cls, api) + + err = False + + # TODO: this is nearly the same logic as that of + # _scan_declarative_decorator_stmt, likely can be merged + if type_id in { + names.MAPPED, + names.RELATIONSHIP, + names.COMPOSITE_PROPERTY, + names.MAPPER_PROPERTY, + names.SYNONYM_PROPERTY, + names.COLUMN_PROPERTY, + }: + if value.type.args: + left_hand_explicit_type = value.type.args[0] + else: + err = True + elif type_id is names.COLUMN: + if not value.type.args: + err = True + else: + typeengine_arg = value.type.args[0] + if isinstance(typeengine_arg, Instance): + typeengine_arg = typeengine_arg.type + + if isinstance(typeengine_arg, (UnboundType, TypeInfo)): + sym = api.lookup(typeengine_arg.name, typeengine_arg) + if sym is not None: + if names._mro_has_id(sym.node.mro, names.TYPEENGINE): + + left_hand_explicit_type = UnionType( + [ + infer._extract_python_type_from_typeengine( + api, sym.node, [] + ), + NoneType(), + ] + ) + else: + util.fail( + api, + "Column type should be a TypeEngine " + "subclass not '{}'".format(sym.node.fullname), + value.type, + ) + + if err: + msg = ( + "Can't infer type from attribute {} on class {}. " + "please specify a return type from this function that is " + "one of: Mapped[], relationship[], " + "Column[], MapperProperty[]" + ) + util.fail(api, msg.format(name, cls.name)) + + left_hand_explicit_type = AnyType(TypeOfAny.special_form) + + if left_hand_explicit_type is not None: + cls_metadata.mapped_attr_names.append((name, left_hand_explicit_type)) + + def _scan_declarative_decorator_stmt( cls: ClassDef, api: SemanticAnalyzerPluginInterface, stmt: Decorator, - cls_metadata: DeclClassApplied, + cls_metadata: util.DeclClassApplied, ): """Extract mapping information from a @declared_attr in a declarative class. @@ -201,7 +236,7 @@ def _scan_declarative_decorator_stmt( left_hand_explicit_type = UnionType( [ - _extract_python_type_from_typeengine( + infer._extract_python_type_from_typeengine( api, sym.node, [] ), NoneType(), @@ -279,7 +314,7 @@ def _scan_declarative_assignment_stmt( cls: ClassDef, api: SemanticAnalyzerPluginInterface, stmt: AssignmentStmt, - cls_metadata: DeclClassApplied, + cls_metadata: util.DeclClassApplied, ): """Extract mapping information from an assignment statement in a declarative class. @@ -317,7 +352,7 @@ def _scan_declarative_assignment_stmt( else: for item in stmt.rvalue.items: if isinstance(item, (NameExpr, StrExpr)): - _apply_mypy_mapped_attr(cls, api, item, cls_metadata) + apply._apply_mypy_mapped_attr(cls, api, item, cls_metadata) left_hand_mapped_type: Type = None @@ -378,24 +413,26 @@ def _scan_declarative_assignment_stmt( if type_id is None: return elif type_id is names.COLUMN: - python_type_for_type = _infer_type_from_decl_column( + python_type_for_type = infer._infer_type_from_decl_column( api, stmt, node, left_hand_explicit_type, stmt.rvalue ) elif type_id is names.RELATIONSHIP: - python_type_for_type = _infer_type_from_relationship( + python_type_for_type = infer._infer_type_from_relationship( api, stmt, node, left_hand_explicit_type ) elif type_id is names.COLUMN_PROPERTY: - python_type_for_type = _infer_type_from_decl_column_property( + python_type_for_type = infer._infer_type_from_decl_column_property( api, stmt, node, left_hand_explicit_type ) elif type_id is names.SYNONYM_PROPERTY: - python_type_for_type = _infer_type_from_left_hand_type_only( + python_type_for_type = infer._infer_type_from_left_hand_type_only( api, node, left_hand_explicit_type ) elif type_id is names.COMPOSITE_PROPERTY: - python_type_for_type = _infer_type_from_decl_composite_property( - api, stmt, node, left_hand_explicit_type + python_type_for_type = ( + infer._infer_type_from_decl_composite_property( + api, stmt, node, left_hand_explicit_type + ) ) else: return @@ -407,7 +444,7 @@ def _scan_declarative_assignment_stmt( assert python_type_for_type is not None - _apply_type_to_mapped_statement( + apply._apply_type_to_mapped_statement( api, stmt, lvalue, @@ -416,486 +453,10 @@ def _scan_declarative_assignment_stmt( ) -def _apply_mypy_mapped_attr( - cls: ClassDef, - api: SemanticAnalyzerPluginInterface, - item: Union[NameExpr, StrExpr], - cls_metadata: DeclClassApplied, -): - if isinstance(item, NameExpr): - name = item.name - elif isinstance(item, StrExpr): - name = item.value - else: - return - - for stmt in cls.defs.body: - if isinstance(stmt, AssignmentStmt) and stmt.lvalues[0].name == name: - break - else: - util.fail(api, "Can't find mapped attribute {}".format(name), cls) - return - - if stmt.type is None: - util.fail( - api, - "Statement linked from _mypy_mapped_attrs has no " - "typing information", - stmt, - ) - return - - left_hand_explicit_type = stmt.type - - cls_metadata.mapped_attr_names.append((name, left_hand_explicit_type)) - - _apply_type_to_mapped_statement( - api, stmt, stmt.lvalues[0], left_hand_explicit_type, None - ) - - -def _infer_type_from_relationship( - api: SemanticAnalyzerPluginInterface, - stmt: AssignmentStmt, - node: Var, - left_hand_explicit_type: Optional[types.Type], -) -> Union[Instance, UnionType, None]: - """Infer the type of mapping from a relationship. - - E.g.:: - - @reg.mapped - class MyClass: - # ... - - addresses = relationship(Address, uselist=True) - - order: Mapped["Order"] = relationship("Order") - - Will resolve in mypy as:: - - @reg.mapped - class MyClass: - # ... - - addresses: Mapped[List[Address]] - - order: Mapped["Order"] - - """ - - assert isinstance(stmt.rvalue, CallExpr) - target_cls_arg = stmt.rvalue.args[0] - python_type_for_type = None - - if isinstance(target_cls_arg, NameExpr) and isinstance( - target_cls_arg.node, TypeInfo - ): - # type - related_object_type = target_cls_arg.node - python_type_for_type = Instance(related_object_type, []) - - # other cases not covered - an error message directs the user - # to set an explicit type annotation - # - # node.type == str, it's a string - # if isinstance(target_cls_arg, NameExpr) and isinstance( - # target_cls_arg.node, Var - # ) - # points to a type - # isinstance(target_cls_arg, NameExpr) and isinstance( - # target_cls_arg.node, TypeAlias - # ) - # string expression - # isinstance(target_cls_arg, StrExpr) - - uselist_arg = util._get_callexpr_kwarg(stmt.rvalue, "uselist") - collection_cls_arg = util._get_callexpr_kwarg( - stmt.rvalue, "collection_class" - ) - - # this can be used to determine Optional for a many-to-one - # in the same way nullable=False could be used, if we start supporting - # that. - # innerjoin_arg = _get_callexpr_kwarg(stmt.rvalue, "innerjoin") - - if ( - uselist_arg is not None - and uselist_arg.fullname == "builtins.True" - and collection_cls_arg is None - ): - if python_type_for_type is not None: - python_type_for_type = Instance( - api.lookup_fully_qualified("builtins.list").node, - [python_type_for_type], - ) - elif ( - uselist_arg is None or uselist_arg.fullname == "builtins.True" - ) and collection_cls_arg is not None: - if isinstance(collection_cls_arg.node, TypeInfo): - if python_type_for_type is not None: - python_type_for_type = Instance( - collection_cls_arg.node, [python_type_for_type] - ) - else: - util.fail( - api, - "Expected Python collection type for " - "collection_class parameter", - stmt.rvalue, - ) - python_type_for_type = None - elif uselist_arg is not None and uselist_arg.fullname == "builtins.False": - if collection_cls_arg is not None: - util.fail( - api, - "Sending uselist=False and collection_class at the same time " - "does not make sense", - stmt.rvalue, - ) - if python_type_for_type is not None: - python_type_for_type = UnionType( - [python_type_for_type, NoneType()] - ) - - else: - if left_hand_explicit_type is None: - msg = ( - "Can't infer scalar or collection for ORM mapped expression " - "assigned to attribute '{}' if both 'uselist' and " - "'collection_class' arguments are absent from the " - "relationship(); please specify a " - "type annotation on the left hand side." - ) - util.fail(api, msg.format(node.name), node) - - if python_type_for_type is None: - return _infer_type_from_left_hand_type_only( - api, node, left_hand_explicit_type - ) - elif left_hand_explicit_type is not None: - return _infer_type_from_left_and_inferred_right( - api, node, left_hand_explicit_type, python_type_for_type - ) - else: - return python_type_for_type - - -def _infer_type_from_decl_composite_property( - api: SemanticAnalyzerPluginInterface, - stmt: AssignmentStmt, - node: Var, - left_hand_explicit_type: Optional[types.Type], -) -> Union[Instance, UnionType, None]: - """Infer the type of mapping from a CompositeProperty.""" - - assert isinstance(stmt.rvalue, CallExpr) - target_cls_arg = stmt.rvalue.args[0] - python_type_for_type = None - - if isinstance(target_cls_arg, NameExpr) and isinstance( - target_cls_arg.node, TypeInfo - ): - related_object_type = target_cls_arg.node - python_type_for_type = Instance(related_object_type, []) - else: - python_type_for_type = None - - if python_type_for_type is None: - return _infer_type_from_left_hand_type_only( - api, node, left_hand_explicit_type - ) - elif left_hand_explicit_type is not None: - return _infer_type_from_left_and_inferred_right( - api, node, left_hand_explicit_type, python_type_for_type - ) - else: - return python_type_for_type - - -def _infer_type_from_decl_column_property( - api: SemanticAnalyzerPluginInterface, - stmt: AssignmentStmt, - node: Var, - left_hand_explicit_type: Optional[types.Type], -) -> Union[Instance, UnionType, None]: - """Infer the type of mapping from a ColumnProperty. - - This includes mappings against ``column_property()`` as well as the - ``deferred()`` function. - - """ - assert isinstance(stmt.rvalue, CallExpr) - first_prop_arg = stmt.rvalue.args[0] - - if isinstance(first_prop_arg, CallExpr): - type_id = names._type_id_for_callee(first_prop_arg.callee) - else: - type_id = None - - print(stmt.lvalues[0].name) - - # look for column_property() / deferred() etc with Column as first - # argument - if type_id is names.COLUMN: - return _infer_type_from_decl_column( - api, stmt, node, left_hand_explicit_type, first_prop_arg - ) - else: - return _infer_type_from_left_hand_type_only( - api, node, left_hand_explicit_type - ) - - -def _infer_type_from_decl_column( - api: SemanticAnalyzerPluginInterface, - stmt: AssignmentStmt, - node: Var, - left_hand_explicit_type: Optional[types.Type], - right_hand_expression: CallExpr, -) -> Union[Instance, UnionType, None]: - """Infer the type of mapping from a Column. - - E.g.:: - - @reg.mapped - class MyClass: - # ... - - a = Column(Integer) - - b = Column("b", String) - - c: Mapped[int] = Column(Integer) - - d: bool = Column(Boolean) - - Will resolve in MyPy as:: - - @reg.mapped - class MyClass: - # ... - - a : Mapped[int] - - b : Mapped[str] - - c: Mapped[int] - - d: Mapped[bool] - - """ - assert isinstance(node, Var) - - callee = None - - for column_arg in right_hand_expression.args[0:2]: - if isinstance(column_arg, nodes.CallExpr): - # x = Column(String(50)) - callee = column_arg.callee - type_args = column_arg.args - break - elif isinstance(column_arg, (nodes.NameExpr, nodes.MemberExpr)): - if isinstance(column_arg.node, TypeInfo): - # x = Column(String) - callee = column_arg - type_args = () - break - else: - # x = Column(some_name, String), go to next argument - continue - elif isinstance(column_arg, (StrExpr,)): - # x = Column("name", String), go to next argument - continue - else: - assert False - - if callee is None: - return None - - if isinstance(callee.node, TypeInfo) and names._mro_has_id( - callee.node.mro, names.TYPEENGINE - ): - python_type_for_type = _extract_python_type_from_typeengine( - api, callee.node, type_args - ) - - if left_hand_explicit_type is not None: - - return _infer_type_from_left_and_inferred_right( - api, node, left_hand_explicit_type, python_type_for_type - ) - - else: - python_type_for_type = UnionType( - [python_type_for_type, NoneType()] - ) - return python_type_for_type - else: - # it's not TypeEngine, it's typically implicitly typed - # like ForeignKey. we can't infer from the right side. - return _infer_type_from_left_hand_type_only( - api, node, left_hand_explicit_type - ) - - -def _infer_type_from_left_and_inferred_right( - api: SemanticAnalyzerPluginInterface, - node: Var, - left_hand_explicit_type: Optional[types.Type], - python_type_for_type: Union[Instance, UnionType], -) -> Optional[Union[Instance, UnionType]]: - """Validate type when a left hand annotation is present and we also - could infer the right hand side:: - - attrname: SomeType = Column(SomeDBType) - - """ - if not is_subtype(left_hand_explicit_type, python_type_for_type): - descriptor = api.lookup("__sa_Mapped", node) - - effective_type = Instance(descriptor.node, [python_type_for_type]) - - msg = ( - "Left hand assignment '{}: {}' not compatible " - "with ORM mapped expression of type {}" - ) - util.fail( - api, - msg.format( - node.name, - format_type(left_hand_explicit_type), - format_type(effective_type), - ), - node, - ) - - return left_hand_explicit_type - - -def _infer_type_from_left_hand_type_only( - api: SemanticAnalyzerPluginInterface, - node: Var, - left_hand_explicit_type: Optional[types.Type], -) -> Optional[Union[Instance, UnionType]]: - """Determine the type based on explicit annotation only. - - if no annotation were present, note that we need one there to know - the type. - - """ - if left_hand_explicit_type is None: - msg = ( - "Can't infer type from ORM mapped expression " - "assigned to attribute '{}'; please specify a " - "Python type or " - "Mapped[] on the left hand side." - ) - util.fail(api, msg.format(node.name), node) - - descriptor = api.lookup("__sa_Mapped", node) - return Instance(descriptor.node, [AnyType(TypeOfAny.special_form)]) - - else: - # use type from the left hand side - return left_hand_explicit_type - - -def _re_apply_declarative_assignments( - cls: ClassDef, - api: SemanticAnalyzerPluginInterface, - cls_metadata: DeclClassApplied, -): - """For multiple class passes, re-apply our left-hand side types as mypy - seems to reset them in place. - - """ - mapped_attr_lookup = { - name: typ for name, typ in cls_metadata.mapped_attr_names - } - - descriptor = api.lookup("__sa_Mapped", cls) - for stmt in cls.defs.body: - # for a re-apply, all of our statements are AssignmentStmt; - # @declared_attr calls will have been converted and this - # currently seems to be preserved by mypy (but who knows if this - # will change). - if ( - isinstance(stmt, AssignmentStmt) - and stmt.lvalues[0].name in mapped_attr_lookup - ): - typ = mapped_attr_lookup[stmt.lvalues[0].name] - left_node = stmt.lvalues[0].node - - inst = Instance(descriptor.node, [typ]) - left_node.type = inst - - -def _apply_type_to_mapped_statement( - api: SemanticAnalyzerPluginInterface, - stmt: AssignmentStmt, - lvalue: NameExpr, - left_hand_explicit_type: Optional[Union[Instance, UnionType]], - python_type_for_type: Union[Instance, UnionType], -) -> None: - """Apply the Mapped[] annotation and right hand object to a - declarative assignment statement. - - This converts a Python declarative class statement such as:: - - class User(Base): - # ... - - attrname = Column(Integer) - - To one that describes the final Python behavior to Mypy:: - - class User(Base): - # ... - - attrname : Mapped[Optional[int]] = - - """ - descriptor = api.lookup("__sa_Mapped", stmt) - left_node = lvalue.node - - inst = Instance(descriptor.node, [python_type_for_type]) - - if left_hand_explicit_type is not None: - left_node.type = Instance(descriptor.node, [left_hand_explicit_type]) - else: - lvalue.is_inferred_def = False - left_node.type = inst - - # so to have it skip the right side totally, we can do this: - # stmt.rvalue = TempNode(AnyType(TypeOfAny.special_form)) - - # however, if we instead manufacture a new node that uses the old - # one, then we can still get type checking for the call itself, - # e.g. the Column, relationship() call, etc. - - # rewrite the node as: - # : Mapped[] = - # _sa_Mapped._empty_constructor() - # the original right-hand side is maintained so it gets type checked - # internally - api.add_symbol_table_node("_sa_Mapped", descriptor) - column_descriptor = nodes.NameExpr("_sa_Mapped") - column_descriptor.fullname = "sqlalchemy.orm.Mapped" - mm = nodes.MemberExpr(column_descriptor, "_empty_constructor") - orig_call_expr = stmt.rvalue - stmt.rvalue = CallExpr( - mm, - [orig_call_expr], - [nodes.ARG_POS], - ["arg1"], - ) - - def _scan_for_mapped_bases( cls: ClassDef, api: SemanticAnalyzerPluginInterface, - cls_metadata: DeclClassApplied, + cls_metadata: util.DeclClassApplied, ) -> None: """Given a class, iterate through its superclass hierarchy to find all other classes that are considered as ORM-significant. @@ -905,99 +466,25 @@ def _scan_for_mapped_bases( """ - baseclasses = list(cls.info.bases) + info = util._info_for_cls(cls, api) + + baseclasses = list(info.bases) + while baseclasses: base: Instance = baseclasses.pop(0) + if base.type.fullname.startswith("builtins"): + continue + # scan each base for mapped attributes. if they are not already - # scanned, that means they are unmapped mixins + # scanned (but have all their type info), that means they are unmapped + # mixins base_decl_class_applied = ( _scan_declarative_assignments_and_apply_types( base.type.defn, api, is_mixin_scan=True ) ) - if base_decl_class_applied is not None: + + if base_decl_class_applied not in (None, False): cls_metadata.mapped_mro.append(base) baseclasses.extend(base.type.bases) - - -def _add_additional_orm_attributes( - cls: ClassDef, - api: SemanticAnalyzerPluginInterface, - cls_metadata: DeclClassApplied, -) -> None: - """Apply __init__, __table__ and other attributes to the mapped class.""" - if "__init__" not in cls.info.names and cls_metadata.is_mapped: - mapped_attr_names = {n: t for n, t in cls_metadata.mapped_attr_names} - - for mapped_base in cls_metadata.mapped_mro: - base_cls_metadata = DeclClassApplied.deserialize( - mapped_base.type.metadata["_sa_decl_class_applied"], api - ) - for n, t in base_cls_metadata.mapped_attr_names: - mapped_attr_names.setdefault(n, t) - - arguments = [] - for name, typ in mapped_attr_names.items(): - if typ is None: - typ = AnyType(TypeOfAny.special_form) - arguments.append( - Argument( - variable=Var(name, typ), - type_annotation=typ, - initializer=TempNode(typ), - kind=ARG_NAMED_OPT, - ) - ) - add_method_to_class(api, cls, "__init__", arguments, NoneTyp()) - - if "__table__" not in cls.info.names and cls_metadata.has_table: - _apply_placeholder_attr_to_class( - api, cls, "sqlalchemy.sql.schema.Table", "__table__" - ) - if cls_metadata.is_mapped: - _apply_placeholder_attr_to_class( - api, cls, "sqlalchemy.orm.mapper.Mapper", "__mapper__" - ) - - -def _apply_placeholder_attr_to_class( - api: SemanticAnalyzerPluginInterface, - cls: ClassDef, - qualified_name: str, - attrname: str, -): - sym = api.lookup_fully_qualified_or_none(qualified_name) - if sym: - assert isinstance(sym.node, TypeInfo) - type_ = Instance(sym.node, []) - else: - type_ = AnyType(TypeOfAny.special_form) - var = Var(attrname) - var.info = cls.info - var.type = type_ - cls.info.names[attrname] = SymbolTableNode(MDEF, var) - - -def _extract_python_type_from_typeengine( - api: SemanticAnalyzerPluginInterface, node: TypeInfo, type_args -) -> Instance: - if node.fullname == "sqlalchemy.sql.sqltypes.Enum" and type_args: - first_arg = type_args[0] - if isinstance(first_arg, NameExpr) and isinstance( - first_arg.node, TypeInfo - ): - for base_ in first_arg.node.mro: - if base_.fullname == "enum.Enum": - return Instance(first_arg.node, []) - # TODO: support other pep-435 types here - else: - n = api.lookup_fully_qualified("builtins.str") - return Instance(n.node, []) - - for mr in node.mro: - if mr.bases: - for base_ in mr.bases: - if base_.type.fullname == "sqlalchemy.sql.type_api.TypeEngine": - return base_.args[-1] - assert False, "could not extract Python type from node: %s" % node diff --git a/lib/sqlalchemy/ext/mypy/infer.py b/lib/sqlalchemy/ext/mypy/infer.py new file mode 100644 index 0000000000..1d77e67d2e --- /dev/null +++ b/lib/sqlalchemy/ext/mypy/infer.py @@ -0,0 +1,398 @@ +# ext/mypy/infer.py +# Copyright (C) 2021 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +from typing import Optional +from typing import Union + +from mypy import nodes +from mypy import types +from mypy.messages import format_type +from mypy.nodes import AssignmentStmt +from mypy.nodes import CallExpr +from mypy.nodes import NameExpr +from mypy.nodes import StrExpr +from mypy.nodes import TypeInfo +from mypy.nodes import Var +from mypy.plugin import SemanticAnalyzerPluginInterface +from mypy.subtypes import is_subtype +from mypy.types import AnyType +from mypy.types import Instance +from mypy.types import NoneType +from mypy.types import TypeOfAny +from mypy.types import UnionType + +from . import names +from . import util + + +def _infer_type_from_relationship( + api: SemanticAnalyzerPluginInterface, + stmt: AssignmentStmt, + node: Var, + left_hand_explicit_type: Optional[types.Type], +) -> Union[Instance, UnionType, None]: + """Infer the type of mapping from a relationship. + + E.g.:: + + @reg.mapped + class MyClass: + # ... + + addresses = relationship(Address, uselist=True) + + order: Mapped["Order"] = relationship("Order") + + Will resolve in mypy as:: + + @reg.mapped + class MyClass: + # ... + + addresses: Mapped[List[Address]] + + order: Mapped["Order"] + + """ + + assert isinstance(stmt.rvalue, CallExpr) + target_cls_arg = stmt.rvalue.args[0] + python_type_for_type = None + + if isinstance(target_cls_arg, NameExpr) and isinstance( + target_cls_arg.node, TypeInfo + ): + # type + related_object_type = target_cls_arg.node + python_type_for_type = Instance(related_object_type, []) + + # other cases not covered - an error message directs the user + # to set an explicit type annotation + # + # node.type == str, it's a string + # if isinstance(target_cls_arg, NameExpr) and isinstance( + # target_cls_arg.node, Var + # ) + # points to a type + # isinstance(target_cls_arg, NameExpr) and isinstance( + # target_cls_arg.node, TypeAlias + # ) + # string expression + # isinstance(target_cls_arg, StrExpr) + + uselist_arg = util._get_callexpr_kwarg(stmt.rvalue, "uselist") + collection_cls_arg = util._get_callexpr_kwarg( + stmt.rvalue, "collection_class" + ) + + # this can be used to determine Optional for a many-to-one + # in the same way nullable=False could be used, if we start supporting + # that. + # innerjoin_arg = _get_callexpr_kwarg(stmt.rvalue, "innerjoin") + + if ( + uselist_arg is not None + and uselist_arg.fullname == "builtins.True" + and collection_cls_arg is None + ): + if python_type_for_type is not None: + python_type_for_type = Instance( + api.lookup_fully_qualified("builtins.list").node, + [python_type_for_type], + ) + elif ( + uselist_arg is None or uselist_arg.fullname == "builtins.True" + ) and collection_cls_arg is not None: + if isinstance(collection_cls_arg.node, TypeInfo): + if python_type_for_type is not None: + python_type_for_type = Instance( + collection_cls_arg.node, [python_type_for_type] + ) + else: + util.fail( + api, + "Expected Python collection type for " + "collection_class parameter", + stmt.rvalue, + ) + python_type_for_type = None + elif uselist_arg is not None and uselist_arg.fullname == "builtins.False": + if collection_cls_arg is not None: + util.fail( + api, + "Sending uselist=False and collection_class at the same time " + "does not make sense", + stmt.rvalue, + ) + if python_type_for_type is not None: + python_type_for_type = UnionType( + [python_type_for_type, NoneType()] + ) + + else: + if left_hand_explicit_type is None: + msg = ( + "Can't infer scalar or collection for ORM mapped expression " + "assigned to attribute '{}' if both 'uselist' and " + "'collection_class' arguments are absent from the " + "relationship(); please specify a " + "type annotation on the left hand side." + ) + util.fail(api, msg.format(node.name), node) + + if python_type_for_type is None: + return _infer_type_from_left_hand_type_only( + api, node, left_hand_explicit_type + ) + elif left_hand_explicit_type is not None: + return _infer_type_from_left_and_inferred_right( + api, node, left_hand_explicit_type, python_type_for_type + ) + else: + return python_type_for_type + + +def _infer_type_from_decl_composite_property( + api: SemanticAnalyzerPluginInterface, + stmt: AssignmentStmt, + node: Var, + left_hand_explicit_type: Optional[types.Type], +) -> Union[Instance, UnionType, None]: + """Infer the type of mapping from a CompositeProperty.""" + + assert isinstance(stmt.rvalue, CallExpr) + target_cls_arg = stmt.rvalue.args[0] + python_type_for_type = None + + if isinstance(target_cls_arg, NameExpr) and isinstance( + target_cls_arg.node, TypeInfo + ): + related_object_type = target_cls_arg.node + python_type_for_type = Instance(related_object_type, []) + else: + python_type_for_type = None + + if python_type_for_type is None: + return _infer_type_from_left_hand_type_only( + api, node, left_hand_explicit_type + ) + elif left_hand_explicit_type is not None: + return _infer_type_from_left_and_inferred_right( + api, node, left_hand_explicit_type, python_type_for_type + ) + else: + return python_type_for_type + + +def _infer_type_from_decl_column_property( + api: SemanticAnalyzerPluginInterface, + stmt: AssignmentStmt, + node: Var, + left_hand_explicit_type: Optional[types.Type], +) -> Union[Instance, UnionType, None]: + """Infer the type of mapping from a ColumnProperty. + + This includes mappings against ``column_property()`` as well as the + ``deferred()`` function. + + """ + assert isinstance(stmt.rvalue, CallExpr) + first_prop_arg = stmt.rvalue.args[0] + + if isinstance(first_prop_arg, CallExpr): + type_id = names._type_id_for_callee(first_prop_arg.callee) + else: + type_id = None + + # look for column_property() / deferred() etc with Column as first + # argument + if type_id is names.COLUMN: + return _infer_type_from_decl_column( + api, stmt, node, left_hand_explicit_type, first_prop_arg + ) + else: + return _infer_type_from_left_hand_type_only( + api, node, left_hand_explicit_type + ) + + +def _infer_type_from_decl_column( + api: SemanticAnalyzerPluginInterface, + stmt: AssignmentStmt, + node: Var, + left_hand_explicit_type: Optional[types.Type], + right_hand_expression: CallExpr, +) -> Union[Instance, UnionType, None]: + """Infer the type of mapping from a Column. + + E.g.:: + + @reg.mapped + class MyClass: + # ... + + a = Column(Integer) + + b = Column("b", String) + + c: Mapped[int] = Column(Integer) + + d: bool = Column(Boolean) + + Will resolve in MyPy as:: + + @reg.mapped + class MyClass: + # ... + + a : Mapped[int] + + b : Mapped[str] + + c: Mapped[int] + + d: Mapped[bool] + + """ + assert isinstance(node, Var) + + callee = None + + for column_arg in right_hand_expression.args[0:2]: + if isinstance(column_arg, nodes.CallExpr): + # x = Column(String(50)) + callee = column_arg.callee + type_args = column_arg.args + break + elif isinstance(column_arg, (nodes.NameExpr, nodes.MemberExpr)): + if isinstance(column_arg.node, TypeInfo): + # x = Column(String) + callee = column_arg + type_args = () + break + else: + # x = Column(some_name, String), go to next argument + continue + elif isinstance(column_arg, (StrExpr,)): + # x = Column("name", String), go to next argument + continue + else: + assert False + + if callee is None: + return None + + if isinstance(callee.node, TypeInfo) and names._mro_has_id( + callee.node.mro, names.TYPEENGINE + ): + python_type_for_type = _extract_python_type_from_typeengine( + api, callee.node, type_args + ) + + if left_hand_explicit_type is not None: + + return _infer_type_from_left_and_inferred_right( + api, node, left_hand_explicit_type, python_type_for_type + ) + + else: + python_type_for_type = UnionType( + [python_type_for_type, NoneType()] + ) + return python_type_for_type + else: + # it's not TypeEngine, it's typically implicitly typed + # like ForeignKey. we can't infer from the right side. + return _infer_type_from_left_hand_type_only( + api, node, left_hand_explicit_type + ) + + +def _infer_type_from_left_and_inferred_right( + api: SemanticAnalyzerPluginInterface, + node: Var, + left_hand_explicit_type: Optional[types.Type], + python_type_for_type: Union[Instance, UnionType], +) -> Optional[Union[Instance, UnionType]]: + """Validate type when a left hand annotation is present and we also + could infer the right hand side:: + + attrname: SomeType = Column(SomeDBType) + + """ + if not is_subtype(left_hand_explicit_type, python_type_for_type): + descriptor = api.lookup("__sa_Mapped", node) + + effective_type = Instance(descriptor.node, [python_type_for_type]) + + msg = ( + "Left hand assignment '{}: {}' not compatible " + "with ORM mapped expression of type {}" + ) + util.fail( + api, + msg.format( + node.name, + format_type(left_hand_explicit_type), + format_type(effective_type), + ), + node, + ) + + return left_hand_explicit_type + + +def _infer_type_from_left_hand_type_only( + api: SemanticAnalyzerPluginInterface, + node: Var, + left_hand_explicit_type: Optional[types.Type], +) -> Optional[Union[Instance, UnionType]]: + """Determine the type based on explicit annotation only. + + if no annotation were present, note that we need one there to know + the type. + + """ + if left_hand_explicit_type is None: + msg = ( + "Can't infer type from ORM mapped expression " + "assigned to attribute '{}'; please specify a " + "Python type or " + "Mapped[] on the left hand side." + ) + util.fail(api, msg.format(node.name), node) + + descriptor = api.lookup("__sa_Mapped", node) + return Instance(descriptor.node, [AnyType(TypeOfAny.special_form)]) + + else: + # use type from the left hand side + return left_hand_explicit_type + + +def _extract_python_type_from_typeengine( + api: SemanticAnalyzerPluginInterface, node: TypeInfo, type_args +) -> Instance: + if node.fullname == "sqlalchemy.sql.sqltypes.Enum" and type_args: + first_arg = type_args[0] + if isinstance(first_arg, NameExpr) and isinstance( + first_arg.node, TypeInfo + ): + for base_ in first_arg.node.mro: + if base_.fullname == "enum.Enum": + return Instance(first_arg.node, []) + # TODO: support other pep-435 types here + else: + n = api.lookup_fully_qualified("builtins.str") + return Instance(n.node, []) + + for mr in node.mro: + if mr.bases: + for base_ in mr.bases: + if base_.type.fullname == "sqlalchemy.sql.type_api.TypeEngine": + return base_.args[-1] + assert False, "could not extract Python type from node: %s" % node diff --git a/lib/sqlalchemy/ext/mypy/names.py b/lib/sqlalchemy/ext/mypy/names.py index d1fd77415a..11208f3c71 100644 --- a/lib/sqlalchemy/ext/mypy/names.py +++ b/lib/sqlalchemy/ext/mypy/names.py @@ -36,6 +36,7 @@ DECLARED_ATTR = util.symbol("DECLARED_ATTR") MAPPER_PROPERTY = util.symbol("MAPPER_PROPERTY") AS_DECLARATIVE = util.symbol("AS_DECLARATIVE") AS_DECLARATIVE_BASE = util.symbol("AS_DECLARATIVE_BASE") +DECLARATIVE_MIXIN = util.symbol("DECLARATIVE_MIXIN") _lookup = { "Column": ( @@ -134,6 +135,13 @@ _lookup = { "sqlalchemy.orm.declared_attr", }, ), + "declarative_mixin": ( + DECLARATIVE_MIXIN, + { + "sqlalchemy.orm.decl_api.declarative_mixin", + "sqlalchemy.orm.declarative_mixin", + }, + ), } diff --git a/lib/sqlalchemy/ext/mypy/plugin.py b/lib/sqlalchemy/ext/mypy/plugin.py index 9ca1cb2daf..a0aa5bf040 100644 --- a/lib/sqlalchemy/ext/mypy/plugin.py +++ b/lib/sqlalchemy/ext/mypy/plugin.py @@ -55,6 +55,7 @@ class CustomPlugin(Plugin): # subclasses. but then you can just check it here from the "base" # and get the same effect. sym = self.lookup_fully_qualified(fullname) + if ( sym and isinstance(sym.node, TypeInfo) @@ -70,17 +71,18 @@ class CustomPlugin(Plugin): ) -> Optional[Callable[[ClassDefContext], None]]: sym = self.lookup_fully_qualified(fullname) - if ( - sym is not None - and names._type_id_for_named_node(sym.node) - is names.MAPPED_DECORATOR - ): - return _cls_decorator_hook - elif sym is not None and names._type_id_for_named_node(sym.node) in ( - names.AS_DECLARATIVE, - names.AS_DECLARATIVE_BASE, - ): - return _base_cls_decorator_hook + + if sym is not None: + type_id = names._type_id_for_named_node(sym.node) + if type_id is names.MAPPED_DECORATOR: + return _cls_decorator_hook + elif type_id in ( + names.AS_DECLARATIVE, + names.AS_DECLARATIVE_BASE, + ): + return _base_cls_decorator_hook + elif type_id is names.DECLARATIVE_MIXIN: + return _declarative_mixin_hook return None @@ -192,6 +194,13 @@ def _base_cls_hook(ctx: ClassDefContext) -> None: decl_class._scan_declarative_assignments_and_apply_types(ctx.cls, ctx.api) +def _declarative_mixin_hook(ctx: ClassDefContext) -> None: + _add_globals(ctx) + decl_class._scan_declarative_assignments_and_apply_types( + ctx.cls, ctx.api, is_mixin_scan=True + ) + + def _cls_decorator_hook(ctx: ClassDefContext) -> None: _add_globals(ctx) assert isinstance(ctx.reason, nodes.MemberExpr) diff --git a/lib/sqlalchemy/ext/mypy/util.py b/lib/sqlalchemy/ext/mypy/util.py index 7079f3cd78..becce3ebec 100644 --- a/lib/sqlalchemy/ext/mypy/util.py +++ b/lib/sqlalchemy/ext/mypy/util.py @@ -1,18 +1,67 @@ from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Type from mypy.nodes import CallExpr +from mypy.nodes import CLASSDEF_NO_INFO from mypy.nodes import Context from mypy.nodes import IfStmt +from mypy.nodes import JsonDict from mypy.nodes import NameExpr from mypy.nodes import SymbolTableNode +from mypy.nodes import TypeInfo from mypy.plugin import SemanticAnalyzerPluginInterface +from mypy.plugins.common import deserialize_and_fixup_type from mypy.types import Instance from mypy.types import NoneType -from mypy.types import Type from mypy.types import UnboundType from mypy.types import UnionType +class DeclClassApplied: + def __init__( + self, + is_mapped: bool, + has_table: bool, + mapped_attr_names: Sequence[Tuple[str, Type]], + mapped_mro: Sequence[Type], + ): + self.is_mapped = is_mapped + self.has_table = has_table + self.mapped_attr_names = mapped_attr_names + self.mapped_mro = mapped_mro + + def serialize(self) -> JsonDict: + return { + "is_mapped": self.is_mapped, + "has_table": self.has_table, + "mapped_attr_names": [ + (name, type_.serialize()) + for name, type_ in self.mapped_attr_names + ], + "mapped_mro": [type_.serialize() for type_ in self.mapped_mro], + } + + @classmethod + def deserialize( + cls, data: JsonDict, api: SemanticAnalyzerPluginInterface + ) -> "DeclClassApplied": + + return DeclClassApplied( + is_mapped=data["is_mapped"], + has_table=data["has_table"], + mapped_attr_names=[ + (name, deserialize_and_fixup_type(type_, api)) + for name, type_ in data["mapped_attr_names"] + ], + mapped_mro=[ + deserialize_and_fixup_type(type_, api) + for type_ in data["mapped_mro"] + ], + ) + + def fail(api: SemanticAnalyzerPluginInterface, msg: str, ctx: Context): msg = "[SQLAlchemy Mypy plugin] %s" % msg return api.fail(msg, ctx) @@ -94,3 +143,14 @@ def _unbound_to_instance( ) else: return typ + + +def _info_for_cls(cls, api): + if cls.info is CLASSDEF_NO_INFO: + sym = api.lookup(cls.name, cls) + if sym.node and isinstance(sym.node, TypeInfo): + info = sym.node + else: + info = cls.info + + return info diff --git a/lib/sqlalchemy/orm/__init__.py b/lib/sqlalchemy/orm/__init__.py index 025d826e34..66c3e7e33e 100644 --- a/lib/sqlalchemy/orm/__init__.py +++ b/lib/sqlalchemy/orm/__init__.py @@ -23,6 +23,7 @@ from .attributes import QueryableAttribute from .context import QueryContext from .decl_api import as_declarative from .decl_api import declarative_base +from .decl_api import declarative_mixin from .decl_api import DeclarativeMeta from .decl_api import declared_attr from .decl_api import has_inherited_table diff --git a/lib/sqlalchemy/orm/decl_api.py b/lib/sqlalchemy/orm/decl_api.py index ef53e2d399..d9c464815b 100644 --- a/lib/sqlalchemy/orm/decl_api.py +++ b/lib/sqlalchemy/orm/decl_api.py @@ -321,6 +321,48 @@ class _stateful_declared_attr(declared_attr): return declared_attr(fn, **self.kw) +def declarative_mixin(cls): + """Mark a class as providing the feature of "declarative mixin". + + E.g.:: + + from sqlalchemy.orm import declared_attr + from sqlalchemy.orm import declarative_mixin + + @declarative_mixin + class MyMixin: + + @declared_attr + def __tablename__(cls): + return cls.__name__.lower() + + __table_args__ = {'mysql_engine': 'InnoDB'} + __mapper_args__= {'always_refresh': True} + + id = Column(Integer, primary_key=True) + + class MyModel(MyMixin, Base): + name = Column(String(1000)) + + The :func:`_orm.declarative_mixin` decorator currently does not modify + the given class in any way; it's current purpose is strictly to assist + the :ref:`Mypy plugin ` in being able to identify + SQLAlchemy declarative mixin classes when no other context is present. + + .. versionadded:: 1.4.6 + + .. seealso:: + + :ref:`orm_mixins_toplevel` + + :ref:`mypy_declarative_mixins` - in the + :ref:`Mypy plugin documentation ` + + """ # noqa: E501 + + return cls + + def declarative_base( bind=None, metadata=None, diff --git a/test/ext/mypy/files/mixin_three.py b/test/ext/mypy/files/mixin_three.py new file mode 100644 index 0000000000..cb8e30df81 --- /dev/null +++ b/test/ext/mypy/files/mixin_three.py @@ -0,0 +1,33 @@ +from typing import Callable + +from sqlalchemy import Column +from sqlalchemy import Integer +from sqlalchemy import String +from sqlalchemy.orm import deferred +from sqlalchemy.orm import Mapped +from sqlalchemy.orm.decl_api import declarative_mixin +from sqlalchemy.orm.decl_api import declared_attr +from sqlalchemy.orm.interfaces import MapperProperty + + +def some_other_decorator(fn: Callable[..., None]) -> Callable[..., None]: + return fn + + +@declarative_mixin +class HasAMixin: + x: Mapped[int] = Column(Integer) + + y = Column(String) + + @declared_attr + def data(cls) -> Column[String]: + return Column(String) + + @declared_attr + def data2(cls) -> MapperProperty[str]: + return deferred(Column(String)) + + @some_other_decorator + def q(cls) -> None: + return None diff --git a/test/ext/mypy/incremental/ticket_6147/__init__.py b/test/ext/mypy/incremental/ticket_6147/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/test/ext/mypy/incremental/ticket_6147/base.py b/test/ext/mypy/incremental/ticket_6147/base.py new file mode 100644 index 0000000000..59be70308c --- /dev/null +++ b/test/ext/mypy/incremental/ticket_6147/base.py @@ -0,0 +1,3 @@ +from sqlalchemy.orm import declarative_base + +Base = declarative_base() diff --git a/test/ext/mypy/incremental/ticket_6147/one.py b/test/ext/mypy/incremental/ticket_6147/one.py new file mode 100644 index 0000000000..17fb075ac4 --- /dev/null +++ b/test/ext/mypy/incremental/ticket_6147/one.py @@ -0,0 +1,13 @@ +from sqlalchemy import Column +from sqlalchemy import Integer +from .base import Base + + +class One(Base): + __tablename__ = "one" + id = Column(Integer, primary_key=True) + + +o1 = One(id=5) + +One.id.in_([1, 2]) diff --git a/test/ext/mypy/incremental/ticket_6147/patch1.testpatch b/test/ext/mypy/incremental/ticket_6147/patch1.testpatch new file mode 100644 index 0000000000..b1d9bde011 --- /dev/null +++ b/test/ext/mypy/incremental/ticket_6147/patch1.testpatch @@ -0,0 +1,19 @@ +--- a/one.py 2021-04-03 15:32:22.214287290 -0400 ++++ b/one.py 2021-04-03 15:34:56.397398510 -0400 +@@ -1,15 +1,13 @@ + from sqlalchemy import Column + from sqlalchemy import Integer +-from sqlalchemy import String + from .base import Base + + + class One(Base): + __tablename__ = "one" + id = Column(Integer, primary_key=True) +- name = Column(String(50)) + + +-o1 = One(id=5, name="name") ++o1 = One(id=5) + + One.id.in_([1, 2]) diff --git a/test/ext/mypy/incremental/ticket_6147/patch2.testpatch b/test/ext/mypy/incremental/ticket_6147/patch2.testpatch new file mode 100644 index 0000000000..7551659571 --- /dev/null +++ b/test/ext/mypy/incremental/ticket_6147/patch2.testpatch @@ -0,0 +1,38 @@ +--- a/base.py 2021-04-03 16:36:30.201594994 -0400 ++++ b/base.py 2021-04-03 16:38:26.404475025 -0400 +@@ -1,3 +1,15 @@ ++from sqlalchemy import Column ++from sqlalchemy import Integer ++from sqlalchemy import String + from sqlalchemy.orm import declarative_base ++from sqlalchemy.orm import declarative_mixin ++from sqlalchemy.orm import Mapped + + Base = declarative_base() ++ ++ ++@declarative_mixin ++class Mixin: ++ mixed = Column(String) ++ ++ b_int: Mapped[int] = Column(Integer) +--- a/one.py 2021-04-03 16:37:17.906956282 -0400 ++++ b/one.py 2021-04-03 16:38:33.469528528 -0400 +@@ -1,13 +1,15 @@ + from sqlalchemy import Column + from sqlalchemy import Integer ++ + from .base import Base ++from .base import Mixin + + +-class One(Base): ++class One(Mixin, Base): + __tablename__ = "one" + id = Column(Integer, primary_key=True) + + +-o1 = One(id=5) ++o1 = One(id=5, mixed="mixed", b_int=5) + + One.id.in_([1, 2]) diff --git a/test/ext/mypy/test_mypy_plugin_py3k.py b/test/ext/mypy/test_mypy_plugin_py3k.py index c8d042db0b..4ab16540d3 100644 --- a/test/ext/mypy/test_mypy_plugin_py3k.py +++ b/test/ext/mypy/test_mypy_plugin_py3k.py @@ -11,8 +11,17 @@ from sqlalchemy.testing import fixtures class MypyPluginTest(fixtures.TestBase): __requires__ = ("sqlalchemy2_stubs",) + @testing.fixture(scope="function") + def per_func_cachedir(self): + for item in self._cachedir(): + yield item + @testing.fixture(scope="class") def cachedir(self): + for item in self._cachedir(): + yield item + + def _cachedir(self): with tempfile.TemporaryDirectory() as cachedir: with open( os.path.join(cachedir, "sqla_mypy_config.cfg"), "w" @@ -77,9 +86,11 @@ class MypyPluginTest(fixtures.TestBase): *[(dirname,) for dirname in _incremental_dirs()], argnames="dirname" ) @testing.requires.patch_library - def test_incremental(self, mypy_runner, cachedir, dirname): + def test_incremental(self, mypy_runner, per_func_cachedir, dirname): import patch + cachedir = per_func_cachedir + path = os.path.join(os.path.dirname(__file__), "incremental", dirname) dest = os.path.join(cachedir, "mymodel") os.mkdir(dest) @@ -101,7 +112,9 @@ class MypyPluginTest(fixtures.TestBase): if patchfile is not None: print("Applying patchfile %s" % patchfile) patch_obj = patch.fromfile(os.path.join(path, patchfile)) - patch_obj.apply(1, dest) + assert patch_obj.apply(1, dest), ( + "pathfile %s failed" % patchfile + ) print("running mypy against %s/mymodel" % cachedir) result = mypy_runner( "mymodel", diff --git a/test/orm/declarative/test_mixin.py b/test/orm/declarative/test_mixin.py index 05628641a9..664c006303 100644 --- a/test/orm/declarative/test_mixin.py +++ b/test/orm/declarative/test_mixin.py @@ -13,6 +13,7 @@ from sqlalchemy.orm import close_all_sessions from sqlalchemy.orm import column_property from sqlalchemy.orm import configure_mappers from sqlalchemy.orm import declarative_base +from sqlalchemy.orm import declarative_mixin from sqlalchemy.orm import declared_attr from sqlalchemy.orm import deferred from sqlalchemy.orm import events as orm_events @@ -103,6 +104,39 @@ class DeclarativeMixinTest(DeclarativeTestBase): eq_(obj.name, "testing") eq_(obj.foo(), "bar1") + def test_declarative_mixin_decorator(self): + + # note we are also making sure an "old style class" in Python 2, + # as we are now illustrating in all the docs for mixins, doesn't cause + # a problem.... + @declarative_mixin + class MyMixin: + + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) + + def foo(self): + return "bar" + str(self.id) + + # ...as long as the mapped class itself is "new style", which will + # normally be the case for users using declarative_base + @mapper_registry.mapped + class MyModel(MyMixin, object): + + __tablename__ = "test" + name = Column(String(100), nullable=False, index=True) + + Base.metadata.create_all(testing.db) + session = fixture_session() + session.add(MyModel(name="testing")) + session.flush() + session.expunge_all() + obj = session.query(MyModel).one() + eq_(obj.id, 1) + eq_(obj.name, "testing") + eq_(obj.foo(), "bar1") + def test_unique_column(self): class MyMixin(object): -- 2.47.2