From: Mike Bayer Date: Mon, 19 Apr 2021 22:03:12 +0000 (-0400) Subject: Re-infer statements that got more specific on subsequent pass X-Git-Tag: rel_1_4_10~5^2 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=0e1a011aa3091aa2d6d95b269ff6da518db8e1a3;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Re-infer statements that got more specific on subsequent pass Fixed issue where mypy plugin would not correctly interpret an explicit :class:`_orm.Mapped` annotation in conjunction with a :func:`_orm.relationship` that refers to a class by string name; the correct annotation would be downgraded to a less specific one leading to typing errors. The thing figured out here is that after we've already scanned a class in the semanal stage and created DeclClassApplied, when we are called again with that same DeclClassApplied, for this specific kind of case we actually now have *better* types than we did before, where the left side that looked like List?[Address?] now seems to say builtins.list[official.module.Address] - so let's take the right side expression again, this time embedded in our Mapped._empty_constructor() expression, and run the infer all over again just like mypy would. Just not setting the "wrong" type here fixed the test cases but by re-applying the whole infer we get the correct Mapped[] on the left side too. Fixes: #6255 Change-Id: Iafe7254374f685a8458c7a1db82aafc2ed6d0232 --- diff --git a/doc/build/changelog/unreleased_14/6255.rst b/doc/build/changelog/unreleased_14/6255.rst new file mode 100644 index 0000000000..0211fb3412 --- /dev/null +++ b/doc/build/changelog/unreleased_14/6255.rst @@ -0,0 +1,9 @@ +.. change:: + :tags: bug, mypy + :tickets: 6255 + + Fixed issue where mypy plugin would not correctly interpret an explicit + :class:`_orm.Mapped` annotation in conjunction with a + :func:`_orm.relationship` that refers to a class by string name; the + correct annotation would be downgraded to a less specific one leading to + typing errors. diff --git a/lib/sqlalchemy/ext/mypy/apply.py b/lib/sqlalchemy/ext/mypy/apply.py index 3662604373..5dc9ec0b17 100644 --- a/lib/sqlalchemy/ext/mypy/apply.py +++ b/lib/sqlalchemy/ext/mypy/apply.py @@ -15,6 +15,7 @@ from mypy.nodes import AssignmentStmt from mypy.nodes import CallExpr from mypy.nodes import ClassDef from mypy.nodes import MDEF +from mypy.nodes import MemberExpr from mypy.nodes import NameExpr from mypy.nodes import StrExpr from mypy.nodes import SymbolTableNode @@ -32,6 +33,7 @@ from mypy.types import TypeOfAny from mypy.types import UnboundType from mypy.types import UnionType +from . import infer from . import util @@ -92,6 +94,7 @@ def _re_apply_declarative_assignments( mapped_attr_lookup = { name: typ for name, typ in cls_metadata.mapped_attr_names } + update_cls_metadata = False for stmt in cls.defs.body: # for a re-apply, all of our statements are AssignmentStmt; @@ -104,10 +107,51 @@ def _re_apply_declarative_assignments( and stmt.lvalues[0].name in mapped_attr_lookup and isinstance(stmt.lvalues[0].node, Var) ): - typ = mapped_attr_lookup[stmt.lvalues[0].name] + left_node = stmt.lvalues[0].node + python_type_for_type = mapped_attr_lookup[stmt.lvalues[0].name] + # if we have scanned an UnboundType and now there's a more + # specific type than UnboundType, call the re-scan so we + # can get that set up correctly + if ( + isinstance(python_type_for_type, UnboundType) + and not isinstance(left_node.type, UnboundType) + and ( + isinstance(stmt.rvalue.callee, MemberExpr) + and stmt.rvalue.callee.expr.node.fullname + == "sqlalchemy.orm.attributes.Mapped" + and stmt.rvalue.callee.name == "_empty_constructor" + and isinstance(stmt.rvalue.args[0], CallExpr) + ) + ): + + python_type_for_type = ( + infer._infer_type_from_right_hand_nameexpr( + api, + stmt, + left_node, + left_node.type, + stmt.rvalue.args[0].callee, + ) + ) + + if python_type_for_type is None or isinstance( + python_type_for_type, UnboundType + ): + continue + + # update the DeclClassApplied with the better information + mapped_attr_lookup[stmt.lvalues[0].name] = python_type_for_type + update_cls_metadata = True + + left_node.type = api.named_type( + "__sa_Mapped", [python_type_for_type] + ) - left_node.type = api.named_type("__sa_Mapped", [typ]) + if update_cls_metadata: + cls_metadata.mapped_attr_names[:] = [ + (k, v) for k, v in mapped_attr_lookup.items() + ] def _apply_type_to_mapped_statement( diff --git a/lib/sqlalchemy/ext/mypy/decl_class.py b/lib/sqlalchemy/ext/mypy/decl_class.py index 8fac36342b..2870eeb6fb 100644 --- a/lib/sqlalchemy/ext/mypy/decl_class.py +++ b/lib/sqlalchemy/ext/mypy/decl_class.py @@ -63,10 +63,11 @@ def _scan_declarative_assignments_and_apply_types( if not is_mixin_scan: assert cls_metadata.is_mapped - # mypy can call us more than once. it then will have reset the + # mypy can call us more than once. it then *may* have reset the # left hand side of everything, but not the right that we removed, # removing our ability to re-scan. but we have the types - # here, so lets re-apply them. + # here, so lets re-apply them, or if we have an UnboundType, + # we can re-scan apply._re_apply_declarative_assignments(cls, api, cls_metadata) @@ -422,33 +423,11 @@ def _scan_declarative_assignment_stmt( stmt.rvalue.callee, RefExpr ): - type_id = names._type_id_for_callee(stmt.rvalue.callee) + python_type_for_type = infer._infer_type_from_right_hand_nameexpr( + api, stmt, node, left_hand_explicit_type, stmt.rvalue.callee + ) - if type_id is None: - return - elif type_id is names.COLUMN: - python_type_for_type = infer._infer_type_from_decl_column( - api, stmt, node, left_hand_explicit_type, stmt.rvalue - ) - elif type_id is names.RELATIONSHIP: - python_type_for_type = infer._infer_type_from_relationship( - api, stmt, node, left_hand_explicit_type - ) - elif type_id is names.COLUMN_PROPERTY: - python_type_for_type = infer._infer_type_from_decl_column_property( - api, stmt, node, left_hand_explicit_type - ) - elif type_id is names.SYNONYM_PROPERTY: - python_type_for_type = infer._infer_type_from_left_hand_type_only( - api, node, left_hand_explicit_type - ) - elif type_id is names.COMPOSITE_PROPERTY: - python_type_for_type = ( - infer._infer_type_from_decl_composite_property( - api, stmt, node, left_hand_explicit_type - ) - ) - else: + if python_type_for_type is None: return else: diff --git a/lib/sqlalchemy/ext/mypy/infer.py b/lib/sqlalchemy/ext/mypy/infer.py index d734d588e5..2fea6d340a 100644 --- a/lib/sqlalchemy/ext/mypy/infer.py +++ b/lib/sqlalchemy/ext/mypy/infer.py @@ -35,6 +35,44 @@ from . import names from . import util +def _infer_type_from_right_hand_nameexpr( + api: SemanticAnalyzerPluginInterface, + stmt: AssignmentStmt, + node: Var, + left_hand_explicit_type: Optional[ProperType], + infer_from_right_side: NameExpr, +) -> Optional[ProperType]: + + type_id = names._type_id_for_callee(infer_from_right_side) + + if type_id is None: + return None + elif type_id is names.COLUMN: + python_type_for_type = _infer_type_from_decl_column( + api, stmt, node, left_hand_explicit_type + ) + elif type_id is names.RELATIONSHIP: + python_type_for_type = _infer_type_from_relationship( + api, stmt, node, left_hand_explicit_type + ) + elif type_id is names.COLUMN_PROPERTY: + python_type_for_type = _infer_type_from_decl_column_property( + api, stmt, node, left_hand_explicit_type + ) + elif type_id is names.SYNONYM_PROPERTY: + python_type_for_type = _infer_type_from_left_hand_type_only( + api, node, left_hand_explicit_type + ) + elif type_id is names.COMPOSITE_PROPERTY: + python_type_for_type = _infer_type_from_decl_composite_property( + api, stmt, node, left_hand_explicit_type + ) + else: + return None + + return python_type_for_type + + def _infer_type_from_relationship( api: SemanticAnalyzerPluginInterface, stmt: AssignmentStmt, @@ -255,7 +293,11 @@ def _infer_type_from_decl_column_property( # argument if type_id is names.COLUMN: return _infer_type_from_decl_column( - api, stmt, node, left_hand_explicit_type, first_prop_arg + api, + stmt, + node, + left_hand_explicit_type, + right_hand_expression=first_prop_arg, ) return _infer_type_from_left_hand_type_only( @@ -268,7 +310,7 @@ def _infer_type_from_decl_column( stmt: AssignmentStmt, node: Var, left_hand_explicit_type: Optional[ProperType], - right_hand_expression: CallExpr, + right_hand_expression: Optional[CallExpr] = None, ) -> Optional[ProperType]: """Infer the type of mapping from a Column. @@ -305,6 +347,12 @@ def _infer_type_from_decl_column( callee = None + if right_hand_expression is None: + if not isinstance(stmt.rvalue, CallExpr): + return None + + right_hand_expression = stmt.rvalue + for column_arg in right_hand_expression.args[0:2]: if isinstance(column_arg, CallExpr): if isinstance(column_arg.callee, RefExpr): diff --git a/setup.cfg b/setup.cfg index 001a38e30f..744b858e82 100644 --- a/setup.cfg +++ b/setup.cfg @@ -120,6 +120,7 @@ per-file-ignores = # min mypy version 0.800 strict = True incremental = True +plugins = sqlalchemy.ext.mypy.plugin [mypy-sqlalchemy.*] ignore_errors = True diff --git a/test/ext/mypy/files/relationship_6255_one.py b/test/ext/mypy/files/relationship_6255_one.py new file mode 100644 index 0000000000..e5a180b479 --- /dev/null +++ b/test/ext/mypy/files/relationship_6255_one.py @@ -0,0 +1,51 @@ +from typing import List +from typing import Optional + +from sqlalchemy import Column +from sqlalchemy import ForeignKey +from sqlalchemy import Integer +from sqlalchemy import select +from sqlalchemy import String +from sqlalchemy.orm import declarative_base +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import relationship + +Base = declarative_base() + + +class User(Base): + __tablename__ = "user" + + id = Column(Integer, primary_key=True) + name = Column(String) + + addresses: Mapped[List["Address"]] = relationship( + "Address", back_populates="user" + ) + + @property + def some_property(self) -> List[Optional[int]]: + return [i.id for i in self.addresses] + + +class Address(Base): + __tablename__ = "address" + + id = Column(Integer, primary_key=True) + user_id: int = Column(ForeignKey("user.id")) + + user: "User" = relationship("User", back_populates="addresses") + + @property + def some_other_property(self) -> Optional[str]: + return self.user.name + + +# it's in the constructor, correct type +u1 = User(addresses=[Address()]) + +# knows it's an iterable +[x for x in u1.addresses] + +# knows it's Mapped +stmt = select(User).where(User.addresses.any(id=5)) diff --git a/test/ext/mypy/files/relationship_6255_three.py b/test/ext/mypy/files/relationship_6255_three.py new file mode 100644 index 0000000000..121d8de40a --- /dev/null +++ b/test/ext/mypy/files/relationship_6255_three.py @@ -0,0 +1,48 @@ +from typing import List +from typing import Optional + +from sqlalchemy import Column +from sqlalchemy import ForeignKey +from sqlalchemy import Integer +from sqlalchemy import select +from sqlalchemy import String +from sqlalchemy.orm import declarative_base +from sqlalchemy.orm import relationship + +Base = declarative_base() + + +class User(Base): + __tablename__ = "user" + + id = Column(Integer, primary_key=True) + name = Column(String) + + addresses: List["Address"] = relationship("Address", back_populates="user") + + @property + def some_property(self) -> List[Optional[int]]: + return [i.id for i in self.addresses] + + +class Address(Base): + __tablename__ = "address" + + id = Column(Integer, primary_key=True) + user_id: int = Column(ForeignKey("user.id")) + + user: "User" = relationship("User", back_populates="addresses") + + @property + def some_other_property(self) -> Optional[str]: + return self.user.name + + +# it's in the constructor, correct type +u1 = User(addresses=[Address()]) + +# knows it's an iterable +[x for x in u1.addresses] + +# knows it's Mapped +stmt = select(User).where(User.addresses.any(id=5)) diff --git a/test/ext/mypy/files/relationship_6255_two.py b/test/ext/mypy/files/relationship_6255_two.py new file mode 100644 index 0000000000..121d8de40a --- /dev/null +++ b/test/ext/mypy/files/relationship_6255_two.py @@ -0,0 +1,48 @@ +from typing import List +from typing import Optional + +from sqlalchemy import Column +from sqlalchemy import ForeignKey +from sqlalchemy import Integer +from sqlalchemy import select +from sqlalchemy import String +from sqlalchemy.orm import declarative_base +from sqlalchemy.orm import relationship + +Base = declarative_base() + + +class User(Base): + __tablename__ = "user" + + id = Column(Integer, primary_key=True) + name = Column(String) + + addresses: List["Address"] = relationship("Address", back_populates="user") + + @property + def some_property(self) -> List[Optional[int]]: + return [i.id for i in self.addresses] + + +class Address(Base): + __tablename__ = "address" + + id = Column(Integer, primary_key=True) + user_id: int = Column(ForeignKey("user.id")) + + user: "User" = relationship("User", back_populates="addresses") + + @property + def some_other_property(self) -> Optional[str]: + return self.user.name + + +# it's in the constructor, correct type +u1 = User(addresses=[Address()]) + +# knows it's an iterable +[x for x in u1.addresses] + +# knows it's Mapped +stmt = select(User).where(User.addresses.any(id=5)) diff --git a/test/ext/mypy/test_mypy_plugin_py3k.py b/test/ext/mypy/test_mypy_plugin_py3k.py index 4ab16540d3..c853f7be5d 100644 --- a/test/ext/mypy/test_mypy_plugin_py3k.py +++ b/test/ext/mypy/test_mypy_plugin_py3k.py @@ -171,9 +171,13 @@ class MypyPluginTest(fixtures.TestBase): errors.append(e) for num, is_mypy, msg in expected_errors: + msg = msg.replace("'", '"') prefix = "[SQLAlchemy Mypy plugin] " if not is_mypy else "" for idx, errmsg in enumerate(errors): - if f"{filename}:{num + 1}: error: {prefix}{msg}" in errmsg: + if ( + f"{filename}:{num + 1}: error: {prefix}{msg}" + in errmsg.replace("'", '"') + ): break else: continue