--- /dev/null
+.. change::
+ :tags: bug, mypy
+ :tickets: 6147
+
+ Applied a series of refactorings and fixes to accommodate for Mypy
+ "incremental" mode across multiple files, which previously was not taken
+ into account. In this mode the Mypy plugin has to accommodate Python
+ datatypes expressed in other files coming in with less information than
+ they have on a direct run.
+
+ Additionally, a new decorator :func:`_orm.declarative_mixin` is added,
+ which is necessary for the Mypy plugin to be able to definifitely identify
+ a Declarative mixin class that is otherwise not used inside a particular
+ Python file.
+
+ .. seealso::
+
+ :ref:`mypy_declarative_mixins`
+
An example of some commonly mixed-in idioms is below::
+ from sqlalchemy.orm import declarative_mixin
from sqlalchemy.orm import declared_attr
- class MyMixin(object):
+ @declarative_mixin
+ class MyMixin:
@declared_attr
def __tablename__(cls):
from the name of the class itself, as well as ``__table_args__``
and ``__mapper_args__`` defined by the ``MyMixin`` mixin class.
+.. tip::
+
+ The use of the :func:`_orm.declarative_mixin` class decorator marks a
+ particular class as providing the service of providing SQLAlchemy declarative
+ assignments as a mixin for other classes. This decorator is currently only
+ necessary to provide a hint to the :ref:`Mypy plugin <mypy_toplevel>` that
+ this class should be handled as part of declarative mappings.
+
There's no fixed convention over whether ``MyMixin`` precedes
``Base`` or not. Normal Python method resolution rules apply, and
the above example would work just as well with::
from sqlalchemy.orm import declared_attr
- class Base(object):
+ class Base:
@declared_attr
def __tablename__(cls):
return cls.__name__.lower()
The most basic way to specify a column on a mixin is by simple
declaration::
- class TimestampMixin(object):
+ @declarative_mixin
+ class TimestampMixin:
created_at = Column(DateTime, default=func.now())
class MyModel(TimestampMixin, Base):
from sqlalchemy.orm import declared_attr
- class ReferenceAddressMixin(object):
+ @declarative_mixin
+ class ReferenceAddressMixin:
@declared_attr
def address_id(cls):
return Column(Integer, ForeignKey('address.id'))
by ``polymorphic_on`` and ``version_id_col``; the declarative extension
will resolve them at class construction time::
+ @declarative_mixin
class MyMixin:
@declared_attr
def type_(cls):
relationship so that two classes ``Foo`` and ``Bar`` can both be configured to
reference a common target class via many-to-one::
- class RefTargetMixin(object):
+ @declarative_mixin
+ class RefTargetMixin:
@declared_attr
def target_id(cls):
return Column('target_id', ForeignKey('target.id'))
The canonical example is the primaryjoin condition that depends upon
another mixed-in column::
- class RefTargetMixin(object):
+ @declarative_mixin
+ class RefTargetMixin:
@declared_attr
def target_id(cls):
return Column('target_id', ForeignKey('target.id'))
The condition above is resolved using a lambda::
- class RefTargetMixin(object):
+ @declarative_mixin
+ class RefTargetMixin:
@declared_attr
def target_id(cls):
return Column('target_id', ForeignKey('target.id'))
or alternatively, the string form (which ultimately generates a lambda)::
- class RefTargetMixin(object):
+ @declarative_mixin
+ class RefTargetMixin:
@declared_attr
def target_id(cls):
return Column('target_id', ForeignKey('target.id'))
used with declarative mixins, have the :class:`_orm.declared_attr`
requirement so that no reliance on copying is needed::
- class SomethingMixin(object):
+ @declarative_mixin
+ class SomethingMixin:
@declared_attr
def dprop(cls):
to other columns from the mixin. These are copied ahead of time before
the :class:`_orm.declared_attr` is invoked::
- class SomethingMixin(object):
+ @declarative_mixin
+ class SomethingMixin:
x = Column(Integer)
y = Column(Integer)
string values to an implementing class::
from sqlalchemy import Column, Integer, ForeignKey, String
- from sqlalchemy.orm import relationship
from sqlalchemy.ext.associationproxy import association_proxy
- from sqlalchemy.orm import declarative_base, declared_attr
+ from sqlalchemy.orm import declarative_base
+ from sqlalchemy.orm import declarative_mixin
+ from sqlalchemy.orm import declared_attr
+ from sqlalchemy.orm import relationship
Base = declarative_base()
- class HasStringCollection(object):
+ @declarative_mixin
+ class HasStringCollection:
@declared_attr
def _strings(cls):
class StringAttribute(Base):
For example, to create a mixin that gives every class a simple table
name based on class name::
+ from sqlalchemy.orm import declarative_mixin
from sqlalchemy.orm import declared_attr
+ @declarative_mixin
class Tablename:
@declared_attr
def __tablename__(cls):
the effect of those subclasses being mapped with single table inheritance
against the parent::
+ from sqlalchemy.orm import declarative_mixin
from sqlalchemy.orm import declared_attr
from sqlalchemy.orm import has_inherited_table
- class Tablename(object):
+ @declarative_mixin
+ class Tablename:
@declared_attr
def __tablename__(cls):
if has_inherited_table(cls):
called ``id``; the mapping will fail on ``Engineer``, which is not given
a primary key::
- class HasId(object):
+ @declarative_mixin
+ class HasId:
@declared_attr
def id(cls):
return Column('id', Integer, primary_key=True)
function should be invoked **for each class in the hierarchy**, in *almost*
(see warning below) the same way as it does for ``__tablename__``::
- class HasIdMixin(object):
+ @declarative_mixin
+ class HasIdMixin:
@declared_attr.cascading
def id(cls):
if has_inherited_table(cls):
here to create user-defined collation routines that pull
from multiple collections::
+ from sqlalchemy.orm import declarative_mixin
from sqlalchemy.orm import declared_attr
- class MySQLSettings(object):
+ @declarative_mixin
+ class MySQLSettings:
__table_args__ = {'mysql_engine':'InnoDB'}
- class MyOtherMixin(object):
+ @declarative_mixin
+ class MyOtherMixin:
__table_args__ = {'info':'foo'}
class MyModel(MySQLSettings, MyOtherMixin, Base):
tables derived from a mixin, use the "inline" form of :class:`.Index` and
establish it as part of ``__table_args__``::
- class MyMixin(object):
+ @declarative_mixin
+ class MyMixin:
a = Column(Integer)
b = Column(Integer)
user: Mapped[User] = relationship(User, back_populates="addresses")
-Using @declared_attr
-^^^^^^^^^^^^^^^^^^^^
-
-The :class:`_orm.declared_attr` class allows Declarative mapped attributes
-to be declared in class level functions, and is particularly useful when
-using `declarative mixins <orm_mixins_toplevel>`_. For these functions,
-the return type of the function should be annotated using either the
-``Mapped[]`` construct or by indicating the exact kind of object returned
-by the function::
-
- from sqlalchemy.orm.decl_api import declared_attr
-
+.. _mypy_declarative_mixins:
+
+Using @declared_attr and Declarative Mixins
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+The :class:`_orm.declared_attr` class allows Declarative mapped attributes to
+be declared in class level functions, and is particularly useful when using
+`declarative mixins <orm_mixins_toplevel>`_. For these functions, the return
+type of the function should be annotated using either the ``Mapped[]``
+construct or by indicating the exact kind of object returned by the function.
+Additionally, "mixin" classes that are not otherwise mapped (i.e. don't extend
+from a :func:`_orm.declarative_base` class nor are they mapped with a method
+such as :meth:`_orm.registry.mapped`) should be decorated with the
+:func:`_orm.declarative_mixin` decorator, which provides a hint to the Mypy
+plugin that a particular class intends to serve as a declarative mixin::
+
+ from sqlalchemy.orm import declared_attr
+ from sqlalchemy.orm import declarative_mixin
+
+ @declarative_mixin
class HasUpdatedAt:
@declared_attr
def updated_at(cls) -> Column[DateTime]: # uses Column
return Column(DateTime)
+ @declarative_mixin
class HasCompany:
@declared_attr
.. autofunction:: declarative_base
+.. autofunction:: declarative_mixin
+
.. autofunction:: as_declarative
.. autoclass:: declared_attr
--- /dev/null
+# ext/mypy/apply.py
+# Copyright (C) 2021 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
+
+from typing import Optional
+from typing import Union
+
+from mypy import nodes
+from mypy.nodes import ARG_NAMED_OPT
+from mypy.nodes import Argument
+from mypy.nodes import AssignmentStmt
+from mypy.nodes import CallExpr
+from mypy.nodes import ClassDef
+from mypy.nodes import MDEF
+from mypy.nodes import NameExpr
+from mypy.nodes import StrExpr
+from mypy.nodes import SymbolTableNode
+from mypy.nodes import TempNode
+from mypy.nodes import TypeInfo
+from mypy.nodes import Var
+from mypy.plugin import SemanticAnalyzerPluginInterface
+from mypy.plugins.common import add_method_to_class
+from mypy.types import AnyType
+from mypy.types import Instance
+from mypy.types import NoneTyp
+from mypy.types import TypeOfAny
+from mypy.types import UnionType
+
+from . import util
+
+
+def _apply_mypy_mapped_attr(
+ cls: ClassDef,
+ api: SemanticAnalyzerPluginInterface,
+ item: Union[NameExpr, StrExpr],
+ cls_metadata: util.DeclClassApplied,
+):
+ if isinstance(item, NameExpr):
+ name = item.name
+ elif isinstance(item, StrExpr):
+ name = item.value
+ else:
+ return
+
+ for stmt in cls.defs.body:
+ if isinstance(stmt, AssignmentStmt) and stmt.lvalues[0].name == name:
+ break
+ else:
+ util.fail(api, "Can't find mapped attribute {}".format(name), cls)
+ return
+
+ if stmt.type is None:
+ util.fail(
+ api,
+ "Statement linked from _mypy_mapped_attrs has no "
+ "typing information",
+ stmt,
+ )
+ return
+
+ left_hand_explicit_type = stmt.type
+
+ cls_metadata.mapped_attr_names.append((name, left_hand_explicit_type))
+
+ _apply_type_to_mapped_statement(
+ api, stmt, stmt.lvalues[0], left_hand_explicit_type, None
+ )
+
+
+def _re_apply_declarative_assignments(
+ cls: ClassDef,
+ api: SemanticAnalyzerPluginInterface,
+ cls_metadata: util.DeclClassApplied,
+):
+ """For multiple class passes, re-apply our left-hand side types as mypy
+ seems to reset them in place.
+
+ """
+ mapped_attr_lookup = {
+ name: typ for name, typ in cls_metadata.mapped_attr_names
+ }
+
+ descriptor = api.lookup("__sa_Mapped", cls)
+ for stmt in cls.defs.body:
+ # for a re-apply, all of our statements are AssignmentStmt;
+ # @declared_attr calls will have been converted and this
+ # currently seems to be preserved by mypy (but who knows if this
+ # will change).
+ if (
+ isinstance(stmt, AssignmentStmt)
+ and stmt.lvalues[0].name in mapped_attr_lookup
+ ):
+ typ = mapped_attr_lookup[stmt.lvalues[0].name]
+ left_node = stmt.lvalues[0].node
+
+ inst = Instance(descriptor.node, [typ])
+ left_node.type = inst
+
+
+def _apply_type_to_mapped_statement(
+ api: SemanticAnalyzerPluginInterface,
+ stmt: AssignmentStmt,
+ lvalue: NameExpr,
+ left_hand_explicit_type: Optional[Union[Instance, UnionType]],
+ python_type_for_type: Union[Instance, UnionType],
+) -> None:
+ """Apply the Mapped[<type>] annotation and right hand object to a
+ declarative assignment statement.
+
+ This converts a Python declarative class statement such as::
+
+ class User(Base):
+ # ...
+
+ attrname = Column(Integer)
+
+ To one that describes the final Python behavior to Mypy::
+
+ class User(Base):
+ # ...
+
+ attrname : Mapped[Optional[int]] = <meaningless temp node>
+
+ """
+ descriptor = api.lookup("__sa_Mapped", stmt)
+ left_node = lvalue.node
+
+ inst = Instance(descriptor.node, [python_type_for_type])
+
+ if left_hand_explicit_type is not None:
+ left_node.type = Instance(descriptor.node, [left_hand_explicit_type])
+ else:
+ lvalue.is_inferred_def = False
+ left_node.type = inst
+
+ # so to have it skip the right side totally, we can do this:
+ # stmt.rvalue = TempNode(AnyType(TypeOfAny.special_form))
+
+ # however, if we instead manufacture a new node that uses the old
+ # one, then we can still get type checking for the call itself,
+ # e.g. the Column, relationship() call, etc.
+
+ # rewrite the node as:
+ # <attr> : Mapped[<typ>] =
+ # _sa_Mapped._empty_constructor(<original CallExpr from rvalue>)
+ # the original right-hand side is maintained so it gets type checked
+ # internally
+ column_descriptor = nodes.NameExpr("__sa_Mapped")
+ column_descriptor.fullname = "sqlalchemy.orm.Mapped"
+ mm = nodes.MemberExpr(column_descriptor, "_empty_constructor")
+ orig_call_expr = stmt.rvalue
+ stmt.rvalue = CallExpr(mm, [orig_call_expr], [nodes.ARG_POS], ["arg1"])
+
+
+def _add_additional_orm_attributes(
+ cls: ClassDef,
+ api: SemanticAnalyzerPluginInterface,
+ cls_metadata: util.DeclClassApplied,
+) -> None:
+ """Apply __init__, __table__ and other attributes to the mapped class."""
+
+ info = util._info_for_cls(cls, api)
+ if "__init__" not in info.names and cls_metadata.is_mapped:
+ mapped_attr_names = {n: t for n, t in cls_metadata.mapped_attr_names}
+
+ for mapped_base in cls_metadata.mapped_mro:
+ base_cls_metadata = util.DeclClassApplied.deserialize(
+ mapped_base.type.metadata["_sa_decl_class_applied"], api
+ )
+ for n, t in base_cls_metadata.mapped_attr_names:
+ mapped_attr_names.setdefault(n, t)
+
+ arguments = []
+ for name, typ in mapped_attr_names.items():
+ if typ is None:
+ typ = AnyType(TypeOfAny.special_form)
+ arguments.append(
+ Argument(
+ variable=Var(name, typ),
+ type_annotation=typ,
+ initializer=TempNode(typ),
+ kind=ARG_NAMED_OPT,
+ )
+ )
+ add_method_to_class(api, cls, "__init__", arguments, NoneTyp())
+
+ if "__table__" not in info.names and cls_metadata.has_table:
+ _apply_placeholder_attr_to_class(
+ api, cls, "sqlalchemy.sql.schema.Table", "__table__"
+ )
+ if cls_metadata.is_mapped:
+ _apply_placeholder_attr_to_class(
+ api, cls, "sqlalchemy.orm.mapper.Mapper", "__mapper__"
+ )
+
+
+def _apply_placeholder_attr_to_class(
+ api: SemanticAnalyzerPluginInterface,
+ cls: ClassDef,
+ qualified_name: str,
+ attrname: str,
+):
+ sym = api.lookup_fully_qualified_or_none(qualified_name)
+ if sym:
+ assert isinstance(sym.node, TypeInfo)
+ type_ = Instance(sym.node, [])
+ else:
+ type_ = AnyType(TypeOfAny.special_form)
+ var = Var(attrname)
+ var.info = cls.info
+ var.type = type_
+ cls.info.names[attrname] = SymbolTableNode(MDEF, var)
# the MIT License: http://www.opensource.org/licenses/mit-license.php
from typing import Optional
-from typing import Sequence
-from typing import Tuple
from typing import Type
-from typing import Union
from mypy import nodes
-from mypy import types
-from mypy.messages import format_type
-from mypy.nodes import ARG_NAMED_OPT
-from mypy.nodes import Argument
from mypy.nodes import AssignmentStmt
from mypy.nodes import CallExpr
from mypy.nodes import ClassDef
from mypy.nodes import Decorator
-from mypy.nodes import JsonDict
from mypy.nodes import ListExpr
-from mypy.nodes import MDEF
from mypy.nodes import NameExpr
from mypy.nodes import PlaceholderNode
from mypy.nodes import RefExpr
from mypy.nodes import TypeInfo
from mypy.nodes import Var
from mypy.plugin import SemanticAnalyzerPluginInterface
-from mypy.plugins.common import add_method_to_class
-from mypy.plugins.common import deserialize_and_fixup_type
-from mypy.subtypes import is_subtype
from mypy.types import AnyType
from mypy.types import Instance
-from mypy.types import NoneTyp
from mypy.types import NoneType
from mypy.types import TypeOfAny
from mypy.types import UnboundType
from mypy.types import UnionType
+from . import apply
+from . import infer
from . import names
from . import util
-class DeclClassApplied:
- def __init__(
- self,
- is_mapped: bool,
- has_table: bool,
- mapped_attr_names: Sequence[Tuple[str, Type]],
- mapped_mro: Sequence[Type],
- ):
- self.is_mapped = is_mapped
- self.has_table = has_table
- self.mapped_attr_names = mapped_attr_names
- self.mapped_mro = mapped_mro
-
- def serialize(self) -> JsonDict:
- return {
- "is_mapped": self.is_mapped,
- "has_table": self.has_table,
- "mapped_attr_names": [
- (name, type_.serialize())
- for name, type_ in self.mapped_attr_names
- ],
- "mapped_mro": [type_.serialize() for type_ in self.mapped_mro],
- }
-
- @classmethod
- def deserialize(
- cls, data: JsonDict, api: SemanticAnalyzerPluginInterface
- ) -> "DeclClassApplied":
-
- return DeclClassApplied(
- is_mapped=data["is_mapped"],
- has_table=data["has_table"],
- mapped_attr_names=[
- (name, deserialize_and_fixup_type(type_, api))
- for name, type_ in data["mapped_attr_names"]
- ],
- mapped_mro=[
- deserialize_and_fixup_type(type_, api)
- for type_ in data["mapped_mro"]
- ],
- )
-
-
def _scan_declarative_assignments_and_apply_types(
cls: ClassDef, api: SemanticAnalyzerPluginInterface, is_mixin_scan=False
-) -> Optional[DeclClassApplied]:
+) -> Optional[util.DeclClassApplied]:
+
+ info = util._info_for_cls(cls, api)
if cls.fullname.startswith("builtins"):
return None
- elif "_sa_decl_class_applied" in cls.info.metadata:
- cls_metadata = DeclClassApplied.deserialize(
- cls.info.metadata["_sa_decl_class_applied"], api
+ elif "_sa_decl_class_applied" in info.metadata:
+ cls_metadata = util.DeclClassApplied.deserialize(
+ info.metadata["_sa_decl_class_applied"], api
)
# ensure that a class that's mapped is always picked up by
# removing our ability to re-scan. but we have the types
# here, so lets re-apply them.
- _re_apply_declarative_assignments(cls, api, cls_metadata)
+ apply._re_apply_declarative_assignments(cls, api, cls_metadata)
return cls_metadata
- cls_metadata = DeclClassApplied(not is_mixin_scan, False, [], [])
+ cls_metadata = util.DeclClassApplied(not is_mixin_scan, False, [], [])
+
+ if not cls.defs.body:
+ # when we get a mixin class from another file, the body is
+ # empty (!) but the names are in the symbol table. so use that.
- for stmt in util._flatten_typechecking(cls.defs.body):
- if isinstance(stmt, AssignmentStmt):
- _scan_declarative_assignment_stmt(cls, api, stmt, cls_metadata)
- elif isinstance(stmt, Decorator):
- _scan_declarative_decorator_stmt(cls, api, stmt, cls_metadata)
+ for sym_name, sym in info.names.items():
+ _scan_symbol_table_entry(cls, api, sym_name, sym, cls_metadata)
+ else:
+ for stmt in util._flatten_typechecking(cls.defs.body):
+ if isinstance(stmt, AssignmentStmt):
+ _scan_declarative_assignment_stmt(cls, api, stmt, cls_metadata)
+ elif isinstance(stmt, Decorator):
+ _scan_declarative_decorator_stmt(cls, api, stmt, cls_metadata)
_scan_for_mapped_bases(cls, api, cls_metadata)
- _add_additional_orm_attributes(cls, api, cls_metadata)
- cls.info.metadata["_sa_decl_class_applied"] = cls_metadata.serialize()
+ if not is_mixin_scan:
+ apply._add_additional_orm_attributes(cls, api, cls_metadata)
+
+ info.metadata["_sa_decl_class_applied"] = cls_metadata.serialize()
return cls_metadata
+def _scan_symbol_table_entry(
+ cls: ClassDef,
+ api: SemanticAnalyzerPluginInterface,
+ name: str,
+ value: SymbolTableNode,
+ cls_metadata: util.DeclClassApplied,
+):
+ """Extract mapping information from a SymbolTableNode that's in the
+ type.names dictionary.
+
+ """
+ if not isinstance(value.type, Instance):
+ return
+
+ left_hand_explicit_type = None
+ type_id = names._type_id_for_named_node(value.type.type)
+ # type_id = names._type_id_for_unbound_type(value.type.type, cls, api)
+
+ err = False
+
+ # TODO: this is nearly the same logic as that of
+ # _scan_declarative_decorator_stmt, likely can be merged
+ if type_id in {
+ names.MAPPED,
+ names.RELATIONSHIP,
+ names.COMPOSITE_PROPERTY,
+ names.MAPPER_PROPERTY,
+ names.SYNONYM_PROPERTY,
+ names.COLUMN_PROPERTY,
+ }:
+ if value.type.args:
+ left_hand_explicit_type = value.type.args[0]
+ else:
+ err = True
+ elif type_id is names.COLUMN:
+ if not value.type.args:
+ err = True
+ else:
+ typeengine_arg = value.type.args[0]
+ if isinstance(typeengine_arg, Instance):
+ typeengine_arg = typeengine_arg.type
+
+ if isinstance(typeengine_arg, (UnboundType, TypeInfo)):
+ sym = api.lookup(typeengine_arg.name, typeengine_arg)
+ if sym is not None:
+ if names._mro_has_id(sym.node.mro, names.TYPEENGINE):
+
+ left_hand_explicit_type = UnionType(
+ [
+ infer._extract_python_type_from_typeengine(
+ api, sym.node, []
+ ),
+ NoneType(),
+ ]
+ )
+ else:
+ util.fail(
+ api,
+ "Column type should be a TypeEngine "
+ "subclass not '{}'".format(sym.node.fullname),
+ value.type,
+ )
+
+ if err:
+ msg = (
+ "Can't infer type from attribute {} on class {}. "
+ "please specify a return type from this function that is "
+ "one of: Mapped[<python type>], relationship[<target class>], "
+ "Column[<TypeEngine>], MapperProperty[<python type>]"
+ )
+ util.fail(api, msg.format(name, cls.name))
+
+ left_hand_explicit_type = AnyType(TypeOfAny.special_form)
+
+ if left_hand_explicit_type is not None:
+ cls_metadata.mapped_attr_names.append((name, left_hand_explicit_type))
+
+
def _scan_declarative_decorator_stmt(
cls: ClassDef,
api: SemanticAnalyzerPluginInterface,
stmt: Decorator,
- cls_metadata: DeclClassApplied,
+ cls_metadata: util.DeclClassApplied,
):
"""Extract mapping information from a @declared_attr in a declarative
class.
left_hand_explicit_type = UnionType(
[
- _extract_python_type_from_typeengine(
+ infer._extract_python_type_from_typeengine(
api, sym.node, []
),
NoneType(),
cls: ClassDef,
api: SemanticAnalyzerPluginInterface,
stmt: AssignmentStmt,
- cls_metadata: DeclClassApplied,
+ cls_metadata: util.DeclClassApplied,
):
"""Extract mapping information from an assignment statement in a
declarative class.
else:
for item in stmt.rvalue.items:
if isinstance(item, (NameExpr, StrExpr)):
- _apply_mypy_mapped_attr(cls, api, item, cls_metadata)
+ apply._apply_mypy_mapped_attr(cls, api, item, cls_metadata)
left_hand_mapped_type: Type = None
if type_id is None:
return
elif type_id is names.COLUMN:
- python_type_for_type = _infer_type_from_decl_column(
+ python_type_for_type = infer._infer_type_from_decl_column(
api, stmt, node, left_hand_explicit_type, stmt.rvalue
)
elif type_id is names.RELATIONSHIP:
- python_type_for_type = _infer_type_from_relationship(
+ python_type_for_type = infer._infer_type_from_relationship(
api, stmt, node, left_hand_explicit_type
)
elif type_id is names.COLUMN_PROPERTY:
- python_type_for_type = _infer_type_from_decl_column_property(
+ python_type_for_type = infer._infer_type_from_decl_column_property(
api, stmt, node, left_hand_explicit_type
)
elif type_id is names.SYNONYM_PROPERTY:
- python_type_for_type = _infer_type_from_left_hand_type_only(
+ python_type_for_type = infer._infer_type_from_left_hand_type_only(
api, node, left_hand_explicit_type
)
elif type_id is names.COMPOSITE_PROPERTY:
- python_type_for_type = _infer_type_from_decl_composite_property(
- api, stmt, node, left_hand_explicit_type
+ python_type_for_type = (
+ infer._infer_type_from_decl_composite_property(
+ api, stmt, node, left_hand_explicit_type
+ )
)
else:
return
assert python_type_for_type is not None
- _apply_type_to_mapped_statement(
+ apply._apply_type_to_mapped_statement(
api,
stmt,
lvalue,
)
-def _apply_mypy_mapped_attr(
- cls: ClassDef,
- api: SemanticAnalyzerPluginInterface,
- item: Union[NameExpr, StrExpr],
- cls_metadata: DeclClassApplied,
-):
- if isinstance(item, NameExpr):
- name = item.name
- elif isinstance(item, StrExpr):
- name = item.value
- else:
- return
-
- for stmt in cls.defs.body:
- if isinstance(stmt, AssignmentStmt) and stmt.lvalues[0].name == name:
- break
- else:
- util.fail(api, "Can't find mapped attribute {}".format(name), cls)
- return
-
- if stmt.type is None:
- util.fail(
- api,
- "Statement linked from _mypy_mapped_attrs has no "
- "typing information",
- stmt,
- )
- return
-
- left_hand_explicit_type = stmt.type
-
- cls_metadata.mapped_attr_names.append((name, left_hand_explicit_type))
-
- _apply_type_to_mapped_statement(
- api, stmt, stmt.lvalues[0], left_hand_explicit_type, None
- )
-
-
-def _infer_type_from_relationship(
- api: SemanticAnalyzerPluginInterface,
- stmt: AssignmentStmt,
- node: Var,
- left_hand_explicit_type: Optional[types.Type],
-) -> Union[Instance, UnionType, None]:
- """Infer the type of mapping from a relationship.
-
- E.g.::
-
- @reg.mapped
- class MyClass:
- # ...
-
- addresses = relationship(Address, uselist=True)
-
- order: Mapped["Order"] = relationship("Order")
-
- Will resolve in mypy as::
-
- @reg.mapped
- class MyClass:
- # ...
-
- addresses: Mapped[List[Address]]
-
- order: Mapped["Order"]
-
- """
-
- assert isinstance(stmt.rvalue, CallExpr)
- target_cls_arg = stmt.rvalue.args[0]
- python_type_for_type = None
-
- if isinstance(target_cls_arg, NameExpr) and isinstance(
- target_cls_arg.node, TypeInfo
- ):
- # type
- related_object_type = target_cls_arg.node
- python_type_for_type = Instance(related_object_type, [])
-
- # other cases not covered - an error message directs the user
- # to set an explicit type annotation
- #
- # node.type == str, it's a string
- # if isinstance(target_cls_arg, NameExpr) and isinstance(
- # target_cls_arg.node, Var
- # )
- # points to a type
- # isinstance(target_cls_arg, NameExpr) and isinstance(
- # target_cls_arg.node, TypeAlias
- # )
- # string expression
- # isinstance(target_cls_arg, StrExpr)
-
- uselist_arg = util._get_callexpr_kwarg(stmt.rvalue, "uselist")
- collection_cls_arg = util._get_callexpr_kwarg(
- stmt.rvalue, "collection_class"
- )
-
- # this can be used to determine Optional for a many-to-one
- # in the same way nullable=False could be used, if we start supporting
- # that.
- # innerjoin_arg = _get_callexpr_kwarg(stmt.rvalue, "innerjoin")
-
- if (
- uselist_arg is not None
- and uselist_arg.fullname == "builtins.True"
- and collection_cls_arg is None
- ):
- if python_type_for_type is not None:
- python_type_for_type = Instance(
- api.lookup_fully_qualified("builtins.list").node,
- [python_type_for_type],
- )
- elif (
- uselist_arg is None or uselist_arg.fullname == "builtins.True"
- ) and collection_cls_arg is not None:
- if isinstance(collection_cls_arg.node, TypeInfo):
- if python_type_for_type is not None:
- python_type_for_type = Instance(
- collection_cls_arg.node, [python_type_for_type]
- )
- else:
- util.fail(
- api,
- "Expected Python collection type for "
- "collection_class parameter",
- stmt.rvalue,
- )
- python_type_for_type = None
- elif uselist_arg is not None and uselist_arg.fullname == "builtins.False":
- if collection_cls_arg is not None:
- util.fail(
- api,
- "Sending uselist=False and collection_class at the same time "
- "does not make sense",
- stmt.rvalue,
- )
- if python_type_for_type is not None:
- python_type_for_type = UnionType(
- [python_type_for_type, NoneType()]
- )
-
- else:
- if left_hand_explicit_type is None:
- msg = (
- "Can't infer scalar or collection for ORM mapped expression "
- "assigned to attribute '{}' if both 'uselist' and "
- "'collection_class' arguments are absent from the "
- "relationship(); please specify a "
- "type annotation on the left hand side."
- )
- util.fail(api, msg.format(node.name), node)
-
- if python_type_for_type is None:
- return _infer_type_from_left_hand_type_only(
- api, node, left_hand_explicit_type
- )
- elif left_hand_explicit_type is not None:
- return _infer_type_from_left_and_inferred_right(
- api, node, left_hand_explicit_type, python_type_for_type
- )
- else:
- return python_type_for_type
-
-
-def _infer_type_from_decl_composite_property(
- api: SemanticAnalyzerPluginInterface,
- stmt: AssignmentStmt,
- node: Var,
- left_hand_explicit_type: Optional[types.Type],
-) -> Union[Instance, UnionType, None]:
- """Infer the type of mapping from a CompositeProperty."""
-
- assert isinstance(stmt.rvalue, CallExpr)
- target_cls_arg = stmt.rvalue.args[0]
- python_type_for_type = None
-
- if isinstance(target_cls_arg, NameExpr) and isinstance(
- target_cls_arg.node, TypeInfo
- ):
- related_object_type = target_cls_arg.node
- python_type_for_type = Instance(related_object_type, [])
- else:
- python_type_for_type = None
-
- if python_type_for_type is None:
- return _infer_type_from_left_hand_type_only(
- api, node, left_hand_explicit_type
- )
- elif left_hand_explicit_type is not None:
- return _infer_type_from_left_and_inferred_right(
- api, node, left_hand_explicit_type, python_type_for_type
- )
- else:
- return python_type_for_type
-
-
-def _infer_type_from_decl_column_property(
- api: SemanticAnalyzerPluginInterface,
- stmt: AssignmentStmt,
- node: Var,
- left_hand_explicit_type: Optional[types.Type],
-) -> Union[Instance, UnionType, None]:
- """Infer the type of mapping from a ColumnProperty.
-
- This includes mappings against ``column_property()`` as well as the
- ``deferred()`` function.
-
- """
- assert isinstance(stmt.rvalue, CallExpr)
- first_prop_arg = stmt.rvalue.args[0]
-
- if isinstance(first_prop_arg, CallExpr):
- type_id = names._type_id_for_callee(first_prop_arg.callee)
- else:
- type_id = None
-
- print(stmt.lvalues[0].name)
-
- # look for column_property() / deferred() etc with Column as first
- # argument
- if type_id is names.COLUMN:
- return _infer_type_from_decl_column(
- api, stmt, node, left_hand_explicit_type, first_prop_arg
- )
- else:
- return _infer_type_from_left_hand_type_only(
- api, node, left_hand_explicit_type
- )
-
-
-def _infer_type_from_decl_column(
- api: SemanticAnalyzerPluginInterface,
- stmt: AssignmentStmt,
- node: Var,
- left_hand_explicit_type: Optional[types.Type],
- right_hand_expression: CallExpr,
-) -> Union[Instance, UnionType, None]:
- """Infer the type of mapping from a Column.
-
- E.g.::
-
- @reg.mapped
- class MyClass:
- # ...
-
- a = Column(Integer)
-
- b = Column("b", String)
-
- c: Mapped[int] = Column(Integer)
-
- d: bool = Column(Boolean)
-
- Will resolve in MyPy as::
-
- @reg.mapped
- class MyClass:
- # ...
-
- a : Mapped[int]
-
- b : Mapped[str]
-
- c: Mapped[int]
-
- d: Mapped[bool]
-
- """
- assert isinstance(node, Var)
-
- callee = None
-
- for column_arg in right_hand_expression.args[0:2]:
- if isinstance(column_arg, nodes.CallExpr):
- # x = Column(String(50))
- callee = column_arg.callee
- type_args = column_arg.args
- break
- elif isinstance(column_arg, (nodes.NameExpr, nodes.MemberExpr)):
- if isinstance(column_arg.node, TypeInfo):
- # x = Column(String)
- callee = column_arg
- type_args = ()
- break
- else:
- # x = Column(some_name, String), go to next argument
- continue
- elif isinstance(column_arg, (StrExpr,)):
- # x = Column("name", String), go to next argument
- continue
- else:
- assert False
-
- if callee is None:
- return None
-
- if isinstance(callee.node, TypeInfo) and names._mro_has_id(
- callee.node.mro, names.TYPEENGINE
- ):
- python_type_for_type = _extract_python_type_from_typeengine(
- api, callee.node, type_args
- )
-
- if left_hand_explicit_type is not None:
-
- return _infer_type_from_left_and_inferred_right(
- api, node, left_hand_explicit_type, python_type_for_type
- )
-
- else:
- python_type_for_type = UnionType(
- [python_type_for_type, NoneType()]
- )
- return python_type_for_type
- else:
- # it's not TypeEngine, it's typically implicitly typed
- # like ForeignKey. we can't infer from the right side.
- return _infer_type_from_left_hand_type_only(
- api, node, left_hand_explicit_type
- )
-
-
-def _infer_type_from_left_and_inferred_right(
- api: SemanticAnalyzerPluginInterface,
- node: Var,
- left_hand_explicit_type: Optional[types.Type],
- python_type_for_type: Union[Instance, UnionType],
-) -> Optional[Union[Instance, UnionType]]:
- """Validate type when a left hand annotation is present and we also
- could infer the right hand side::
-
- attrname: SomeType = Column(SomeDBType)
-
- """
- if not is_subtype(left_hand_explicit_type, python_type_for_type):
- descriptor = api.lookup("__sa_Mapped", node)
-
- effective_type = Instance(descriptor.node, [python_type_for_type])
-
- msg = (
- "Left hand assignment '{}: {}' not compatible "
- "with ORM mapped expression of type {}"
- )
- util.fail(
- api,
- msg.format(
- node.name,
- format_type(left_hand_explicit_type),
- format_type(effective_type),
- ),
- node,
- )
-
- return left_hand_explicit_type
-
-
-def _infer_type_from_left_hand_type_only(
- api: SemanticAnalyzerPluginInterface,
- node: Var,
- left_hand_explicit_type: Optional[types.Type],
-) -> Optional[Union[Instance, UnionType]]:
- """Determine the type based on explicit annotation only.
-
- if no annotation were present, note that we need one there to know
- the type.
-
- """
- if left_hand_explicit_type is None:
- msg = (
- "Can't infer type from ORM mapped expression "
- "assigned to attribute '{}'; please specify a "
- "Python type or "
- "Mapped[<python type>] on the left hand side."
- )
- util.fail(api, msg.format(node.name), node)
-
- descriptor = api.lookup("__sa_Mapped", node)
- return Instance(descriptor.node, [AnyType(TypeOfAny.special_form)])
-
- else:
- # use type from the left hand side
- return left_hand_explicit_type
-
-
-def _re_apply_declarative_assignments(
- cls: ClassDef,
- api: SemanticAnalyzerPluginInterface,
- cls_metadata: DeclClassApplied,
-):
- """For multiple class passes, re-apply our left-hand side types as mypy
- seems to reset them in place.
-
- """
- mapped_attr_lookup = {
- name: typ for name, typ in cls_metadata.mapped_attr_names
- }
-
- descriptor = api.lookup("__sa_Mapped", cls)
- for stmt in cls.defs.body:
- # for a re-apply, all of our statements are AssignmentStmt;
- # @declared_attr calls will have been converted and this
- # currently seems to be preserved by mypy (but who knows if this
- # will change).
- if (
- isinstance(stmt, AssignmentStmt)
- and stmt.lvalues[0].name in mapped_attr_lookup
- ):
- typ = mapped_attr_lookup[stmt.lvalues[0].name]
- left_node = stmt.lvalues[0].node
-
- inst = Instance(descriptor.node, [typ])
- left_node.type = inst
-
-
-def _apply_type_to_mapped_statement(
- api: SemanticAnalyzerPluginInterface,
- stmt: AssignmentStmt,
- lvalue: NameExpr,
- left_hand_explicit_type: Optional[Union[Instance, UnionType]],
- python_type_for_type: Union[Instance, UnionType],
-) -> None:
- """Apply the Mapped[<type>] annotation and right hand object to a
- declarative assignment statement.
-
- This converts a Python declarative class statement such as::
-
- class User(Base):
- # ...
-
- attrname = Column(Integer)
-
- To one that describes the final Python behavior to Mypy::
-
- class User(Base):
- # ...
-
- attrname : Mapped[Optional[int]] = <meaningless temp node>
-
- """
- descriptor = api.lookup("__sa_Mapped", stmt)
- left_node = lvalue.node
-
- inst = Instance(descriptor.node, [python_type_for_type])
-
- if left_hand_explicit_type is not None:
- left_node.type = Instance(descriptor.node, [left_hand_explicit_type])
- else:
- lvalue.is_inferred_def = False
- left_node.type = inst
-
- # so to have it skip the right side totally, we can do this:
- # stmt.rvalue = TempNode(AnyType(TypeOfAny.special_form))
-
- # however, if we instead manufacture a new node that uses the old
- # one, then we can still get type checking for the call itself,
- # e.g. the Column, relationship() call, etc.
-
- # rewrite the node as:
- # <attr> : Mapped[<typ>] =
- # _sa_Mapped._empty_constructor(<original CallExpr from rvalue>)
- # the original right-hand side is maintained so it gets type checked
- # internally
- api.add_symbol_table_node("_sa_Mapped", descriptor)
- column_descriptor = nodes.NameExpr("_sa_Mapped")
- column_descriptor.fullname = "sqlalchemy.orm.Mapped"
- mm = nodes.MemberExpr(column_descriptor, "_empty_constructor")
- orig_call_expr = stmt.rvalue
- stmt.rvalue = CallExpr(
- mm,
- [orig_call_expr],
- [nodes.ARG_POS],
- ["arg1"],
- )
-
-
def _scan_for_mapped_bases(
cls: ClassDef,
api: SemanticAnalyzerPluginInterface,
- cls_metadata: DeclClassApplied,
+ cls_metadata: util.DeclClassApplied,
) -> None:
"""Given a class, iterate through its superclass hierarchy to find
all other classes that are considered as ORM-significant.
"""
- baseclasses = list(cls.info.bases)
+ info = util._info_for_cls(cls, api)
+
+ baseclasses = list(info.bases)
+
while baseclasses:
base: Instance = baseclasses.pop(0)
+ if base.type.fullname.startswith("builtins"):
+ continue
+
# scan each base for mapped attributes. if they are not already
- # scanned, that means they are unmapped mixins
+ # scanned (but have all their type info), that means they are unmapped
+ # mixins
base_decl_class_applied = (
_scan_declarative_assignments_and_apply_types(
base.type.defn, api, is_mixin_scan=True
)
)
- if base_decl_class_applied is not None:
+
+ if base_decl_class_applied not in (None, False):
cls_metadata.mapped_mro.append(base)
baseclasses.extend(base.type.bases)
-
-
-def _add_additional_orm_attributes(
- cls: ClassDef,
- api: SemanticAnalyzerPluginInterface,
- cls_metadata: DeclClassApplied,
-) -> None:
- """Apply __init__, __table__ and other attributes to the mapped class."""
- if "__init__" not in cls.info.names and cls_metadata.is_mapped:
- mapped_attr_names = {n: t for n, t in cls_metadata.mapped_attr_names}
-
- for mapped_base in cls_metadata.mapped_mro:
- base_cls_metadata = DeclClassApplied.deserialize(
- mapped_base.type.metadata["_sa_decl_class_applied"], api
- )
- for n, t in base_cls_metadata.mapped_attr_names:
- mapped_attr_names.setdefault(n, t)
-
- arguments = []
- for name, typ in mapped_attr_names.items():
- if typ is None:
- typ = AnyType(TypeOfAny.special_form)
- arguments.append(
- Argument(
- variable=Var(name, typ),
- type_annotation=typ,
- initializer=TempNode(typ),
- kind=ARG_NAMED_OPT,
- )
- )
- add_method_to_class(api, cls, "__init__", arguments, NoneTyp())
-
- if "__table__" not in cls.info.names and cls_metadata.has_table:
- _apply_placeholder_attr_to_class(
- api, cls, "sqlalchemy.sql.schema.Table", "__table__"
- )
- if cls_metadata.is_mapped:
- _apply_placeholder_attr_to_class(
- api, cls, "sqlalchemy.orm.mapper.Mapper", "__mapper__"
- )
-
-
-def _apply_placeholder_attr_to_class(
- api: SemanticAnalyzerPluginInterface,
- cls: ClassDef,
- qualified_name: str,
- attrname: str,
-):
- sym = api.lookup_fully_qualified_or_none(qualified_name)
- if sym:
- assert isinstance(sym.node, TypeInfo)
- type_ = Instance(sym.node, [])
- else:
- type_ = AnyType(TypeOfAny.special_form)
- var = Var(attrname)
- var.info = cls.info
- var.type = type_
- cls.info.names[attrname] = SymbolTableNode(MDEF, var)
-
-
-def _extract_python_type_from_typeengine(
- api: SemanticAnalyzerPluginInterface, node: TypeInfo, type_args
-) -> Instance:
- if node.fullname == "sqlalchemy.sql.sqltypes.Enum" and type_args:
- first_arg = type_args[0]
- if isinstance(first_arg, NameExpr) and isinstance(
- first_arg.node, TypeInfo
- ):
- for base_ in first_arg.node.mro:
- if base_.fullname == "enum.Enum":
- return Instance(first_arg.node, [])
- # TODO: support other pep-435 types here
- else:
- n = api.lookup_fully_qualified("builtins.str")
- return Instance(n.node, [])
-
- for mr in node.mro:
- if mr.bases:
- for base_ in mr.bases:
- if base_.type.fullname == "sqlalchemy.sql.type_api.TypeEngine":
- return base_.args[-1]
- assert False, "could not extract Python type from node: %s" % node
--- /dev/null
+# ext/mypy/infer.py
+# Copyright (C) 2021 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
+
+from typing import Optional
+from typing import Union
+
+from mypy import nodes
+from mypy import types
+from mypy.messages import format_type
+from mypy.nodes import AssignmentStmt
+from mypy.nodes import CallExpr
+from mypy.nodes import NameExpr
+from mypy.nodes import StrExpr
+from mypy.nodes import TypeInfo
+from mypy.nodes import Var
+from mypy.plugin import SemanticAnalyzerPluginInterface
+from mypy.subtypes import is_subtype
+from mypy.types import AnyType
+from mypy.types import Instance
+from mypy.types import NoneType
+from mypy.types import TypeOfAny
+from mypy.types import UnionType
+
+from . import names
+from . import util
+
+
+def _infer_type_from_relationship(
+ api: SemanticAnalyzerPluginInterface,
+ stmt: AssignmentStmt,
+ node: Var,
+ left_hand_explicit_type: Optional[types.Type],
+) -> Union[Instance, UnionType, None]:
+ """Infer the type of mapping from a relationship.
+
+ E.g.::
+
+ @reg.mapped
+ class MyClass:
+ # ...
+
+ addresses = relationship(Address, uselist=True)
+
+ order: Mapped["Order"] = relationship("Order")
+
+ Will resolve in mypy as::
+
+ @reg.mapped
+ class MyClass:
+ # ...
+
+ addresses: Mapped[List[Address]]
+
+ order: Mapped["Order"]
+
+ """
+
+ assert isinstance(stmt.rvalue, CallExpr)
+ target_cls_arg = stmt.rvalue.args[0]
+ python_type_for_type = None
+
+ if isinstance(target_cls_arg, NameExpr) and isinstance(
+ target_cls_arg.node, TypeInfo
+ ):
+ # type
+ related_object_type = target_cls_arg.node
+ python_type_for_type = Instance(related_object_type, [])
+
+ # other cases not covered - an error message directs the user
+ # to set an explicit type annotation
+ #
+ # node.type == str, it's a string
+ # if isinstance(target_cls_arg, NameExpr) and isinstance(
+ # target_cls_arg.node, Var
+ # )
+ # points to a type
+ # isinstance(target_cls_arg, NameExpr) and isinstance(
+ # target_cls_arg.node, TypeAlias
+ # )
+ # string expression
+ # isinstance(target_cls_arg, StrExpr)
+
+ uselist_arg = util._get_callexpr_kwarg(stmt.rvalue, "uselist")
+ collection_cls_arg = util._get_callexpr_kwarg(
+ stmt.rvalue, "collection_class"
+ )
+
+ # this can be used to determine Optional for a many-to-one
+ # in the same way nullable=False could be used, if we start supporting
+ # that.
+ # innerjoin_arg = _get_callexpr_kwarg(stmt.rvalue, "innerjoin")
+
+ if (
+ uselist_arg is not None
+ and uselist_arg.fullname == "builtins.True"
+ and collection_cls_arg is None
+ ):
+ if python_type_for_type is not None:
+ python_type_for_type = Instance(
+ api.lookup_fully_qualified("builtins.list").node,
+ [python_type_for_type],
+ )
+ elif (
+ uselist_arg is None or uselist_arg.fullname == "builtins.True"
+ ) and collection_cls_arg is not None:
+ if isinstance(collection_cls_arg.node, TypeInfo):
+ if python_type_for_type is not None:
+ python_type_for_type = Instance(
+ collection_cls_arg.node, [python_type_for_type]
+ )
+ else:
+ util.fail(
+ api,
+ "Expected Python collection type for "
+ "collection_class parameter",
+ stmt.rvalue,
+ )
+ python_type_for_type = None
+ elif uselist_arg is not None and uselist_arg.fullname == "builtins.False":
+ if collection_cls_arg is not None:
+ util.fail(
+ api,
+ "Sending uselist=False and collection_class at the same time "
+ "does not make sense",
+ stmt.rvalue,
+ )
+ if python_type_for_type is not None:
+ python_type_for_type = UnionType(
+ [python_type_for_type, NoneType()]
+ )
+
+ else:
+ if left_hand_explicit_type is None:
+ msg = (
+ "Can't infer scalar or collection for ORM mapped expression "
+ "assigned to attribute '{}' if both 'uselist' and "
+ "'collection_class' arguments are absent from the "
+ "relationship(); please specify a "
+ "type annotation on the left hand side."
+ )
+ util.fail(api, msg.format(node.name), node)
+
+ if python_type_for_type is None:
+ return _infer_type_from_left_hand_type_only(
+ api, node, left_hand_explicit_type
+ )
+ elif left_hand_explicit_type is not None:
+ return _infer_type_from_left_and_inferred_right(
+ api, node, left_hand_explicit_type, python_type_for_type
+ )
+ else:
+ return python_type_for_type
+
+
+def _infer_type_from_decl_composite_property(
+ api: SemanticAnalyzerPluginInterface,
+ stmt: AssignmentStmt,
+ node: Var,
+ left_hand_explicit_type: Optional[types.Type],
+) -> Union[Instance, UnionType, None]:
+ """Infer the type of mapping from a CompositeProperty."""
+
+ assert isinstance(stmt.rvalue, CallExpr)
+ target_cls_arg = stmt.rvalue.args[0]
+ python_type_for_type = None
+
+ if isinstance(target_cls_arg, NameExpr) and isinstance(
+ target_cls_arg.node, TypeInfo
+ ):
+ related_object_type = target_cls_arg.node
+ python_type_for_type = Instance(related_object_type, [])
+ else:
+ python_type_for_type = None
+
+ if python_type_for_type is None:
+ return _infer_type_from_left_hand_type_only(
+ api, node, left_hand_explicit_type
+ )
+ elif left_hand_explicit_type is not None:
+ return _infer_type_from_left_and_inferred_right(
+ api, node, left_hand_explicit_type, python_type_for_type
+ )
+ else:
+ return python_type_for_type
+
+
+def _infer_type_from_decl_column_property(
+ api: SemanticAnalyzerPluginInterface,
+ stmt: AssignmentStmt,
+ node: Var,
+ left_hand_explicit_type: Optional[types.Type],
+) -> Union[Instance, UnionType, None]:
+ """Infer the type of mapping from a ColumnProperty.
+
+ This includes mappings against ``column_property()`` as well as the
+ ``deferred()`` function.
+
+ """
+ assert isinstance(stmt.rvalue, CallExpr)
+ first_prop_arg = stmt.rvalue.args[0]
+
+ if isinstance(first_prop_arg, CallExpr):
+ type_id = names._type_id_for_callee(first_prop_arg.callee)
+ else:
+ type_id = None
+
+ # look for column_property() / deferred() etc with Column as first
+ # argument
+ if type_id is names.COLUMN:
+ return _infer_type_from_decl_column(
+ api, stmt, node, left_hand_explicit_type, first_prop_arg
+ )
+ else:
+ return _infer_type_from_left_hand_type_only(
+ api, node, left_hand_explicit_type
+ )
+
+
+def _infer_type_from_decl_column(
+ api: SemanticAnalyzerPluginInterface,
+ stmt: AssignmentStmt,
+ node: Var,
+ left_hand_explicit_type: Optional[types.Type],
+ right_hand_expression: CallExpr,
+) -> Union[Instance, UnionType, None]:
+ """Infer the type of mapping from a Column.
+
+ E.g.::
+
+ @reg.mapped
+ class MyClass:
+ # ...
+
+ a = Column(Integer)
+
+ b = Column("b", String)
+
+ c: Mapped[int] = Column(Integer)
+
+ d: bool = Column(Boolean)
+
+ Will resolve in MyPy as::
+
+ @reg.mapped
+ class MyClass:
+ # ...
+
+ a : Mapped[int]
+
+ b : Mapped[str]
+
+ c: Mapped[int]
+
+ d: Mapped[bool]
+
+ """
+ assert isinstance(node, Var)
+
+ callee = None
+
+ for column_arg in right_hand_expression.args[0:2]:
+ if isinstance(column_arg, nodes.CallExpr):
+ # x = Column(String(50))
+ callee = column_arg.callee
+ type_args = column_arg.args
+ break
+ elif isinstance(column_arg, (nodes.NameExpr, nodes.MemberExpr)):
+ if isinstance(column_arg.node, TypeInfo):
+ # x = Column(String)
+ callee = column_arg
+ type_args = ()
+ break
+ else:
+ # x = Column(some_name, String), go to next argument
+ continue
+ elif isinstance(column_arg, (StrExpr,)):
+ # x = Column("name", String), go to next argument
+ continue
+ else:
+ assert False
+
+ if callee is None:
+ return None
+
+ if isinstance(callee.node, TypeInfo) and names._mro_has_id(
+ callee.node.mro, names.TYPEENGINE
+ ):
+ python_type_for_type = _extract_python_type_from_typeengine(
+ api, callee.node, type_args
+ )
+
+ if left_hand_explicit_type is not None:
+
+ return _infer_type_from_left_and_inferred_right(
+ api, node, left_hand_explicit_type, python_type_for_type
+ )
+
+ else:
+ python_type_for_type = UnionType(
+ [python_type_for_type, NoneType()]
+ )
+ return python_type_for_type
+ else:
+ # it's not TypeEngine, it's typically implicitly typed
+ # like ForeignKey. we can't infer from the right side.
+ return _infer_type_from_left_hand_type_only(
+ api, node, left_hand_explicit_type
+ )
+
+
+def _infer_type_from_left_and_inferred_right(
+ api: SemanticAnalyzerPluginInterface,
+ node: Var,
+ left_hand_explicit_type: Optional[types.Type],
+ python_type_for_type: Union[Instance, UnionType],
+) -> Optional[Union[Instance, UnionType]]:
+ """Validate type when a left hand annotation is present and we also
+ could infer the right hand side::
+
+ attrname: SomeType = Column(SomeDBType)
+
+ """
+ if not is_subtype(left_hand_explicit_type, python_type_for_type):
+ descriptor = api.lookup("__sa_Mapped", node)
+
+ effective_type = Instance(descriptor.node, [python_type_for_type])
+
+ msg = (
+ "Left hand assignment '{}: {}' not compatible "
+ "with ORM mapped expression of type {}"
+ )
+ util.fail(
+ api,
+ msg.format(
+ node.name,
+ format_type(left_hand_explicit_type),
+ format_type(effective_type),
+ ),
+ node,
+ )
+
+ return left_hand_explicit_type
+
+
+def _infer_type_from_left_hand_type_only(
+ api: SemanticAnalyzerPluginInterface,
+ node: Var,
+ left_hand_explicit_type: Optional[types.Type],
+) -> Optional[Union[Instance, UnionType]]:
+ """Determine the type based on explicit annotation only.
+
+ if no annotation were present, note that we need one there to know
+ the type.
+
+ """
+ if left_hand_explicit_type is None:
+ msg = (
+ "Can't infer type from ORM mapped expression "
+ "assigned to attribute '{}'; please specify a "
+ "Python type or "
+ "Mapped[<python type>] on the left hand side."
+ )
+ util.fail(api, msg.format(node.name), node)
+
+ descriptor = api.lookup("__sa_Mapped", node)
+ return Instance(descriptor.node, [AnyType(TypeOfAny.special_form)])
+
+ else:
+ # use type from the left hand side
+ return left_hand_explicit_type
+
+
+def _extract_python_type_from_typeengine(
+ api: SemanticAnalyzerPluginInterface, node: TypeInfo, type_args
+) -> Instance:
+ if node.fullname == "sqlalchemy.sql.sqltypes.Enum" and type_args:
+ first_arg = type_args[0]
+ if isinstance(first_arg, NameExpr) and isinstance(
+ first_arg.node, TypeInfo
+ ):
+ for base_ in first_arg.node.mro:
+ if base_.fullname == "enum.Enum":
+ return Instance(first_arg.node, [])
+ # TODO: support other pep-435 types here
+ else:
+ n = api.lookup_fully_qualified("builtins.str")
+ return Instance(n.node, [])
+
+ for mr in node.mro:
+ if mr.bases:
+ for base_ in mr.bases:
+ if base_.type.fullname == "sqlalchemy.sql.type_api.TypeEngine":
+ return base_.args[-1]
+ assert False, "could not extract Python type from node: %s" % node
MAPPER_PROPERTY = util.symbol("MAPPER_PROPERTY")
AS_DECLARATIVE = util.symbol("AS_DECLARATIVE")
AS_DECLARATIVE_BASE = util.symbol("AS_DECLARATIVE_BASE")
+DECLARATIVE_MIXIN = util.symbol("DECLARATIVE_MIXIN")
_lookup = {
"Column": (
"sqlalchemy.orm.declared_attr",
},
),
+ "declarative_mixin": (
+ DECLARATIVE_MIXIN,
+ {
+ "sqlalchemy.orm.decl_api.declarative_mixin",
+ "sqlalchemy.orm.declarative_mixin",
+ },
+ ),
}
# subclasses. but then you can just check it here from the "base"
# and get the same effect.
sym = self.lookup_fully_qualified(fullname)
+
if (
sym
and isinstance(sym.node, TypeInfo)
) -> Optional[Callable[[ClassDefContext], None]]:
sym = self.lookup_fully_qualified(fullname)
- if (
- sym is not None
- and names._type_id_for_named_node(sym.node)
- is names.MAPPED_DECORATOR
- ):
- return _cls_decorator_hook
- elif sym is not None and names._type_id_for_named_node(sym.node) in (
- names.AS_DECLARATIVE,
- names.AS_DECLARATIVE_BASE,
- ):
- return _base_cls_decorator_hook
+
+ if sym is not None:
+ type_id = names._type_id_for_named_node(sym.node)
+ if type_id is names.MAPPED_DECORATOR:
+ return _cls_decorator_hook
+ elif type_id in (
+ names.AS_DECLARATIVE,
+ names.AS_DECLARATIVE_BASE,
+ ):
+ return _base_cls_decorator_hook
+ elif type_id is names.DECLARATIVE_MIXIN:
+ return _declarative_mixin_hook
return None
decl_class._scan_declarative_assignments_and_apply_types(ctx.cls, ctx.api)
+def _declarative_mixin_hook(ctx: ClassDefContext) -> None:
+ _add_globals(ctx)
+ decl_class._scan_declarative_assignments_and_apply_types(
+ ctx.cls, ctx.api, is_mixin_scan=True
+ )
+
+
def _cls_decorator_hook(ctx: ClassDefContext) -> None:
_add_globals(ctx)
assert isinstance(ctx.reason, nodes.MemberExpr)
from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Type
from mypy.nodes import CallExpr
+from mypy.nodes import CLASSDEF_NO_INFO
from mypy.nodes import Context
from mypy.nodes import IfStmt
+from mypy.nodes import JsonDict
from mypy.nodes import NameExpr
from mypy.nodes import SymbolTableNode
+from mypy.nodes import TypeInfo
from mypy.plugin import SemanticAnalyzerPluginInterface
+from mypy.plugins.common import deserialize_and_fixup_type
from mypy.types import Instance
from mypy.types import NoneType
-from mypy.types import Type
from mypy.types import UnboundType
from mypy.types import UnionType
+class DeclClassApplied:
+ def __init__(
+ self,
+ is_mapped: bool,
+ has_table: bool,
+ mapped_attr_names: Sequence[Tuple[str, Type]],
+ mapped_mro: Sequence[Type],
+ ):
+ self.is_mapped = is_mapped
+ self.has_table = has_table
+ self.mapped_attr_names = mapped_attr_names
+ self.mapped_mro = mapped_mro
+
+ def serialize(self) -> JsonDict:
+ return {
+ "is_mapped": self.is_mapped,
+ "has_table": self.has_table,
+ "mapped_attr_names": [
+ (name, type_.serialize())
+ for name, type_ in self.mapped_attr_names
+ ],
+ "mapped_mro": [type_.serialize() for type_ in self.mapped_mro],
+ }
+
+ @classmethod
+ def deserialize(
+ cls, data: JsonDict, api: SemanticAnalyzerPluginInterface
+ ) -> "DeclClassApplied":
+
+ return DeclClassApplied(
+ is_mapped=data["is_mapped"],
+ has_table=data["has_table"],
+ mapped_attr_names=[
+ (name, deserialize_and_fixup_type(type_, api))
+ for name, type_ in data["mapped_attr_names"]
+ ],
+ mapped_mro=[
+ deserialize_and_fixup_type(type_, api)
+ for type_ in data["mapped_mro"]
+ ],
+ )
+
+
def fail(api: SemanticAnalyzerPluginInterface, msg: str, ctx: Context):
msg = "[SQLAlchemy Mypy plugin] %s" % msg
return api.fail(msg, ctx)
)
else:
return typ
+
+
+def _info_for_cls(cls, api):
+ if cls.info is CLASSDEF_NO_INFO:
+ sym = api.lookup(cls.name, cls)
+ if sym.node and isinstance(sym.node, TypeInfo):
+ info = sym.node
+ else:
+ info = cls.info
+
+ return info
from .context import QueryContext
from .decl_api import as_declarative
from .decl_api import declarative_base
+from .decl_api import declarative_mixin
from .decl_api import DeclarativeMeta
from .decl_api import declared_attr
from .decl_api import has_inherited_table
return declared_attr(fn, **self.kw)
+def declarative_mixin(cls):
+ """Mark a class as providing the feature of "declarative mixin".
+
+ E.g.::
+
+ from sqlalchemy.orm import declared_attr
+ from sqlalchemy.orm import declarative_mixin
+
+ @declarative_mixin
+ class MyMixin:
+
+ @declared_attr
+ def __tablename__(cls):
+ return cls.__name__.lower()
+
+ __table_args__ = {'mysql_engine': 'InnoDB'}
+ __mapper_args__= {'always_refresh': True}
+
+ id = Column(Integer, primary_key=True)
+
+ class MyModel(MyMixin, Base):
+ name = Column(String(1000))
+
+ The :func:`_orm.declarative_mixin` decorator currently does not modify
+ the given class in any way; it's current purpose is strictly to assist
+ the :ref:`Mypy plugin <mypy_toplevel>` in being able to identify
+ SQLAlchemy declarative mixin classes when no other context is present.
+
+ .. versionadded:: 1.4.6
+
+ .. seealso::
+
+ :ref:`orm_mixins_toplevel`
+
+ :ref:`mypy_declarative_mixins` - in the
+ :ref:`Mypy plugin documentation <mypy_toplevel>`
+
+ """ # noqa: E501
+
+ return cls
+
+
def declarative_base(
bind=None,
metadata=None,
--- /dev/null
+from typing import Callable
+
+from sqlalchemy import Column
+from sqlalchemy import Integer
+from sqlalchemy import String
+from sqlalchemy.orm import deferred
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm.decl_api import declarative_mixin
+from sqlalchemy.orm.decl_api import declared_attr
+from sqlalchemy.orm.interfaces import MapperProperty
+
+
+def some_other_decorator(fn: Callable[..., None]) -> Callable[..., None]:
+ return fn
+
+
+@declarative_mixin
+class HasAMixin:
+ x: Mapped[int] = Column(Integer)
+
+ y = Column(String)
+
+ @declared_attr
+ def data(cls) -> Column[String]:
+ return Column(String)
+
+ @declared_attr
+ def data2(cls) -> MapperProperty[str]:
+ return deferred(Column(String))
+
+ @some_other_decorator
+ def q(cls) -> None:
+ return None
--- /dev/null
+from sqlalchemy.orm import declarative_base
+
+Base = declarative_base()
--- /dev/null
+from sqlalchemy import Column
+from sqlalchemy import Integer
+from .base import Base
+
+
+class One(Base):
+ __tablename__ = "one"
+ id = Column(Integer, primary_key=True)
+
+
+o1 = One(id=5)
+
+One.id.in_([1, 2])
--- /dev/null
+--- a/one.py 2021-04-03 15:32:22.214287290 -0400
++++ b/one.py 2021-04-03 15:34:56.397398510 -0400
+@@ -1,15 +1,13 @@
+ from sqlalchemy import Column
+ from sqlalchemy import Integer
+-from sqlalchemy import String
+ from .base import Base
+
+
+ class One(Base):
+ __tablename__ = "one"
+ id = Column(Integer, primary_key=True)
+- name = Column(String(50))
+
+
+-o1 = One(id=5, name="name")
++o1 = One(id=5)
+
+ One.id.in_([1, 2])
--- /dev/null
+--- a/base.py 2021-04-03 16:36:30.201594994 -0400
++++ b/base.py 2021-04-03 16:38:26.404475025 -0400
+@@ -1,3 +1,15 @@
++from sqlalchemy import Column
++from sqlalchemy import Integer
++from sqlalchemy import String
+ from sqlalchemy.orm import declarative_base
++from sqlalchemy.orm import declarative_mixin
++from sqlalchemy.orm import Mapped
+
+ Base = declarative_base()
++
++
++@declarative_mixin
++class Mixin:
++ mixed = Column(String)
++
++ b_int: Mapped[int] = Column(Integer)
+--- a/one.py 2021-04-03 16:37:17.906956282 -0400
++++ b/one.py 2021-04-03 16:38:33.469528528 -0400
+@@ -1,13 +1,15 @@
+ from sqlalchemy import Column
+ from sqlalchemy import Integer
++
+ from .base import Base
++from .base import Mixin
+
+
+-class One(Base):
++class One(Mixin, Base):
+ __tablename__ = "one"
+ id = Column(Integer, primary_key=True)
+
+
+-o1 = One(id=5)
++o1 = One(id=5, mixed="mixed", b_int=5)
+
+ One.id.in_([1, 2])
class MypyPluginTest(fixtures.TestBase):
__requires__ = ("sqlalchemy2_stubs",)
+ @testing.fixture(scope="function")
+ def per_func_cachedir(self):
+ for item in self._cachedir():
+ yield item
+
@testing.fixture(scope="class")
def cachedir(self):
+ for item in self._cachedir():
+ yield item
+
+ def _cachedir(self):
with tempfile.TemporaryDirectory() as cachedir:
with open(
os.path.join(cachedir, "sqla_mypy_config.cfg"), "w"
*[(dirname,) for dirname in _incremental_dirs()], argnames="dirname"
)
@testing.requires.patch_library
- def test_incremental(self, mypy_runner, cachedir, dirname):
+ def test_incremental(self, mypy_runner, per_func_cachedir, dirname):
import patch
+ cachedir = per_func_cachedir
+
path = os.path.join(os.path.dirname(__file__), "incremental", dirname)
dest = os.path.join(cachedir, "mymodel")
os.mkdir(dest)
if patchfile is not None:
print("Applying patchfile %s" % patchfile)
patch_obj = patch.fromfile(os.path.join(path, patchfile))
- patch_obj.apply(1, dest)
+ assert patch_obj.apply(1, dest), (
+ "pathfile %s failed" % patchfile
+ )
print("running mypy against %s/mymodel" % cachedir)
result = mypy_runner(
"mymodel",
from sqlalchemy.orm import column_property
from sqlalchemy.orm import configure_mappers
from sqlalchemy.orm import declarative_base
+from sqlalchemy.orm import declarative_mixin
from sqlalchemy.orm import declared_attr
from sqlalchemy.orm import deferred
from sqlalchemy.orm import events as orm_events
eq_(obj.name, "testing")
eq_(obj.foo(), "bar1")
+ def test_declarative_mixin_decorator(self):
+
+ # note we are also making sure an "old style class" in Python 2,
+ # as we are now illustrating in all the docs for mixins, doesn't cause
+ # a problem....
+ @declarative_mixin
+ class MyMixin:
+
+ id = Column(
+ Integer, primary_key=True, test_needs_autoincrement=True
+ )
+
+ def foo(self):
+ return "bar" + str(self.id)
+
+ # ...as long as the mapped class itself is "new style", which will
+ # normally be the case for users using declarative_base
+ @mapper_registry.mapped
+ class MyModel(MyMixin, object):
+
+ __tablename__ = "test"
+ name = Column(String(100), nullable=False, index=True)
+
+ Base.metadata.create_all(testing.db)
+ session = fixture_session()
+ session.add(MyModel(name="testing"))
+ session.flush()
+ session.expunge_all()
+ obj = session.query(MyModel).one()
+ eq_(obj.id, 1)
+ eq_(obj.name, "testing")
+ eq_(obj.foo(), "bar1")
+
def test_unique_column(self):
class MyMixin(object):