--- /dev/null
+.. change::
+ :tags: bug, mypy
+ :tickets: 6255
+
+ Fixed issue where mypy plugin would not correctly interpret an explicit
+ :class:`_orm.Mapped` annotation in conjunction with a
+ :func:`_orm.relationship` that refers to a class by string name; the
+ correct annotation would be downgraded to a less specific one leading to
+ typing errors.
from mypy.nodes import CallExpr
from mypy.nodes import ClassDef
from mypy.nodes import MDEF
+from mypy.nodes import MemberExpr
from mypy.nodes import NameExpr
from mypy.nodes import StrExpr
from mypy.nodes import SymbolTableNode
from mypy.types import UnboundType
from mypy.types import UnionType
+from . import infer
from . import util
mapped_attr_lookup = {
name: typ for name, typ in cls_metadata.mapped_attr_names
}
+ update_cls_metadata = False
for stmt in cls.defs.body:
# for a re-apply, all of our statements are AssignmentStmt;
and stmt.lvalues[0].name in mapped_attr_lookup
and isinstance(stmt.lvalues[0].node, Var)
):
- typ = mapped_attr_lookup[stmt.lvalues[0].name]
+
left_node = stmt.lvalues[0].node
+ python_type_for_type = mapped_attr_lookup[stmt.lvalues[0].name]
+ # if we have scanned an UnboundType and now there's a more
+ # specific type than UnboundType, call the re-scan so we
+ # can get that set up correctly
+ if (
+ isinstance(python_type_for_type, UnboundType)
+ and not isinstance(left_node.type, UnboundType)
+ and (
+ isinstance(stmt.rvalue.callee, MemberExpr)
+ and stmt.rvalue.callee.expr.node.fullname
+ == "sqlalchemy.orm.attributes.Mapped"
+ and stmt.rvalue.callee.name == "_empty_constructor"
+ and isinstance(stmt.rvalue.args[0], CallExpr)
+ )
+ ):
+
+ python_type_for_type = (
+ infer._infer_type_from_right_hand_nameexpr(
+ api,
+ stmt,
+ left_node,
+ left_node.type,
+ stmt.rvalue.args[0].callee,
+ )
+ )
+
+ if python_type_for_type is None or isinstance(
+ python_type_for_type, UnboundType
+ ):
+ continue
+
+ # update the DeclClassApplied with the better information
+ mapped_attr_lookup[stmt.lvalues[0].name] = python_type_for_type
+ update_cls_metadata = True
+
+ left_node.type = api.named_type(
+ "__sa_Mapped", [python_type_for_type]
+ )
- left_node.type = api.named_type("__sa_Mapped", [typ])
+ if update_cls_metadata:
+ cls_metadata.mapped_attr_names[:] = [
+ (k, v) for k, v in mapped_attr_lookup.items()
+ ]
def _apply_type_to_mapped_statement(
if not is_mixin_scan:
assert cls_metadata.is_mapped
- # mypy can call us more than once. it then will have reset the
+ # mypy can call us more than once. it then *may* have reset the
# left hand side of everything, but not the right that we removed,
# removing our ability to re-scan. but we have the types
- # here, so lets re-apply them.
+ # here, so lets re-apply them, or if we have an UnboundType,
+ # we can re-scan
apply._re_apply_declarative_assignments(cls, api, cls_metadata)
stmt.rvalue.callee, RefExpr
):
- type_id = names._type_id_for_callee(stmt.rvalue.callee)
+ python_type_for_type = infer._infer_type_from_right_hand_nameexpr(
+ api, stmt, node, left_hand_explicit_type, stmt.rvalue.callee
+ )
- if type_id is None:
- return
- elif type_id is names.COLUMN:
- python_type_for_type = infer._infer_type_from_decl_column(
- api, stmt, node, left_hand_explicit_type, stmt.rvalue
- )
- elif type_id is names.RELATIONSHIP:
- python_type_for_type = infer._infer_type_from_relationship(
- api, stmt, node, left_hand_explicit_type
- )
- elif type_id is names.COLUMN_PROPERTY:
- python_type_for_type = infer._infer_type_from_decl_column_property(
- api, stmt, node, left_hand_explicit_type
- )
- elif type_id is names.SYNONYM_PROPERTY:
- python_type_for_type = infer._infer_type_from_left_hand_type_only(
- api, node, left_hand_explicit_type
- )
- elif type_id is names.COMPOSITE_PROPERTY:
- python_type_for_type = (
- infer._infer_type_from_decl_composite_property(
- api, stmt, node, left_hand_explicit_type
- )
- )
- else:
+ if python_type_for_type is None:
return
else:
from . import util
+def _infer_type_from_right_hand_nameexpr(
+ api: SemanticAnalyzerPluginInterface,
+ stmt: AssignmentStmt,
+ node: Var,
+ left_hand_explicit_type: Optional[ProperType],
+ infer_from_right_side: NameExpr,
+) -> Optional[ProperType]:
+
+ type_id = names._type_id_for_callee(infer_from_right_side)
+
+ if type_id is None:
+ return None
+ elif type_id is names.COLUMN:
+ python_type_for_type = _infer_type_from_decl_column(
+ api, stmt, node, left_hand_explicit_type
+ )
+ elif type_id is names.RELATIONSHIP:
+ python_type_for_type = _infer_type_from_relationship(
+ api, stmt, node, left_hand_explicit_type
+ )
+ elif type_id is names.COLUMN_PROPERTY:
+ python_type_for_type = _infer_type_from_decl_column_property(
+ api, stmt, node, left_hand_explicit_type
+ )
+ elif type_id is names.SYNONYM_PROPERTY:
+ python_type_for_type = _infer_type_from_left_hand_type_only(
+ api, node, left_hand_explicit_type
+ )
+ elif type_id is names.COMPOSITE_PROPERTY:
+ python_type_for_type = _infer_type_from_decl_composite_property(
+ api, stmt, node, left_hand_explicit_type
+ )
+ else:
+ return None
+
+ return python_type_for_type
+
+
def _infer_type_from_relationship(
api: SemanticAnalyzerPluginInterface,
stmt: AssignmentStmt,
# argument
if type_id is names.COLUMN:
return _infer_type_from_decl_column(
- api, stmt, node, left_hand_explicit_type, first_prop_arg
+ api,
+ stmt,
+ node,
+ left_hand_explicit_type,
+ right_hand_expression=first_prop_arg,
)
return _infer_type_from_left_hand_type_only(
stmt: AssignmentStmt,
node: Var,
left_hand_explicit_type: Optional[ProperType],
- right_hand_expression: CallExpr,
+ right_hand_expression: Optional[CallExpr] = None,
) -> Optional[ProperType]:
"""Infer the type of mapping from a Column.
callee = None
+ if right_hand_expression is None:
+ if not isinstance(stmt.rvalue, CallExpr):
+ return None
+
+ right_hand_expression = stmt.rvalue
+
for column_arg in right_hand_expression.args[0:2]:
if isinstance(column_arg, CallExpr):
if isinstance(column_arg.callee, RefExpr):
# min mypy version 0.800
strict = True
incremental = True
+plugins = sqlalchemy.ext.mypy.plugin
[mypy-sqlalchemy.*]
ignore_errors = True
--- /dev/null
+from typing import List
+from typing import Optional
+
+from sqlalchemy import Column
+from sqlalchemy import ForeignKey
+from sqlalchemy import Integer
+from sqlalchemy import select
+from sqlalchemy import String
+from sqlalchemy.orm import declarative_base
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import relationship
+
+Base = declarative_base()
+
+
+class User(Base):
+ __tablename__ = "user"
+
+ id = Column(Integer, primary_key=True)
+ name = Column(String)
+
+ addresses: Mapped[List["Address"]] = relationship(
+ "Address", back_populates="user"
+ )
+
+ @property
+ def some_property(self) -> List[Optional[int]]:
+ return [i.id for i in self.addresses]
+
+
+class Address(Base):
+ __tablename__ = "address"
+
+ id = Column(Integer, primary_key=True)
+ user_id: int = Column(ForeignKey("user.id"))
+
+ user: "User" = relationship("User", back_populates="addresses")
+
+ @property
+ def some_other_property(self) -> Optional[str]:
+ return self.user.name
+
+
+# it's in the constructor, correct type
+u1 = User(addresses=[Address()])
+
+# knows it's an iterable
+[x for x in u1.addresses]
+
+# knows it's Mapped
+stmt = select(User).where(User.addresses.any(id=5))
--- /dev/null
+from typing import List
+from typing import Optional
+
+from sqlalchemy import Column
+from sqlalchemy import ForeignKey
+from sqlalchemy import Integer
+from sqlalchemy import select
+from sqlalchemy import String
+from sqlalchemy.orm import declarative_base
+from sqlalchemy.orm import relationship
+
+Base = declarative_base()
+
+
+class User(Base):
+ __tablename__ = "user"
+
+ id = Column(Integer, primary_key=True)
+ name = Column(String)
+
+ addresses: List["Address"] = relationship("Address", back_populates="user")
+
+ @property
+ def some_property(self) -> List[Optional[int]]:
+ return [i.id for i in self.addresses]
+
+
+class Address(Base):
+ __tablename__ = "address"
+
+ id = Column(Integer, primary_key=True)
+ user_id: int = Column(ForeignKey("user.id"))
+
+ user: "User" = relationship("User", back_populates="addresses")
+
+ @property
+ def some_other_property(self) -> Optional[str]:
+ return self.user.name
+
+
+# it's in the constructor, correct type
+u1 = User(addresses=[Address()])
+
+# knows it's an iterable
+[x for x in u1.addresses]
+
+# knows it's Mapped
+stmt = select(User).where(User.addresses.any(id=5))
--- /dev/null
+from typing import List
+from typing import Optional
+
+from sqlalchemy import Column
+from sqlalchemy import ForeignKey
+from sqlalchemy import Integer
+from sqlalchemy import select
+from sqlalchemy import String
+from sqlalchemy.orm import declarative_base
+from sqlalchemy.orm import relationship
+
+Base = declarative_base()
+
+
+class User(Base):
+ __tablename__ = "user"
+
+ id = Column(Integer, primary_key=True)
+ name = Column(String)
+
+ addresses: List["Address"] = relationship("Address", back_populates="user")
+
+ @property
+ def some_property(self) -> List[Optional[int]]:
+ return [i.id for i in self.addresses]
+
+
+class Address(Base):
+ __tablename__ = "address"
+
+ id = Column(Integer, primary_key=True)
+ user_id: int = Column(ForeignKey("user.id"))
+
+ user: "User" = relationship("User", back_populates="addresses")
+
+ @property
+ def some_other_property(self) -> Optional[str]:
+ return self.user.name
+
+
+# it's in the constructor, correct type
+u1 = User(addresses=[Address()])
+
+# knows it's an iterable
+[x for x in u1.addresses]
+
+# knows it's Mapped
+stmt = select(User).where(User.addresses.any(id=5))
errors.append(e)
for num, is_mypy, msg in expected_errors:
+ msg = msg.replace("'", '"')
prefix = "[SQLAlchemy Mypy plugin] " if not is_mypy else ""
for idx, errmsg in enumerate(errors):
- if f"{filename}:{num + 1}: error: {prefix}{msg}" in errmsg:
+ if (
+ f"{filename}:{num + 1}: error: {prefix}{msg}"
+ in errmsg.replace("'", '"')
+ ):
break
else:
continue