]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Re-infer statements that got more specific on subsequent pass
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 19 Apr 2021 22:03:12 +0000 (18:03 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 20 Apr 2021 17:05:49 +0000 (13:05 -0400)
Fixed issue where mypy plugin would not correctly interpret an explicit
:class:`_orm.Mapped` annotation in conjunction with a
:func:`_orm.relationship` that refers to a class by string name; the
correct annotation would be downgraded to a less specific one leading to
typing errors.

The thing figured out here is that after we've already scanned
a class in the semanal stage and created DeclClassApplied,
when we are called again with that same DeclClassApplied, for this
specific kind of case we actually now have *better* types than
we did before, where the left side that looked like
List?[Address?] now seems to say
builtins.list[official.module.Address] - so let's take the
right side expression again, this time embedded in our
Mapped._empty_constructor() expression, and run the infer
all over again just like mypy would.   Just not setting the
"wrong" type here fixed the test cases but by re-applying the
whole infer we get the correct Mapped[] on the left side too.

Fixes: #6255
Change-Id: Iafe7254374f685a8458c7a1db82aafc2ed6d0232

doc/build/changelog/unreleased_14/6255.rst [new file with mode: 0644]
lib/sqlalchemy/ext/mypy/apply.py
lib/sqlalchemy/ext/mypy/decl_class.py
lib/sqlalchemy/ext/mypy/infer.py
setup.cfg
test/ext/mypy/files/relationship_6255_one.py [new file with mode: 0644]
test/ext/mypy/files/relationship_6255_three.py [new file with mode: 0644]
test/ext/mypy/files/relationship_6255_two.py [new file with mode: 0644]
test/ext/mypy/test_mypy_plugin_py3k.py

diff --git a/doc/build/changelog/unreleased_14/6255.rst b/doc/build/changelog/unreleased_14/6255.rst
new file mode 100644 (file)
index 0000000..0211fb3
--- /dev/null
@@ -0,0 +1,9 @@
+.. change::
+    :tags: bug, mypy
+    :tickets: 6255
+
+    Fixed issue where mypy plugin would not correctly interpret an explicit
+    :class:`_orm.Mapped` annotation in conjunction with a
+    :func:`_orm.relationship` that refers to a class by string name; the
+    correct annotation would be downgraded to a less specific one leading to
+    typing errors.
index 3662604373da6443da7a88b57e96516478d6cc0d..5dc9ec0b1741eec15424f1e073e9d8449a3a1e14 100644 (file)
@@ -15,6 +15,7 @@ from mypy.nodes import AssignmentStmt
 from mypy.nodes import CallExpr
 from mypy.nodes import ClassDef
 from mypy.nodes import MDEF
+from mypy.nodes import MemberExpr
 from mypy.nodes import NameExpr
 from mypy.nodes import StrExpr
 from mypy.nodes import SymbolTableNode
@@ -32,6 +33,7 @@ from mypy.types import TypeOfAny
 from mypy.types import UnboundType
 from mypy.types import UnionType
 
+from . import infer
 from . import util
 
 
@@ -92,6 +94,7 @@ def _re_apply_declarative_assignments(
     mapped_attr_lookup = {
         name: typ for name, typ in cls_metadata.mapped_attr_names
     }
+    update_cls_metadata = False
 
     for stmt in cls.defs.body:
         # for a re-apply, all of our statements are AssignmentStmt;
@@ -104,10 +107,51 @@ def _re_apply_declarative_assignments(
             and stmt.lvalues[0].name in mapped_attr_lookup
             and isinstance(stmt.lvalues[0].node, Var)
         ):
-            typ = mapped_attr_lookup[stmt.lvalues[0].name]
+
             left_node = stmt.lvalues[0].node
+            python_type_for_type = mapped_attr_lookup[stmt.lvalues[0].name]
+            # if we have scanned an UnboundType and now there's a more
+            # specific type than UnboundType, call the re-scan so we
+            # can get that set up correctly
+            if (
+                isinstance(python_type_for_type, UnboundType)
+                and not isinstance(left_node.type, UnboundType)
+                and (
+                    isinstance(stmt.rvalue.callee, MemberExpr)
+                    and stmt.rvalue.callee.expr.node.fullname
+                    == "sqlalchemy.orm.attributes.Mapped"
+                    and stmt.rvalue.callee.name == "_empty_constructor"
+                    and isinstance(stmt.rvalue.args[0], CallExpr)
+                )
+            ):
+
+                python_type_for_type = (
+                    infer._infer_type_from_right_hand_nameexpr(
+                        api,
+                        stmt,
+                        left_node,
+                        left_node.type,
+                        stmt.rvalue.args[0].callee,
+                    )
+                )
+
+                if python_type_for_type is None or isinstance(
+                    python_type_for_type, UnboundType
+                ):
+                    continue
+
+                # update the DeclClassApplied with the better information
+                mapped_attr_lookup[stmt.lvalues[0].name] = python_type_for_type
+                update_cls_metadata = True
+
+            left_node.type = api.named_type(
+                "__sa_Mapped", [python_type_for_type]
+            )
 
-            left_node.type = api.named_type("__sa_Mapped", [typ])
+    if update_cls_metadata:
+        cls_metadata.mapped_attr_names[:] = [
+            (k, v) for k, v in mapped_attr_lookup.items()
+        ]
 
 
 def _apply_type_to_mapped_statement(
index 8fac36342b4244d993e2856f71408729bba41672..2870eeb6fbdd2fe14ae7406502b1f905eec53edf 100644 (file)
@@ -63,10 +63,11 @@ def _scan_declarative_assignments_and_apply_types(
         if not is_mixin_scan:
             assert cls_metadata.is_mapped
 
-            # mypy can call us more than once.  it then will have reset the
+            # mypy can call us more than once.  it then *may* have reset the
             # left hand side of everything, but not the right that we removed,
             # removing our ability to re-scan.   but we have the types
-            # here, so lets re-apply them.
+            # here, so lets re-apply them, or if we have an UnboundType,
+            # we can re-scan
 
             apply._re_apply_declarative_assignments(cls, api, cls_metadata)
 
@@ -422,33 +423,11 @@ def _scan_declarative_assignment_stmt(
         stmt.rvalue.callee, RefExpr
     ):
 
-        type_id = names._type_id_for_callee(stmt.rvalue.callee)
+        python_type_for_type = infer._infer_type_from_right_hand_nameexpr(
+            api, stmt, node, left_hand_explicit_type, stmt.rvalue.callee
+        )
 
-        if type_id is None:
-            return
-        elif type_id is names.COLUMN:
-            python_type_for_type = infer._infer_type_from_decl_column(
-                api, stmt, node, left_hand_explicit_type, stmt.rvalue
-            )
-        elif type_id is names.RELATIONSHIP:
-            python_type_for_type = infer._infer_type_from_relationship(
-                api, stmt, node, left_hand_explicit_type
-            )
-        elif type_id is names.COLUMN_PROPERTY:
-            python_type_for_type = infer._infer_type_from_decl_column_property(
-                api, stmt, node, left_hand_explicit_type
-            )
-        elif type_id is names.SYNONYM_PROPERTY:
-            python_type_for_type = infer._infer_type_from_left_hand_type_only(
-                api, node, left_hand_explicit_type
-            )
-        elif type_id is names.COMPOSITE_PROPERTY:
-            python_type_for_type = (
-                infer._infer_type_from_decl_composite_property(
-                    api, stmt, node, left_hand_explicit_type
-                )
-            )
-        else:
+        if python_type_for_type is None:
             return
 
     else:
index d734d588e53fd7d68e165ce2ff67eec54e0a61cf..2fea6d340ace2ed5a1f9e50d0ea49a37effe706b 100644 (file)
@@ -35,6 +35,44 @@ from . import names
 from . import util
 
 
+def _infer_type_from_right_hand_nameexpr(
+    api: SemanticAnalyzerPluginInterface,
+    stmt: AssignmentStmt,
+    node: Var,
+    left_hand_explicit_type: Optional[ProperType],
+    infer_from_right_side: NameExpr,
+) -> Optional[ProperType]:
+
+    type_id = names._type_id_for_callee(infer_from_right_side)
+
+    if type_id is None:
+        return None
+    elif type_id is names.COLUMN:
+        python_type_for_type = _infer_type_from_decl_column(
+            api, stmt, node, left_hand_explicit_type
+        )
+    elif type_id is names.RELATIONSHIP:
+        python_type_for_type = _infer_type_from_relationship(
+            api, stmt, node, left_hand_explicit_type
+        )
+    elif type_id is names.COLUMN_PROPERTY:
+        python_type_for_type = _infer_type_from_decl_column_property(
+            api, stmt, node, left_hand_explicit_type
+        )
+    elif type_id is names.SYNONYM_PROPERTY:
+        python_type_for_type = _infer_type_from_left_hand_type_only(
+            api, node, left_hand_explicit_type
+        )
+    elif type_id is names.COMPOSITE_PROPERTY:
+        python_type_for_type = _infer_type_from_decl_composite_property(
+            api, stmt, node, left_hand_explicit_type
+        )
+    else:
+        return None
+
+    return python_type_for_type
+
+
 def _infer_type_from_relationship(
     api: SemanticAnalyzerPluginInterface,
     stmt: AssignmentStmt,
@@ -255,7 +293,11 @@ def _infer_type_from_decl_column_property(
         # argument
         if type_id is names.COLUMN:
             return _infer_type_from_decl_column(
-                api, stmt, node, left_hand_explicit_type, first_prop_arg
+                api,
+                stmt,
+                node,
+                left_hand_explicit_type,
+                right_hand_expression=first_prop_arg,
             )
 
     return _infer_type_from_left_hand_type_only(
@@ -268,7 +310,7 @@ def _infer_type_from_decl_column(
     stmt: AssignmentStmt,
     node: Var,
     left_hand_explicit_type: Optional[ProperType],
-    right_hand_expression: CallExpr,
+    right_hand_expression: Optional[CallExpr] = None,
 ) -> Optional[ProperType]:
     """Infer the type of mapping from a Column.
 
@@ -305,6 +347,12 @@ def _infer_type_from_decl_column(
 
     callee = None
 
+    if right_hand_expression is None:
+        if not isinstance(stmt.rvalue, CallExpr):
+            return None
+
+        right_hand_expression = stmt.rvalue
+
     for column_arg in right_hand_expression.args[0:2]:
         if isinstance(column_arg, CallExpr):
             if isinstance(column_arg.callee, RefExpr):
index 001a38e30ffecc2dc7dce41848be529c0a7b4d11..744b858e82b0c81f7bf288823fe17f6fe488738c 100644 (file)
--- a/setup.cfg
+++ b/setup.cfg
@@ -120,6 +120,7 @@ per-file-ignores =
 # min mypy version 0.800
 strict = True
 incremental = True
+plugins = sqlalchemy.ext.mypy.plugin
 
 [mypy-sqlalchemy.*]
 ignore_errors = True
diff --git a/test/ext/mypy/files/relationship_6255_one.py b/test/ext/mypy/files/relationship_6255_one.py
new file mode 100644 (file)
index 0000000..e5a180b
--- /dev/null
@@ -0,0 +1,51 @@
+from typing import List
+from typing import Optional
+
+from sqlalchemy import Column
+from sqlalchemy import ForeignKey
+from sqlalchemy import Integer
+from sqlalchemy import select
+from sqlalchemy import String
+from sqlalchemy.orm import declarative_base
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import relationship
+
+Base = declarative_base()
+
+
+class User(Base):
+    __tablename__ = "user"
+
+    id = Column(Integer, primary_key=True)
+    name = Column(String)
+
+    addresses: Mapped[List["Address"]] = relationship(
+        "Address", back_populates="user"
+    )
+
+    @property
+    def some_property(self) -> List[Optional[int]]:
+        return [i.id for i in self.addresses]
+
+
+class Address(Base):
+    __tablename__ = "address"
+
+    id = Column(Integer, primary_key=True)
+    user_id: int = Column(ForeignKey("user.id"))
+
+    user: "User" = relationship("User", back_populates="addresses")
+
+    @property
+    def some_other_property(self) -> Optional[str]:
+        return self.user.name
+
+
+# it's in the constructor, correct type
+u1 = User(addresses=[Address()])
+
+# knows it's an iterable
+[x for x in u1.addresses]
+
+# knows it's Mapped
+stmt = select(User).where(User.addresses.any(id=5))
diff --git a/test/ext/mypy/files/relationship_6255_three.py b/test/ext/mypy/files/relationship_6255_three.py
new file mode 100644 (file)
index 0000000..121d8de
--- /dev/null
@@ -0,0 +1,48 @@
+from typing import List
+from typing import Optional
+
+from sqlalchemy import Column
+from sqlalchemy import ForeignKey
+from sqlalchemy import Integer
+from sqlalchemy import select
+from sqlalchemy import String
+from sqlalchemy.orm import declarative_base
+from sqlalchemy.orm import relationship
+
+Base = declarative_base()
+
+
+class User(Base):
+    __tablename__ = "user"
+
+    id = Column(Integer, primary_key=True)
+    name = Column(String)
+
+    addresses: List["Address"] = relationship("Address", back_populates="user")
+
+    @property
+    def some_property(self) -> List[Optional[int]]:
+        return [i.id for i in self.addresses]
+
+
+class Address(Base):
+    __tablename__ = "address"
+
+    id = Column(Integer, primary_key=True)
+    user_id: int = Column(ForeignKey("user.id"))
+
+    user: "User" = relationship("User", back_populates="addresses")
+
+    @property
+    def some_other_property(self) -> Optional[str]:
+        return self.user.name
+
+
+# it's in the constructor, correct type
+u1 = User(addresses=[Address()])
+
+# knows it's an iterable
+[x for x in u1.addresses]
+
+# knows it's Mapped
+stmt = select(User).where(User.addresses.any(id=5))
diff --git a/test/ext/mypy/files/relationship_6255_two.py b/test/ext/mypy/files/relationship_6255_two.py
new file mode 100644 (file)
index 0000000..121d8de
--- /dev/null
@@ -0,0 +1,48 @@
+from typing import List
+from typing import Optional
+
+from sqlalchemy import Column
+from sqlalchemy import ForeignKey
+from sqlalchemy import Integer
+from sqlalchemy import select
+from sqlalchemy import String
+from sqlalchemy.orm import declarative_base
+from sqlalchemy.orm import relationship
+
+Base = declarative_base()
+
+
+class User(Base):
+    __tablename__ = "user"
+
+    id = Column(Integer, primary_key=True)
+    name = Column(String)
+
+    addresses: List["Address"] = relationship("Address", back_populates="user")
+
+    @property
+    def some_property(self) -> List[Optional[int]]:
+        return [i.id for i in self.addresses]
+
+
+class Address(Base):
+    __tablename__ = "address"
+
+    id = Column(Integer, primary_key=True)
+    user_id: int = Column(ForeignKey("user.id"))
+
+    user: "User" = relationship("User", back_populates="addresses")
+
+    @property
+    def some_other_property(self) -> Optional[str]:
+        return self.user.name
+
+
+# it's in the constructor, correct type
+u1 = User(addresses=[Address()])
+
+# knows it's an iterable
+[x for x in u1.addresses]
+
+# knows it's Mapped
+stmt = select(User).where(User.addresses.any(id=5))
index 4ab16540d39274a6209937fc82c9b908f8c35174..c853f7be5d76c4ed06c0eb0e4fb7b57aee822100 100644 (file)
@@ -171,9 +171,13 @@ class MypyPluginTest(fixtures.TestBase):
                     errors.append(e)
 
             for num, is_mypy, msg in expected_errors:
+                msg = msg.replace("'", '"')
                 prefix = "[SQLAlchemy Mypy plugin] " if not is_mypy else ""
                 for idx, errmsg in enumerate(errors):
-                    if f"{filename}:{num + 1}: error: {prefix}{msg}" in errmsg:
+                    if (
+                        f"{filename}:{num + 1}: error: {prefix}{msg}"
+                        in errmsg.replace("'", '"')
+                    ):
                         break
                 else:
                     continue