]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Accommodate for callable fns for collection_class
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 6 Apr 2021 15:12:49 +0000 (11:12 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 6 Apr 2021 15:13:49 +0000 (11:13 -0400)
Fixed issue where the Mypy plugin would fail to interpret the
"collection_class" of a relationship if it were a callable and not a class.
Also improved type matching and error reporting for collection-oriented
relationships.

Fixes: #6205
Change-Id: If3cb0defd4d7336e06a3bb3a3e8d59ea34b4c98d

doc/build/changelog/unreleased_14/6205.rst [new file with mode: 0644]
lib/sqlalchemy/ext/mypy/infer.py
test/ext/mypy/files/orderinglist1.py [new file with mode: 0644]
test/ext/mypy/files/orderinglist2.py [new file with mode: 0644]

diff --git a/doc/build/changelog/unreleased_14/6205.rst b/doc/build/changelog/unreleased_14/6205.rst
new file mode 100644 (file)
index 0000000..8f79eb4
--- /dev/null
@@ -0,0 +1,9 @@
+.. change::
+    :tags: bug, mypy
+    :tickets: 6205
+
+    Fixed issue where the Mypy plugin would fail to interpret the
+    "collection_class" of a relationship if it were a callable and not a class.
+    Also improved type matching and error reporting for collection-oriented
+    relationships.
+
index 1d77e67d2e3284d60e0901ba56fddfa18949993c..49dd9fb7435fba5db26ad03785106583a4b9312c 100644 (file)
@@ -88,6 +88,7 @@ def _infer_type_from_relationship(
     collection_cls_arg = util._get_callexpr_kwarg(
         stmt.rvalue, "collection_class"
     )
+    type_is_a_collection = False
 
     # this can be used to determine Optional for a many-to-one
     # in the same way nullable=False could be used, if we start supporting
@@ -99,6 +100,7 @@ def _infer_type_from_relationship(
         and uselist_arg.fullname == "builtins.True"
         and collection_cls_arg is None
     ):
+        type_is_a_collection = True
         if python_type_for_type is not None:
             python_type_for_type = Instance(
                 api.lookup_fully_qualified("builtins.list").node,
@@ -107,8 +109,16 @@ def _infer_type_from_relationship(
     elif (
         uselist_arg is None or uselist_arg.fullname == "builtins.True"
     ) and collection_cls_arg is not None:
-        if isinstance(collection_cls_arg.node, TypeInfo):
+        type_is_a_collection = True
+        if isinstance(collection_cls_arg, CallExpr):
+            collection_cls_arg = collection_cls_arg.callee
+
+        if isinstance(collection_cls_arg, NameExpr) and isinstance(
+            collection_cls_arg.node, TypeInfo
+        ):
             if python_type_for_type is not None:
+                # this can still be overridden by the left hand side
+                # within _infer_Type_from_left_and_inferred_right
                 python_type_for_type = Instance(
                     collection_cls_arg.node, [python_type_for_type]
                 )
@@ -150,7 +160,11 @@ def _infer_type_from_relationship(
         )
     elif left_hand_explicit_type is not None:
         return _infer_type_from_left_and_inferred_right(
-            api, node, left_hand_explicit_type, python_type_for_type
+            api,
+            node,
+            left_hand_explicit_type,
+            python_type_for_type,
+            type_is_a_collection=type_is_a_collection,
         )
     else:
         return python_type_for_type
@@ -317,6 +331,7 @@ def _infer_type_from_left_and_inferred_right(
     node: Var,
     left_hand_explicit_type: Optional[types.Type],
     python_type_for_type: Union[Instance, UnionType],
+    type_is_a_collection: bool = False,
 ) -> Optional[Union[Instance, UnionType]]:
     """Validate type when a left hand annotation is present and we also
     could infer the right hand side::
@@ -324,10 +339,18 @@ def _infer_type_from_left_and_inferred_right(
         attrname: SomeType = Column(SomeDBType)
 
     """
+
+    orig_left_hand_type = left_hand_explicit_type
+    orig_python_type_for_type = python_type_for_type
+
+    if type_is_a_collection and left_hand_explicit_type.args:
+        left_hand_explicit_type = left_hand_explicit_type.args[0]
+        python_type_for_type = python_type_for_type.args[0]
+
     if not is_subtype(left_hand_explicit_type, python_type_for_type):
         descriptor = api.lookup("__sa_Mapped", node)
 
-        effective_type = Instance(descriptor.node, [python_type_for_type])
+        effective_type = Instance(descriptor.node, [orig_python_type_for_type])
 
         msg = (
             "Left hand assignment '{}: {}' not compatible "
@@ -337,13 +360,13 @@ def _infer_type_from_left_and_inferred_right(
             api,
             msg.format(
                 node.name,
-                format_type(left_hand_explicit_type),
+                format_type(orig_left_hand_type),
                 format_type(effective_type),
             ),
             node,
         )
 
-    return left_hand_explicit_type
+    return orig_left_hand_type
 
 
 def _infer_type_from_left_hand_type_only(
diff --git a/test/ext/mypy/files/orderinglist1.py b/test/ext/mypy/files/orderinglist1.py
new file mode 100644 (file)
index 0000000..cb06ba4
--- /dev/null
@@ -0,0 +1,25 @@
+from sqlalchemy import Column
+from sqlalchemy import ForeignKey
+from sqlalchemy import Integer
+from sqlalchemy.ext.orderinglist import OrderingList
+from sqlalchemy.orm import registry
+from sqlalchemy.orm import relationship
+
+mapper_registry: registry = registry()
+
+
+@mapper_registry.mapped
+class A:
+    __tablename__ = "a"
+    id = Column(Integer, primary_key=True)
+
+    # EXPECTED: Can't infer type from ORM mapped expression assigned to attribute 'parents'; please specify a Python type or Mapped[<python type>] on the left hand side.  # noqa
+    parents = relationship("A", collection_class=OrderingList("ordering"))
+    parent_id = Column(Integer, ForeignKey("a.id"))
+    ordering = Column(Integer)
+
+
+a1 = A(id=5, ordering=10)
+
+# EXPECTED_MYPY: Argument "parents" to "A" has incompatible type "List[A]"; expected "Mapped[Any]"  # noqa
+a2 = A(parents=[a1])
diff --git a/test/ext/mypy/files/orderinglist2.py b/test/ext/mypy/files/orderinglist2.py
new file mode 100644 (file)
index 0000000..2a1623e
--- /dev/null
@@ -0,0 +1,53 @@
+from typing import List
+
+from sqlalchemy import Column
+from sqlalchemy import ForeignKey
+from sqlalchemy import Integer
+from sqlalchemy.ext.orderinglist import OrderingList
+from sqlalchemy.orm import registry
+from sqlalchemy.orm import relationship
+
+mapper_registry: registry = registry()
+
+
+@mapper_registry.mapped
+class B:
+    __tablename__ = "b"
+    id = Column(Integer, primary_key=True)
+    parent_id = Column(Integer, ForeignKey("a.id"))
+    ordering = Column(Integer)
+
+
+@mapper_registry.mapped
+class C:
+    __tablename__ = "c"
+    id = Column(Integer, primary_key=True)
+    parent_id = Column(Integer, ForeignKey("a.id"))
+    ordering = Column(Integer)
+
+
+@mapper_registry.mapped
+class A:
+    __tablename__ = "a"
+    id = Column(Integer, primary_key=True)
+
+    bs = relationship(B, collection_class=OrderingList("ordering"))
+
+    bs_w_list: List[B] = relationship(
+        B, collection_class=OrderingList("ordering")
+    )
+
+    # EXPECTED: Left hand assignment 'cs: "List[B]"' not compatible with ORM mapped expression of type "Mapped[List[C]]"  # noqa
+    cs: List[B] = relationship(C, uselist=True)
+
+    # EXPECTED: Left hand assignment 'cs_2: "B"' not compatible with ORM mapped expression of type "Mapped[List[C]]"  # noqa
+    cs_2: B = relationship(C, uselist=True)
+
+
+b1 = B(ordering=10)
+
+# in this case, the plugin infers OrderingList as the type.  not great
+a1 = A(bs=OrderingList(b1))
+
+# so we want to support being able to override it at least
+a2 = A(bs_w_list=[b1])