]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Update mypy plugin to conform to strict mode
authorBryan Forbes <bryan@reigndropsfall.net>
Mon, 12 Apr 2021 21:24:37 +0000 (16:24 -0500)
committerBryan Forbes <bryan@reigndropsfall.net>
Mon, 12 Apr 2021 21:24:37 +0000 (16:24 -0500)
Change-Id: I09a3df5af2f2d4ee34d8d72c3dedc4f236df8eb1

lib/sqlalchemy/ext/mypy/apply.py
lib/sqlalchemy/ext/mypy/decl_class.py
lib/sqlalchemy/ext/mypy/infer.py
lib/sqlalchemy/ext/mypy/names.py
lib/sqlalchemy/ext/mypy/plugin.py
lib/sqlalchemy/ext/mypy/util.py
setup.cfg

index 0f4bb1fd9b30024e0a8dc83b35b31be328b678a0..3662604373da6443da7a88b57e96516478d6cc0d 100644 (file)
@@ -24,9 +24,12 @@ from mypy.nodes import Var
 from mypy.plugin import SemanticAnalyzerPluginInterface
 from mypy.plugins.common import add_method_to_class
 from mypy.types import AnyType
+from mypy.types import get_proper_type
 from mypy.types import Instance
 from mypy.types import NoneTyp
+from mypy.types import ProperType
 from mypy.types import TypeOfAny
+from mypy.types import UnboundType
 from mypy.types import UnionType
 
 from . import util
@@ -37,7 +40,7 @@ def _apply_mypy_mapped_attr(
     api: SemanticAnalyzerPluginInterface,
     item: Union[NameExpr, StrExpr],
     cls_metadata: util.DeclClassApplied,
-):
+) -> None:
     if isinstance(item, NameExpr):
         name = item.name
     elif isinstance(item, StrExpr):
@@ -46,7 +49,11 @@ def _apply_mypy_mapped_attr(
         return
 
     for stmt in cls.defs.body:
-        if isinstance(stmt, AssignmentStmt) and stmt.lvalues[0].name == name:
+        if (
+            isinstance(stmt, AssignmentStmt)
+            and isinstance(stmt.lvalues[0], NameExpr)
+            and stmt.lvalues[0].name == name
+        ):
             break
     else:
         util.fail(api, "Can't find mapped attribute {}".format(name), cls)
@@ -61,7 +68,10 @@ def _apply_mypy_mapped_attr(
         )
         return
 
-    left_hand_explicit_type = stmt.type
+    left_hand_explicit_type = get_proper_type(stmt.type)
+    assert isinstance(
+        left_hand_explicit_type, (Instance, UnionType, UnboundType)
+    )
 
     cls_metadata.mapped_attr_names.append((name, left_hand_explicit_type))
 
@@ -74,7 +84,7 @@ def _re_apply_declarative_assignments(
     cls: ClassDef,
     api: SemanticAnalyzerPluginInterface,
     cls_metadata: util.DeclClassApplied,
-):
+) -> None:
     """For multiple class passes, re-apply our left-hand side types as mypy
     seems to reset them in place.
 
@@ -90,7 +100,9 @@ def _re_apply_declarative_assignments(
         # will change).
         if (
             isinstance(stmt, AssignmentStmt)
+            and isinstance(stmt.lvalues[0], NameExpr)
             and stmt.lvalues[0].name in mapped_attr_lookup
+            and isinstance(stmt.lvalues[0].node, Var)
         ):
             typ = mapped_attr_lookup[stmt.lvalues[0].name]
             left_node = stmt.lvalues[0].node
@@ -102,8 +114,8 @@ def _apply_type_to_mapped_statement(
     api: SemanticAnalyzerPluginInterface,
     stmt: AssignmentStmt,
     lvalue: NameExpr,
-    left_hand_explicit_type: Optional[Union[Instance, UnionType]],
-    python_type_for_type: Union[Instance, UnionType],
+    left_hand_explicit_type: Optional[ProperType],
+    python_type_for_type: Optional[ProperType],
 ) -> None:
     """Apply the Mapped[<type>] annotation and right hand object to a
     declarative assignment statement.
@@ -124,6 +136,7 @@ def _apply_type_to_mapped_statement(
 
     """
     left_node = lvalue.node
+    assert isinstance(left_node, Var)
 
     if left_hand_explicit_type is not None:
         left_node.type = api.named_type(
@@ -131,7 +144,10 @@ def _apply_type_to_mapped_statement(
         )
     else:
         lvalue.is_inferred_def = False
-        left_node.type = api.named_type("__sa_Mapped", [python_type_for_type])
+        left_node.type = api.named_type(
+            "__sa_Mapped",
+            [] if python_type_for_type is None else [python_type_for_type],
+        )
 
     # so to have it skip the right side totally, we can do this:
     # stmt.rvalue = TempNode(AnyType(TypeOfAny.special_form))
@@ -146,7 +162,7 @@ def _apply_type_to_mapped_statement(
     # the original right-hand side is maintained so it gets type checked
     # internally
     column_descriptor = nodes.NameExpr("__sa_Mapped")
-    column_descriptor.fullname = "sqlalchemy.orm.Mapped"
+    column_descriptor.fullname = "sqlalchemy.orm.attributes.Mapped"
     mm = nodes.MemberExpr(column_descriptor, "_empty_constructor")
     orig_call_expr = stmt.rvalue
     stmt.rvalue = CallExpr(mm, [orig_call_expr], [nodes.ARG_POS], ["arg1"])
@@ -199,11 +215,11 @@ def _apply_placeholder_attr_to_class(
     cls: ClassDef,
     qualified_name: str,
     attrname: str,
-):
+) -> None:
     sym = api.lookup_fully_qualified_or_none(qualified_name)
     if sym:
         assert isinstance(sym.node, TypeInfo)
-        type_: Union[Instance, AnyType] = Instance(sym.node, [])
+        type_: ProperType = Instance(sym.node, [])
     else:
         type_ = AnyType(TypeOfAny.special_form)
     var = Var(attrname)
index 40f1f0c0fad427e22fb22588b7d5abe5ea3c6557..8fac36342b4244d993e2856f71408729bba41672 100644 (file)
@@ -6,7 +6,7 @@
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
 
 from typing import Optional
-from typing import Type
+from typing import Union
 
 from mypy import nodes
 from mypy.nodes import AssignmentStmt
@@ -14,18 +14,24 @@ from mypy.nodes import CallExpr
 from mypy.nodes import ClassDef
 from mypy.nodes import Decorator
 from mypy.nodes import ListExpr
+from mypy.nodes import MemberExpr
 from mypy.nodes import NameExpr
 from mypy.nodes import PlaceholderNode
 from mypy.nodes import RefExpr
 from mypy.nodes import StrExpr
+from mypy.nodes import SymbolNode
 from mypy.nodes import SymbolTableNode
 from mypy.nodes import TempNode
 from mypy.nodes import TypeInfo
 from mypy.nodes import Var
 from mypy.plugin import SemanticAnalyzerPluginInterface
 from mypy.types import AnyType
+from mypy.types import CallableType
+from mypy.types import get_proper_type
 from mypy.types import Instance
 from mypy.types import NoneType
+from mypy.types import ProperType
+from mypy.types import Type
 from mypy.types import TypeOfAny
 from mypy.types import UnboundType
 from mypy.types import UnionType
@@ -37,7 +43,9 @@ from . import util
 
 
 def _scan_declarative_assignments_and_apply_types(
-    cls: ClassDef, api: SemanticAnalyzerPluginInterface, is_mixin_scan=False
+    cls: ClassDef,
+    api: SemanticAnalyzerPluginInterface,
+    is_mixin_scan: bool = False,
 ) -> Optional[util.DeclClassApplied]:
 
     info = util._info_for_cls(cls, api)
@@ -94,16 +102,17 @@ def _scan_symbol_table_entry(
     name: str,
     value: SymbolTableNode,
     cls_metadata: util.DeclClassApplied,
-):
+) -> None:
     """Extract mapping information from a SymbolTableNode that's in the
     type.names dictionary.
 
     """
-    if not isinstance(value.type, Instance):
+    value_type = get_proper_type(value.type)
+    if not isinstance(value_type, Instance):
         return
 
     left_hand_explicit_type = None
-    type_id = names._type_id_for_named_node(value.type.type)
+    type_id = names._type_id_for_named_node(value_type.type)
     # type_id = names._type_id_for_unbound_type(value.type.type, cls, api)
 
     err = False
@@ -118,22 +127,24 @@ def _scan_symbol_table_entry(
         names.SYNONYM_PROPERTY,
         names.COLUMN_PROPERTY,
     }:
-        if value.type.args:
-            left_hand_explicit_type = value.type.args[0]
+        if value_type.args:
+            left_hand_explicit_type = get_proper_type(value_type.args[0])
         else:
             err = True
     elif type_id is names.COLUMN:
-        if not value.type.args:
+        if not value_type.args:
             err = True
         else:
-            typeengine_arg = value.type.args[0]
+            typeengine_arg: Union[ProperType, TypeInfo] = get_proper_type(
+                value_type.args[0]
+            )
             if isinstance(typeengine_arg, Instance):
                 typeengine_arg = typeengine_arg.type
 
             if isinstance(typeengine_arg, (UnboundType, TypeInfo)):
                 sym = api.lookup_qualified(typeengine_arg.name, typeengine_arg)
-                if sym is not None:
-                    if names._mro_has_id(sym.node.mro, names.TYPEENGINE):
+                if sym is not None and isinstance(sym.node, TypeInfo):
+                    if names._has_base_type_id(sym.node, names.TYPEENGINE):
 
                         left_hand_explicit_type = UnionType(
                             [
@@ -148,7 +159,7 @@ def _scan_symbol_table_entry(
                             api,
                             "Column type should be a TypeEngine "
                             "subclass not '{}'".format(sym.node.fullname),
-                            value.type,
+                            value_type,
                         )
 
     if err:
@@ -158,7 +169,7 @@ def _scan_symbol_table_entry(
             "one of: Mapped[<python type>], relationship[<target class>], "
             "Column[<TypeEngine>], MapperProperty[<python type>]"
         )
-        util.fail(api, msg.format(name, cls.name))
+        util.fail(api, msg.format(name, cls.name), cls)
 
         left_hand_explicit_type = AnyType(TypeOfAny.special_form)
 
@@ -171,7 +182,7 @@ def _scan_declarative_decorator_stmt(
     api: SemanticAnalyzerPluginInterface,
     stmt: Decorator,
     cls_metadata: util.DeclClassApplied,
-):
+) -> None:
     """Extract mapping information from a @declared_attr in a declarative
     class.
 
@@ -195,16 +206,19 @@ def _scan_declarative_decorator_stmt(
 
     """
     for dec in stmt.decorators:
-        if names._type_id_for_named_node(dec) is names.DECLARED_ATTR:
+        if (
+            isinstance(dec, (NameExpr, MemberExpr, SymbolNode))
+            and names._type_id_for_named_node(dec) is names.DECLARED_ATTR
+        ):
             break
     else:
         return
 
     dec_index = cls.defs.body.index(stmt)
 
-    left_hand_explicit_type = None
+    left_hand_explicit_type: Optional[ProperType] = None
 
-    if stmt.func.type is not None:
+    if isinstance(stmt.func.type, CallableType):
         func_type = stmt.func.type.ret_type
         if isinstance(func_type, UnboundType):
             type_id = names._type_id_for_unbound_type(func_type, cls, api)
@@ -225,30 +239,28 @@ def _scan_declarative_decorator_stmt(
             }
             and func_type.args
         ):
-            left_hand_explicit_type = func_type.args[0]
+            left_hand_explicit_type = get_proper_type(func_type.args[0])
         elif type_id is names.COLUMN and func_type.args:
             typeengine_arg = func_type.args[0]
             if isinstance(typeengine_arg, UnboundType):
                 sym = api.lookup_qualified(typeengine_arg.name, typeengine_arg)
-                if sym is not None and names._mro_has_id(
-                    sym.node.mro, names.TYPEENGINE
-                ):
-
-                    left_hand_explicit_type = UnionType(
-                        [
-                            infer._extract_python_type_from_typeengine(
-                                api, sym.node, []
-                            ),
-                            NoneType(),
-                        ]
-                    )
-                else:
-                    util.fail(
-                        api,
-                        "Column type should be a TypeEngine "
-                        "subclass not '{}'".format(sym.node.fullname),
-                        func_type,
-                    )
+                if sym is not None and isinstance(sym.node, TypeInfo):
+                    if names._has_base_type_id(sym.node, names.TYPEENGINE):
+                        left_hand_explicit_type = UnionType(
+                            [
+                                infer._extract_python_type_from_typeengine(
+                                    api, sym.node, []
+                                ),
+                                NoneType(),
+                            ]
+                        )
+                    else:
+                        util.fail(
+                            api,
+                            "Column type should be a TypeEngine "
+                            "subclass not '{}'".format(sym.node.fullname),
+                            func_type,
+                        )
 
     if left_hand_explicit_type is None:
         # no type on the decorated function.  our option here is to
@@ -274,8 +286,8 @@ def _scan_declarative_decorator_stmt(
     # of converting it to the regular Instance/TypeInfo/UnionType structures
     # we see everywhere else.
     if isinstance(left_hand_explicit_type, UnboundType):
-        left_hand_explicit_type = util._unbound_to_instance(
-            api, left_hand_explicit_type
+        left_hand_explicit_type = get_proper_type(
+            util._unbound_to_instance(api, left_hand_explicit_type)
         )
 
     left_node.node.type = api.named_type(
@@ -315,7 +327,7 @@ def _scan_declarative_assignment_stmt(
     api: SemanticAnalyzerPluginInterface,
     stmt: AssignmentStmt,
     cls_metadata: util.DeclClassApplied,
-):
+) -> None:
     """Extract mapping information from an assignment statement in a
     declarative class.
 
@@ -339,7 +351,7 @@ def _scan_declarative_assignment_stmt(
     assert isinstance(node, Var)
 
     if node.name == "__abstract__":
-        if stmt.rvalue.fullname == "builtins.True":
+        if api.parse_bool(stmt.rvalue) is True:
             cls_metadata.is_mapped = False
         return
     elif node.name == "__tablename__":
@@ -354,7 +366,8 @@ def _scan_declarative_assignment_stmt(
                 if isinstance(item, (NameExpr, StrExpr)):
                     apply._apply_mypy_mapped_attr(cls, api, item, cls_metadata)
 
-    left_hand_mapped_type: Type = None
+    left_hand_mapped_type: Optional[Type] = None
+    left_hand_explicit_type: Optional[ProperType] = None
 
     if node.is_inferred or node.type is None:
         if isinstance(stmt.type, UnboundType):
@@ -370,32 +383,33 @@ def _scan_declarative_assignment_stmt(
                 mapped_sym = api.lookup_qualified("Mapped", cls)
                 if (
                     mapped_sym is not None
+                    and mapped_sym.node is not None
                     and names._type_id_for_named_node(mapped_sym.node)
                     is names.MAPPED
                 ):
-                    left_hand_explicit_type = stmt.type.args[0]
+                    left_hand_explicit_type = get_proper_type(
+                        stmt.type.args[0]
+                    )
                     left_hand_mapped_type = stmt.type
 
             # TODO: do we need to convert from unbound for this case?
             # left_hand_explicit_type = util._unbound_to_instance(
             #     api, left_hand_explicit_type
             # )
-
-        else:
-            left_hand_explicit_type = None
     else:
+        node_type = get_proper_type(node.type)
         if (
-            isinstance(node.type, Instance)
-            and names._type_id_for_named_node(node.type.type) is names.MAPPED
+            isinstance(node_type, Instance)
+            and names._type_id_for_named_node(node_type.type) is names.MAPPED
         ):
             # print(node.type)
             # sqlalchemy.orm.attributes.Mapped[<python type>]
-            left_hand_explicit_type = node.type.args[0]
-            left_hand_mapped_type = node.type
+            left_hand_explicit_type = get_proper_type(node_type.args[0])
+            left_hand_mapped_type = node_type
         else:
             # print(node.type)
             # <python type>
-            left_hand_explicit_type = node.type
+            left_hand_explicit_type = node_type
             left_hand_mapped_type = None
 
     if isinstance(stmt.rvalue, TempNode) and left_hand_mapped_type is not None:
@@ -440,10 +454,10 @@ def _scan_declarative_assignment_stmt(
     else:
         return
 
-    cls_metadata.mapped_attr_names.append((node.name, python_type_for_type))
-
     assert python_type_for_type is not None
 
+    cls_metadata.mapped_attr_names.append((node.name, python_type_for_type))
+
     apply._apply_type_to_mapped_statement(
         api,
         stmt,
@@ -485,6 +499,6 @@ def _scan_for_mapped_bases(
             )
         )
 
-        if base_decl_class_applied not in (None, False):
+        if base_decl_class_applied is not None:
             cls_metadata.mapped_mro.append(base)
         baseclasses.extend(base.type.bases)
index f1bda7865532eea21670e96020082504abd02473..7915c3ae2dcdabbd7a2486b438e0ea44d2e61cf0 100644 (file)
@@ -6,23 +6,26 @@
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
 
 from typing import Optional
-from typing import Union
+from typing import Sequence
 
-from mypy import nodes
-from mypy import types
 from mypy.maptype import map_instance_to_supertype
 from mypy.messages import format_type
 from mypy.nodes import AssignmentStmt
 from mypy.nodes import CallExpr
+from mypy.nodes import Expression
+from mypy.nodes import MemberExpr
 from mypy.nodes import NameExpr
+from mypy.nodes import RefExpr
 from mypy.nodes import StrExpr
 from mypy.nodes import TypeInfo
 from mypy.nodes import Var
 from mypy.plugin import SemanticAnalyzerPluginInterface
 from mypy.subtypes import is_subtype
 from mypy.types import AnyType
+from mypy.types import get_proper_type
 from mypy.types import Instance
 from mypy.types import NoneType
+from mypy.types import ProperType
 from mypy.types import TypeOfAny
 from mypy.types import UnionType
 
@@ -34,8 +37,8 @@ def _infer_type_from_relationship(
     api: SemanticAnalyzerPluginInterface,
     stmt: AssignmentStmt,
     node: Var,
-    left_hand_explicit_type: Optional[types.Type],
-) -> Union[Instance, UnionType, None]:
+    left_hand_explicit_type: Optional[ProperType],
+) -> Optional[ProperType]:
     """Infer the type of mapping from a relationship.
 
     E.g.::
@@ -62,7 +65,7 @@ def _infer_type_from_relationship(
 
     assert isinstance(stmt.rvalue, CallExpr)
     target_cls_arg = stmt.rvalue.args[0]
-    python_type_for_type = None
+    python_type_for_type: Optional[ProperType] = None
 
     if isinstance(target_cls_arg, NameExpr) and isinstance(
         target_cls_arg.node, TypeInfo
@@ -86,7 +89,7 @@ def _infer_type_from_relationship(
     # isinstance(target_cls_arg, StrExpr)
 
     uselist_arg = util._get_callexpr_kwarg(stmt.rvalue, "uselist")
-    collection_cls_arg = util._get_callexpr_kwarg(
+    collection_cls_arg: Optional[Expression] = util._get_callexpr_kwarg(
         stmt.rvalue, "collection_class"
     )
     type_is_a_collection = False
@@ -98,7 +101,7 @@ def _infer_type_from_relationship(
 
     if (
         uselist_arg is not None
-        and uselist_arg.fullname == "builtins.True"
+        and api.parse_bool(uselist_arg) is True
         and collection_cls_arg is None
     ):
         type_is_a_collection = True
@@ -107,7 +110,7 @@ def _infer_type_from_relationship(
                 "__builtins__.list", [python_type_for_type]
             )
     elif (
-        uselist_arg is None or uselist_arg.fullname == "builtins.True"
+        uselist_arg is None or api.parse_bool(uselist_arg) is True
     ) and collection_cls_arg is not None:
         type_is_a_collection = True
         if isinstance(collection_cls_arg, CallExpr):
@@ -130,7 +133,7 @@ def _infer_type_from_relationship(
                 stmt.rvalue,
             )
             python_type_for_type = None
-    elif uselist_arg is not None and uselist_arg.fullname == "builtins.False":
+    elif uselist_arg is not None and api.parse_bool(uselist_arg) is False:
         if collection_cls_arg is not None:
             util.fail(
                 api,
@@ -159,13 +162,19 @@ def _infer_type_from_relationship(
             api, node, left_hand_explicit_type
         )
     elif left_hand_explicit_type is not None:
-        return _infer_type_from_left_and_inferred_right(
-            api,
-            node,
-            left_hand_explicit_type,
-            python_type_for_type,
-            type_is_a_collection=type_is_a_collection,
-        )
+        if type_is_a_collection:
+            assert isinstance(left_hand_explicit_type, Instance)
+            assert isinstance(python_type_for_type, Instance)
+            return _infer_collection_type_from_left_and_inferred_right(
+                api, node, left_hand_explicit_type, python_type_for_type
+            )
+        else:
+            return _infer_type_from_left_and_inferred_right(
+                api,
+                node,
+                left_hand_explicit_type,
+                python_type_for_type,
+            )
     else:
         return python_type_for_type
 
@@ -174,8 +183,8 @@ def _infer_type_from_decl_composite_property(
     api: SemanticAnalyzerPluginInterface,
     stmt: AssignmentStmt,
     node: Var,
-    left_hand_explicit_type: Optional[types.Type],
-) -> Union[Instance, UnionType, None]:
+    left_hand_explicit_type: Optional[ProperType],
+) -> Optional[ProperType]:
     """Infer the type of mapping from a CompositeProperty."""
 
     assert isinstance(stmt.rvalue, CallExpr)
@@ -206,8 +215,8 @@ def _infer_type_from_decl_column_property(
     api: SemanticAnalyzerPluginInterface,
     stmt: AssignmentStmt,
     node: Var,
-    left_hand_explicit_type: Optional[types.Type],
-) -> Union[Instance, UnionType, None]:
+    left_hand_explicit_type: Optional[ProperType],
+) -> Optional[ProperType]:
     """Infer the type of mapping from a ColumnProperty.
 
     This includes mappings against ``column_property()`` as well as the
@@ -219,28 +228,26 @@ def _infer_type_from_decl_column_property(
 
     if isinstance(first_prop_arg, CallExpr):
         type_id = names._type_id_for_callee(first_prop_arg.callee)
-    else:
-        type_id = None
 
-    # look for column_property() / deferred() etc with Column as first
-    # argument
-    if type_id is names.COLUMN:
-        return _infer_type_from_decl_column(
-            api, stmt, node, left_hand_explicit_type, first_prop_arg
-        )
-    else:
-        return _infer_type_from_left_hand_type_only(
-            api, node, left_hand_explicit_type
-        )
+        # look for column_property() / deferred() etc with Column as first
+        # argument
+        if type_id is names.COLUMN:
+            return _infer_type_from_decl_column(
+                api, stmt, node, left_hand_explicit_type, first_prop_arg
+            )
+
+    return _infer_type_from_left_hand_type_only(
+        api, node, left_hand_explicit_type
+    )
 
 
 def _infer_type_from_decl_column(
     api: SemanticAnalyzerPluginInterface,
     stmt: AssignmentStmt,
     node: Var,
-    left_hand_explicit_type: Optional[types.Type],
+    left_hand_explicit_type: Optional[ProperType],
     right_hand_expression: CallExpr,
-) -> Union[Instance, UnionType, None]:
+) -> Optional[ProperType]:
     """Infer the type of mapping from a Column.
 
     E.g.::
@@ -277,12 +284,13 @@ def _infer_type_from_decl_column(
     callee = None
 
     for column_arg in right_hand_expression.args[0:2]:
-        if isinstance(column_arg, nodes.CallExpr):
-            # x = Column(String(50))
-            callee = column_arg.callee
-            type_args = column_arg.args
-            break
-        elif isinstance(column_arg, (nodes.NameExpr, nodes.MemberExpr)):
+        if isinstance(column_arg, CallExpr):
+            if isinstance(column_arg.callee, RefExpr):
+                # x = Column(String(50))
+                callee = column_arg.callee
+                type_args: Sequence[Expression] = column_arg.args
+                break
+        elif isinstance(column_arg, (NameExpr, MemberExpr)):
             if isinstance(column_arg.node, TypeInfo):
                 # x = Column(String)
                 callee = column_arg
@@ -314,10 +322,7 @@ def _infer_type_from_decl_column(
             )
 
         else:
-            python_type_for_type = UnionType(
-                [python_type_for_type, NoneType()]
-            )
-        return python_type_for_type
+            return UnionType([python_type_for_type, NoneType()])
     else:
         # it's not TypeEngine, it's typically implicitly typed
         # like ForeignKey.  we can't infer from the right side.
@@ -329,10 +334,11 @@ def _infer_type_from_decl_column(
 def _infer_type_from_left_and_inferred_right(
     api: SemanticAnalyzerPluginInterface,
     node: Var,
-    left_hand_explicit_type: Optional[types.Type],
-    python_type_for_type: Union[Instance, UnionType],
-    type_is_a_collection: bool = False,
-) -> Optional[Union[Instance, UnionType]]:
+    left_hand_explicit_type: ProperType,
+    python_type_for_type: ProperType,
+    orig_left_hand_type: Optional[ProperType] = None,
+    orig_python_type_for_type: Optional[ProperType] = None,
+) -> Optional[ProperType]:
     """Validate type when a left hand annotation is present and we also
     could infer the right hand side::
 
@@ -340,12 +346,10 @@ def _infer_type_from_left_and_inferred_right(
 
     """
 
-    orig_left_hand_type = left_hand_explicit_type
-    orig_python_type_for_type = python_type_for_type
-
-    if type_is_a_collection and left_hand_explicit_type.args:
-        left_hand_explicit_type = left_hand_explicit_type.args[0]
-        python_type_for_type = python_type_for_type.args[0]
+    if orig_left_hand_type is None:
+        orig_left_hand_type = left_hand_explicit_type
+    if orig_python_type_for_type is None:
+        orig_python_type_for_type = python_type_for_type
 
     if not is_subtype(left_hand_explicit_type, python_type_for_type):
         effective_type = api.named_type(
@@ -369,11 +373,40 @@ def _infer_type_from_left_and_inferred_right(
     return orig_left_hand_type
 
 
+def _infer_collection_type_from_left_and_inferred_right(
+    api: SemanticAnalyzerPluginInterface,
+    node: Var,
+    left_hand_explicit_type: Instance,
+    python_type_for_type: Instance,
+) -> Optional[ProperType]:
+    orig_left_hand_type = left_hand_explicit_type
+    orig_python_type_for_type = python_type_for_type
+
+    if left_hand_explicit_type.args:
+        left_hand_arg = get_proper_type(left_hand_explicit_type.args[0])
+        python_type_arg = get_proper_type(python_type_for_type.args[0])
+    else:
+        left_hand_arg = left_hand_explicit_type
+        python_type_arg = python_type_for_type
+
+    assert isinstance(left_hand_arg, (Instance, UnionType))
+    assert isinstance(python_type_arg, (Instance, UnionType))
+
+    return _infer_type_from_left_and_inferred_right(
+        api,
+        node,
+        left_hand_arg,
+        python_type_arg,
+        orig_left_hand_type=orig_left_hand_type,
+        orig_python_type_for_type=orig_python_type_for_type,
+    )
+
+
 def _infer_type_from_left_hand_type_only(
     api: SemanticAnalyzerPluginInterface,
     node: Var,
-    left_hand_explicit_type: Optional[types.Type],
-) -> Optional[Union[Instance, UnionType]]:
+    left_hand_explicit_type: Optional[ProperType],
+) -> Optional[ProperType]:
     """Determine the type based on explicit annotation only.
 
     if no annotation were present, note that we need one there to know
@@ -397,8 +430,10 @@ def _infer_type_from_left_hand_type_only(
 
 
 def _extract_python_type_from_typeengine(
-    api: SemanticAnalyzerPluginInterface, node: TypeInfo, type_args
-) -> Instance:
+    api: SemanticAnalyzerPluginInterface,
+    node: TypeInfo,
+    type_args: Sequence[Expression],
+) -> ProperType:
     if node.fullname == "sqlalchemy.sql.sqltypes.Enum" and type_args:
         first_arg = type_args[0]
         if isinstance(first_arg, NameExpr) and isinstance(
@@ -426,4 +461,4 @@ def _extract_python_type_from_typeengine(
         Instance(node, []),
         type_engine_sym.node,
     )
-    return type_engine.args[-1]
+    return get_proper_type(type_engine.args[-1])
index 174a8f422e8c866c0619f6c3cf67cd5ea38306d4..6ee600cd792fe0027959c78633c5acb9b6f88e6d 100644 (file)
@@ -5,40 +5,48 @@
 # This module is part of SQLAlchemy and is released under
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
 
+from typing import Dict
 from typing import List
+from typing import Optional
+from typing import Set
+from typing import Tuple
+from typing import Union
 
 from mypy.nodes import ClassDef
 from mypy.nodes import Expression
 from mypy.nodes import FuncDef
-from mypy.nodes import RefExpr
+from mypy.nodes import MemberExpr
+from mypy.nodes import NameExpr
 from mypy.nodes import SymbolNode
 from mypy.nodes import TypeAlias
 from mypy.nodes import TypeInfo
-from mypy.nodes import Union
 from mypy.plugin import SemanticAnalyzerPluginInterface
+from mypy.types import CallableType
+from mypy.types import get_proper_type
+from mypy.types import Instance
 from mypy.types import UnboundType
 
 from ... import util
 
-COLUMN = util.symbol("COLUMN")
-RELATIONSHIP = util.symbol("RELATIONSHIP")
-REGISTRY = util.symbol("REGISTRY")
-COLUMN_PROPERTY = util.symbol("COLUMN_PROPERTY")
-TYPEENGINE = util.symbol("TYPEENGNE")
-MAPPED = util.symbol("MAPPED")
-DECLARATIVE_BASE = util.symbol("DECLARATIVE_BASE")
-DECLARATIVE_META = util.symbol("DECLARATIVE_META")
-MAPPED_DECORATOR = util.symbol("MAPPED_DECORATOR")
-COLUMN_PROPERTY = util.symbol("COLUMN_PROPERTY")
-SYNONYM_PROPERTY = util.symbol("SYNONYM_PROPERTY")
-COMPOSITE_PROPERTY = util.symbol("COMPOSITE_PROPERTY")
-DECLARED_ATTR = util.symbol("DECLARED_ATTR")
-MAPPER_PROPERTY = util.symbol("MAPPER_PROPERTY")
-AS_DECLARATIVE = util.symbol("AS_DECLARATIVE")
-AS_DECLARATIVE_BASE = util.symbol("AS_DECLARATIVE_BASE")
-DECLARATIVE_MIXIN = util.symbol("DECLARATIVE_MIXIN")
-
-_lookup = {
+COLUMN: int = util.symbol("COLUMN")  # type: ignore
+RELATIONSHIP: int = util.symbol("RELATIONSHIP")  # type: ignore
+REGISTRY: int = util.symbol("REGISTRY")  # type: ignore
+COLUMN_PROPERTY: int = util.symbol("COLUMN_PROPERTY")  # type: ignore
+TYPEENGINE: int = util.symbol("TYPEENGNE")  # type: ignore
+MAPPED: int = util.symbol("MAPPED")  # type: ignore
+DECLARATIVE_BASE: int = util.symbol("DECLARATIVE_BASE")  # type: ignore
+DECLARATIVE_META: int = util.symbol("DECLARATIVE_META")  # type: ignore
+MAPPED_DECORATOR: int = util.symbol("MAPPED_DECORATOR")  # type: ignore
+COLUMN_PROPERTY: int = util.symbol("COLUMN_PROPERTY")  # type: ignore
+SYNONYM_PROPERTY: int = util.symbol("SYNONYM_PROPERTY")  # type: ignore
+COMPOSITE_PROPERTY: int = util.symbol("COMPOSITE_PROPERTY")  # type: ignore
+DECLARED_ATTR: int = util.symbol("DECLARED_ATTR")  # type: ignore
+MAPPER_PROPERTY: int = util.symbol("MAPPER_PROPERTY")  # type: ignore
+AS_DECLARATIVE: int = util.symbol("AS_DECLARATIVE")  # type: ignore
+AS_DECLARATIVE_BASE: int = util.symbol("AS_DECLARATIVE_BASE")  # type: ignore
+DECLARATIVE_MIXIN: int = util.symbol("DECLARATIVE_MIXIN")  # type: ignore
+
+_lookup: Dict[str, Tuple[int, Set[str]]] = {
     "Column": (
         COLUMN,
         {
@@ -145,7 +153,21 @@ _lookup = {
 }
 
 
-def _mro_has_id(mro: List[TypeInfo], type_id: int):
+def _has_base_type_id(info: TypeInfo, type_id: int) -> bool:
+    for mr in info.mro:
+        check_type_id, fullnames = _lookup.get(mr.name, (None, None))
+        if check_type_id == type_id:
+            break
+    else:
+        return False
+
+    if fullnames is None:
+        return False
+
+    return mr.fullname in fullnames
+
+
+def _mro_has_id(mro: List[TypeInfo], type_id: int) -> bool:
     for mr in mro:
         check_type_id, fullnames = _lookup.get(mr.name, (None, None))
         if check_type_id == type_id:
@@ -153,65 +175,75 @@ def _mro_has_id(mro: List[TypeInfo], type_id: int):
     else:
         return False
 
+    if fullnames is None:
+        return False
+
     return mr.fullname in fullnames
 
 
 def _type_id_for_unbound_type(
     type_: UnboundType, cls: ClassDef, api: SemanticAnalyzerPluginInterface
-) -> int:
+) -> Optional[int]:
     type_id = None
 
     sym = api.lookup_qualified(type_.name, type_)
     if sym is not None:
         if isinstance(sym.node, TypeAlias):
-            type_id = _type_id_for_named_node(sym.node.target.type)
+            target_type = get_proper_type(sym.node.target)
+            if isinstance(target_type, Instance):
+                type_id = _type_id_for_named_node(target_type.type)
         elif isinstance(sym.node, TypeInfo):
             type_id = _type_id_for_named_node(sym.node)
 
     return type_id
 
 
-def _type_id_for_callee(callee: Expression) -> int:
-    if isinstance(callee.node, FuncDef):
-        return _type_id_for_funcdef(callee.node)
-    elif isinstance(callee.node, TypeAlias):
-        type_id = _type_id_for_fullname(callee.node.target.type.fullname)
-    elif isinstance(callee.node, TypeInfo):
-        type_id = _type_id_for_named_node(callee)
-    else:
-        type_id = None
+def _type_id_for_callee(callee: Expression) -> Optional[int]:
+    if isinstance(callee, (MemberExpr, NameExpr)):
+        if isinstance(callee.node, FuncDef):
+            return _type_id_for_funcdef(callee.node)
+        elif isinstance(callee.node, TypeAlias):
+            target_type = get_proper_type(callee.node.target)
+            if isinstance(target_type, Instance):
+                type_id = _type_id_for_fullname(target_type.type.fullname)
+        elif isinstance(callee.node, TypeInfo):
+            type_id = _type_id_for_named_node(callee)
+        else:
+            type_id = None
     return type_id
 
 
-def _type_id_for_funcdef(node: FuncDef) -> int:
-    if hasattr(node.type.ret_type, "type"):
-        type_id = _type_id_for_fullname(node.type.ret_type.type.fullname)
-    else:
-        type_id = None
-    return type_id
+def _type_id_for_funcdef(node: FuncDef) -> Optional[int]:
+    if node.type and isinstance(node.type, CallableType):
+        ret_type = get_proper_type(node.type.ret_type)
+
+        if isinstance(ret_type, Instance):
+            return _type_id_for_fullname(ret_type.type.fullname)
+
+    return None
 
 
-def _type_id_for_named_node(node: Union[RefExpr, SymbolNode]) -> int:
+def _type_id_for_named_node(
+    node: Union[NameExpr, MemberExpr, SymbolNode]
+) -> Optional[int]:
     type_id, fullnames = _lookup.get(node.name, (None, None))
 
-    if type_id is None:
+    if type_id is None or fullnames is None:
         return None
-
     elif node.fullname in fullnames:
         return type_id
     else:
         return None
 
 
-def _type_id_for_fullname(fullname: str) -> int:
+def _type_id_for_fullname(fullname: str) -> Optional[int]:
     tokens = fullname.split(".")
     immediate = tokens[-1]
 
     type_id, fullnames = _lookup.get(immediate, (None, None))
 
-    if type_id is None:
+    if type_id is None or fullnames is None:
         return None
-
     elif fullname in fullnames:
         return type_id
     else:
index 23585be49b3d48caae2417d06e5f245a15b9a63b..76aac515225bc40da21cb144126923b1d3503a04 100644 (file)
@@ -9,9 +9,12 @@
 Mypy plugin for SQLAlchemy ORM.
 
 """
+from typing import Callable
 from typing import List
+from typing import Optional
 from typing import Tuple
-from typing import Type
+from typing import Type as TypingType
+from typing import Union
 
 from mypy import nodes
 from mypy.mro import calculate_mro
@@ -25,20 +28,20 @@ from mypy.nodes import SymbolTable
 from mypy.nodes import SymbolTableNode
 from mypy.nodes import TypeInfo
 from mypy.plugin import AttributeContext
-from mypy.plugin import Callable
 from mypy.plugin import ClassDefContext
 from mypy.plugin import DynamicClassDefContext
-from mypy.plugin import Optional
 from mypy.plugin import Plugin
 from mypy.plugin import SemanticAnalyzerPluginInterface
+from mypy.types import get_proper_type
 from mypy.types import Instance
+from mypy.types import Type
 
 from . import decl_class
 from . import names
 from . import util
 
 
-class CustomPlugin(Plugin):
+class SQLAlchemyPlugin(Plugin):
     def get_dynamic_class_hook(
         self, fullname: str
     ) -> Optional[Callable[[DynamicClassDefContext], None]]:
@@ -72,7 +75,7 @@ class CustomPlugin(Plugin):
 
         sym = self.lookup_fully_qualified(fullname)
 
-        if sym is not None:
+        if sym is not None and sym.node is not None:
             type_id = names._type_id_for_named_node(sym.node)
             if type_id is names.MAPPED_DECORATOR:
                 return _cls_decorator_hook
@@ -109,8 +112,8 @@ class CustomPlugin(Plugin):
         ]
 
 
-def plugin(version: str):
-    return CustomPlugin
+def plugin(version: str) -> TypingType[SQLAlchemyPlugin]:
+    return SQLAlchemyPlugin
 
 
 def _queryable_getattr_hook(ctx: AttributeContext) -> Type:
@@ -143,14 +146,14 @@ def _fill_in_decorators(ctx: ClassDefContext) -> None:
         else:
             continue
 
+        assert isinstance(target.expr, NameExpr)
         sym = ctx.api.lookup_qualified(
             target.expr.name, target, suppress_errors=True
         )
-        if sym:
-            if sym.node.type and hasattr(sym.node.type, "type"):
-                target.fullname = (
-                    f"{sym.node.type.type.fullname}.{target.name}"
-                )
+        if sym and sym.node:
+            sym_type = get_proper_type(sym.type)
+            if isinstance(sym_type, Instance):
+                target.fullname = f"{sym_type.type.fullname}.{target.name}"
             else:
                 # if the registry is in the same file as where the
                 # decorator is used, it might not have semantic
@@ -170,7 +173,7 @@ def _fill_in_decorators(ctx: ClassDefContext) -> None:
                 )
 
 
-def _add_globals(ctx: ClassDefContext):
+def _add_globals(ctx: Union[ClassDefContext, DynamicClassDefContext]) -> None:
     """Add __sa_DeclarativeMeta and __sa_Mapped symbol to the global space
     for all class defs
 
@@ -207,7 +210,15 @@ def _cls_decorator_hook(ctx: ClassDefContext) -> None:
     _add_globals(ctx)
     assert isinstance(ctx.reason, nodes.MemberExpr)
     expr = ctx.reason.expr
-    assert names._type_id_for_named_node(expr.node.type.type) is names.REGISTRY
+
+    assert isinstance(expr, nodes.RefExpr) and isinstance(expr.node, nodes.Var)
+
+    node_type = get_proper_type(expr.node.type)
+
+    assert (
+        isinstance(node_type, Instance)
+        and names._type_id_for_named_node(node_type.type) is names.REGISTRY
+    )
 
     decl_class._scan_declarative_assignments_and_apply_types(ctx.cls, ctx.api)
 
@@ -237,8 +248,8 @@ def _dynamic_class_hook(ctx: DynamicClassDefContext) -> None:
     cls.info = info
     _make_declarative_meta(ctx.api, cls)
 
-    cls_arg = util._get_callexpr_kwarg(ctx.call, "cls")
-    if cls_arg is not None:
+    cls_arg = util._get_callexpr_kwarg(ctx.call, "cls", expr_types=(NameExpr,))
+    if cls_arg is not None and isinstance(cls_arg.node, TypeInfo):
         decl_class._scan_declarative_assignments_and_apply_types(
             cls_arg.node.defn, ctx.api, is_mixin_scan=True
         )
@@ -263,7 +274,7 @@ def _dynamic_class_hook(ctx: DynamicClassDefContext) -> None:
 
 def _make_declarative_meta(
     api: SemanticAnalyzerPluginInterface, target_cls: ClassDef
-):
+) -> None:
 
     declarative_meta_name: NameExpr = NameExpr("__sa_DeclarativeMeta")
     declarative_meta_name.kind = GDEF
@@ -272,6 +283,8 @@ def _make_declarative_meta(
     # installed by _add_globals
     sym = api.lookup_qualified("__sa_DeclarativeMeta", target_cls)
 
+    assert sym is not None and isinstance(sym.node, nodes.TypeInfo)
+
     declarative_meta_typeinfo = sym.node
     declarative_meta_name.node = declarative_meta_typeinfo
 
index 1c1e56d2cba4e404e46c06f21ed53c19ef3c2fc1..26bb0ac67aa6a4987dd10cc32f57de11e6ba8613 100644 (file)
@@ -1,37 +1,52 @@
+from typing import Any
+from typing import cast
+from typing import Iterable
+from typing import Iterator
+from typing import List
 from typing import Optional
-from typing import Sequence
+from typing import overload
 from typing import Tuple
-from typing import Type
+from typing import Type as TypingType
+from typing import TypeVar
+from typing import Union
 
 from mypy.nodes import CallExpr
+from mypy.nodes import ClassDef
 from mypy.nodes import CLASSDEF_NO_INFO
 from mypy.nodes import Context
 from mypy.nodes import IfStmt
 from mypy.nodes import JsonDict
 from mypy.nodes import NameExpr
+from mypy.nodes import Statement
 from mypy.nodes import SymbolTableNode
 from mypy.nodes import TypeInfo
 from mypy.plugin import ClassDefContext
+from mypy.plugin import DynamicClassDefContext
 from mypy.plugin import SemanticAnalyzerPluginInterface
 from mypy.plugins.common import deserialize_and_fixup_type
 from mypy.types import Instance
 from mypy.types import NoneType
+from mypy.types import ProperType
+from mypy.types import Type
 from mypy.types import UnboundType
 from mypy.types import UnionType
 
 
+_TArgType = TypeVar("_TArgType", bound=Union[CallExpr, NameExpr])
+
+
 class DeclClassApplied:
     def __init__(
         self,
         is_mapped: bool,
         has_table: bool,
-        mapped_attr_names: Sequence[Tuple[str, Type]],
-        mapped_mro: Sequence[Type],
+        mapped_attr_names: Iterable[Tuple[str, ProperType]],
+        mapped_mro: Iterable[Instance],
     ):
         self.is_mapped = is_mapped
         self.has_table = has_table
-        self.mapped_attr_names = mapped_attr_names
-        self.mapped_mro = mapped_mro
+        self.mapped_attr_names = list(mapped_attr_names)
+        self.mapped_mro = list(mapped_mro)
 
     def serialize(self) -> JsonDict:
         return {
@@ -52,28 +67,34 @@ class DeclClassApplied:
         return DeclClassApplied(
             is_mapped=data["is_mapped"],
             has_table=data["has_table"],
-            mapped_attr_names=[
-                (name, deserialize_and_fixup_type(type_, api))
-                for name, type_ in data["mapped_attr_names"]
-            ],
-            mapped_mro=[
-                deserialize_and_fixup_type(type_, api)
-                for type_ in data["mapped_mro"]
-            ],
+            mapped_attr_names=cast(
+                List[Tuple[str, ProperType]],
+                [
+                    (name, deserialize_and_fixup_type(type_, api))
+                    for name, type_ in data["mapped_attr_names"]
+                ],
+            ),
+            mapped_mro=cast(
+                List[Instance],
+                [
+                    deserialize_and_fixup_type(type_, api)
+                    for type_ in data["mapped_mro"]
+                ],
+            ),
         )
 
 
-def fail(api: SemanticAnalyzerPluginInterface, msg: str, ctx: Context):
+def fail(api: SemanticAnalyzerPluginInterface, msg: str, ctx: Context) -> None:
     msg = "[SQLAlchemy Mypy plugin] %s" % msg
     return api.fail(msg, ctx)
 
 
 def add_global(
-    ctx: ClassDefContext,
+    ctx: Union[ClassDefContext, DynamicClassDefContext],
     module: str,
     symbol_name: str,
     asname: str,
-):
+) -> None:
     module_globals = ctx.api.modules[ctx.api.cur_mod_id].names
 
     if asname not in module_globals:
@@ -84,18 +105,50 @@ def add_global(
         module_globals[asname] = lookup_sym
 
 
-def _get_callexpr_kwarg(callexpr: CallExpr, name: str) -> Optional[NameExpr]:
+@overload
+def _get_callexpr_kwarg(
+    callexpr: CallExpr, name: str, *, expr_types: None = ...
+) -> Optional[Union[CallExpr, NameExpr]]:
+    ...
+
+
+@overload
+def _get_callexpr_kwarg(
+    callexpr: CallExpr,
+    name: str,
+    *,
+    expr_types: Tuple[TypingType[_TArgType], ...]
+) -> Optional[_TArgType]:
+    ...
+
+
+def _get_callexpr_kwarg(
+    callexpr: CallExpr,
+    name: str,
+    *,
+    expr_types: Optional[Tuple[TypingType[Any], ...]] = None
+) -> Optional[Any]:
     try:
         arg_idx = callexpr.arg_names.index(name)
     except ValueError:
         return None
 
-    return callexpr.args[arg_idx]
+    kwarg = callexpr.args[arg_idx]
+    if isinstance(
+        kwarg, expr_types if expr_types is not None else (NameExpr, CallExpr)
+    ):
+        return kwarg
 
+    return None
 
-def _flatten_typechecking(stmts):
+
+def _flatten_typechecking(stmts: Iterable[Statement]) -> Iterator[Statement]:
     for stmt in stmts:
-        if isinstance(stmt, IfStmt) and stmt.expr[0].name == "TYPE_CHECKING":
+        if (
+            isinstance(stmt, IfStmt)
+            and isinstance(stmt.expr[0], NameExpr)
+            and stmt.expr[0].fullname == "typing.TYPE_CHECKING"
+        ):
             for substmt in stmt.body[0].body:
                 yield substmt
         else:
@@ -103,7 +156,7 @@ def _flatten_typechecking(stmts):
 
 
 def _unbound_to_instance(
-    api: SemanticAnalyzerPluginInterface, typ: UnboundType
+    api: SemanticAnalyzerPluginInterface, typ: Type
 ) -> Type:
     """Take the UnboundType that we seem to get as the ret_type from a FuncDef
     and convert it into an Instance/TypeInfo kind of structure that seems
@@ -130,7 +183,11 @@ def _unbound_to_instance(
 
     node = api.lookup_qualified(typ.name, typ)
 
-    if node is not None and isinstance(node, SymbolTableNode):
+    if (
+        node is not None
+        and isinstance(node, SymbolTableNode)
+        and isinstance(node.node, TypeInfo)
+    ):
         bound_type = node.node
 
         return Instance(
@@ -146,12 +203,12 @@ def _unbound_to_instance(
         return typ
 
 
-def _info_for_cls(cls, api):
+def _info_for_cls(
+    cls: ClassDef, api: SemanticAnalyzerPluginInterface
+) -> TypeInfo:
     if cls.info is CLASSDEF_NO_INFO:
         sym = api.lookup_qualified(cls.name, cls)
-        if sym.node and isinstance(sym.node, TypeInfo):
-            info = sym.node
-    else:
-        info = cls.info
+        assert sym and isinstance(sym.node, TypeInfo)
+        return sym.node
 
-    return info
+    return cls.info
index 7e328d5b5f1a8f806dba50b753b226d9f1d3a82a..b0eebb2855cc157f047ade93c1059a62d40ad996 100644 (file)
--- a/setup.cfg
+++ b/setup.cfg
@@ -118,7 +118,14 @@ per-file-ignores =
 
 [mypy]
 # min mypy version 0.800
-plugins = sqlalchemy.ext.mypy.plugin
+strict = True
+incremental = True
+
+[mypy-sqlalchemy.*]
+ignore_errors = True
+
+[mypy-sqlalchemy.ext.mypy.*]
+ignore_errors = False
 
 [sqla_testing]
 requirement_cls = test.requirements:DefaultRequirements