]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Refactor mypy plugin 6764/head
authorBryan Forbes <bryan@reigndropsfall.net>
Wed, 14 Jul 2021 18:48:30 +0000 (13:48 -0500)
committerBryan Forbes <bryan@reigndropsfall.net>
Wed, 14 Jul 2021 18:48:30 +0000 (13:48 -0500)
Change-Id: I067d56dcfbc998ddd1b22a448f756859428b9e31

lib/sqlalchemy/ext/mypy/apply.py
lib/sqlalchemy/ext/mypy/decl_class.py
lib/sqlalchemy/ext/mypy/infer.py
lib/sqlalchemy/ext/mypy/names.py
lib/sqlalchemy/ext/mypy/plugin.py
lib/sqlalchemy/ext/mypy/util.py
setup.cfg

index 293ef2f9a5f8cb4e0556b76cc53000510b975020..cf5b4fda257ec42eb9fb643f6e2f12f97bbec17b 100644 (file)
@@ -5,10 +5,10 @@
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
 
+from typing import List
 from typing import Optional
 from typing import Union
 
-from mypy import nodes
 from mypy.nodes import ARG_NAMED_OPT
 from mypy.nodes import Argument
 from mypy.nodes import AssignmentStmt
@@ -17,6 +17,7 @@ from mypy.nodes import ClassDef
 from mypy.nodes import MDEF
 from mypy.nodes import MemberExpr
 from mypy.nodes import NameExpr
+from mypy.nodes import RefExpr
 from mypy.nodes import StrExpr
 from mypy.nodes import SymbolTableNode
 from mypy.nodes import TempNode
@@ -37,18 +38,18 @@ from . import infer
 from . import util
 
 
-def _apply_mypy_mapped_attr(
+def apply_mypy_mapped_attr(
     cls: ClassDef,
     api: SemanticAnalyzerPluginInterface,
     item: Union[NameExpr, StrExpr],
-    cls_metadata: util.DeclClassApplied,
+    attributes: List[util.SQLAlchemyAttribute],
 ) -> None:
     if isinstance(item, NameExpr):
         name = item.name
     elif isinstance(item, StrExpr):
         name = item.value
     else:
-        return
+        return None
 
     for stmt in cls.defs.body:
         if (
@@ -59,7 +60,7 @@ def _apply_mypy_mapped_attr(
             break
     else:
         util.fail(api, "Can't find mapped attribute {}".format(name), cls)
-        return
+        return None
 
     if stmt.type is None:
         util.fail(
@@ -68,32 +69,38 @@ def _apply_mypy_mapped_attr(
             "typing information",
             stmt,
         )
-        return
+        return None
 
     left_hand_explicit_type = get_proper_type(stmt.type)
     assert isinstance(
         left_hand_explicit_type, (Instance, UnionType, UnboundType)
     )
 
-    cls_metadata.mapped_attr_names.append((name, left_hand_explicit_type))
+    attributes.append(
+        util.SQLAlchemyAttribute(
+            name=name,
+            line=item.line,
+            column=item.column,
+            typ=left_hand_explicit_type,
+            info=cls.info,
+        )
+    )
 
-    _apply_type_to_mapped_statement(
+    apply_type_to_mapped_statement(
         api, stmt, stmt.lvalues[0], left_hand_explicit_type, None
     )
 
 
-def _re_apply_declarative_assignments(
+def re_apply_declarative_assignments(
     cls: ClassDef,
     api: SemanticAnalyzerPluginInterface,
-    cls_metadata: util.DeclClassApplied,
+    attributes: List[util.SQLAlchemyAttribute],
 ) -> None:
     """For multiple class passes, re-apply our left-hand side types as mypy
     seems to reset them in place.
 
     """
-    mapped_attr_lookup = {
-        name: typ for name, typ in cls_metadata.mapped_attr_names
-    }
+    mapped_attr_lookup = {attr.name: attr for attr in attributes}
     update_cls_metadata = False
 
     for stmt in cls.defs.body:
@@ -109,28 +116,37 @@ def _re_apply_declarative_assignments(
         ):
 
             left_node = stmt.lvalues[0].node
-            python_type_for_type = mapped_attr_lookup[stmt.lvalues[0].name]
+            python_type_for_type = mapped_attr_lookup[
+                stmt.lvalues[0].name
+            ].type
+
+            left_node_proper_type = get_proper_type(left_node.type)
+
             # if we have scanned an UnboundType and now there's a more
             # specific type than UnboundType, call the re-scan so we
             # can get that set up correctly
             if (
                 isinstance(python_type_for_type, UnboundType)
-                and not isinstance(left_node.type, UnboundType)
+                and not isinstance(left_node_proper_type, UnboundType)
                 and (
-                    isinstance(stmt.rvalue.callee, MemberExpr)
+                    isinstance(stmt.rvalue, CallExpr)
+                    and isinstance(stmt.rvalue.callee, MemberExpr)
+                    and isinstance(stmt.rvalue.callee.expr, NameExpr)
+                    and stmt.rvalue.callee.expr.node is not None
                     and stmt.rvalue.callee.expr.node.fullname
                     == "sqlalchemy.orm.attributes.Mapped"
                     and stmt.rvalue.callee.name == "_empty_constructor"
                     and isinstance(stmt.rvalue.args[0], CallExpr)
+                    and isinstance(stmt.rvalue.args[0].callee, RefExpr)
                 )
             ):
 
                 python_type_for_type = (
-                    infer._infer_type_from_right_hand_nameexpr(
+                    infer.infer_type_from_right_hand_nameexpr(
                         api,
                         stmt,
                         left_node,
-                        left_node.type,
+                        left_node_proper_type,
                         stmt.rvalue.args[0].callee,
                     )
                 )
@@ -140,21 +156,23 @@ def _re_apply_declarative_assignments(
                 ):
                     continue
 
-                # update the DeclClassApplied with the better information
-                mapped_attr_lookup[stmt.lvalues[0].name] = python_type_for_type
+                # update the SQLAlchemyAttribute with the better information
+                mapped_attr_lookup[
+                    stmt.lvalues[0].name
+                ].type = python_type_for_type
+
                 update_cls_metadata = True
 
-            left_node.type = api.named_type(
-                "__sa_Mapped", [python_type_for_type]
-            )
+            if python_type_for_type is not None:
+                left_node.type = api.named_type(
+                    "__sa_Mapped", [python_type_for_type]
+                )
 
     if update_cls_metadata:
-        cls_metadata.mapped_attr_names[:] = [
-            (k, v) for k, v in mapped_attr_lookup.items()
-        ]
+        util.set_mapped_attributes(cls.info, attributes)
 
 
-def _apply_type_to_mapped_statement(
+def apply_type_to_mapped_statement(
     api: SemanticAnalyzerPluginInterface,
     stmt: AssignmentStmt,
     lvalue: NameExpr,
@@ -205,30 +223,36 @@ def _apply_type_to_mapped_statement(
     # _sa_Mapped._empty_constructor(<original CallExpr from rvalue>)
     # the original right-hand side is maintained so it gets type checked
     # internally
-    column_descriptor = nodes.NameExpr("__sa_Mapped")
-    column_descriptor.fullname = "sqlalchemy.orm.attributes.Mapped"
-    mm = nodes.MemberExpr(column_descriptor, "_empty_constructor")
-    orig_call_expr = stmt.rvalue
-    stmt.rvalue = CallExpr(mm, [orig_call_expr], [nodes.ARG_POS], ["arg1"])
+    stmt.rvalue = util.expr_to_mapped_constructor(stmt.rvalue)
 
 
-def _add_additional_orm_attributes(
+def add_additional_orm_attributes(
     cls: ClassDef,
     api: SemanticAnalyzerPluginInterface,
-    cls_metadata: util.DeclClassApplied,
+    attributes: List[util.SQLAlchemyAttribute],
 ) -> None:
     """Apply __init__, __table__ and other attributes to the mapped class."""
 
-    info = util._info_for_cls(cls, api)
-    if "__init__" not in info.names and cls_metadata.is_mapped:
-        mapped_attr_names = {n: t for n, t in cls_metadata.mapped_attr_names}
+    info = util.info_for_cls(cls, api)
 
-        for mapped_base in cls_metadata.mapped_mro:
-            base_cls_metadata = util.DeclClassApplied.deserialize(
-                mapped_base.type.metadata["_sa_decl_class_applied"], api
-            )
-            for n, t in base_cls_metadata.mapped_attr_names:
-                mapped_attr_names.setdefault(n, t)
+    if info is None:
+        return
+
+    is_base = util.get_is_base(info)
+
+    if "__init__" not in info.names and not is_base:
+        mapped_attr_names = {attr.name: attr.type for attr in attributes}
+
+        for base in info.mro[1:-1]:
+            if "sqlalchemy" not in info.metadata:
+                continue
+
+            base_cls_attributes = util.get_mapped_attributes(base, api)
+            if base_cls_attributes is None:
+                continue
+
+            for attr in base_cls_attributes:
+                mapped_attr_names.setdefault(attr.name, attr.type)
 
         arguments = []
         for name, typ in mapped_attr_names.items():
@@ -242,13 +266,14 @@ def _add_additional_orm_attributes(
                     kind=ARG_NAMED_OPT,
                 )
             )
+
         add_method_to_class(api, cls, "__init__", arguments, NoneTyp())
 
-    if "__table__" not in info.names and cls_metadata.has_table:
+    if "__table__" not in info.names and util.get_has_table(info):
         _apply_placeholder_attr_to_class(
             api, cls, "sqlalchemy.sql.schema.Table", "__table__"
         )
-    if cls_metadata.is_mapped:
+    if not is_base:
         _apply_placeholder_attr_to_class(
             api, cls, "sqlalchemy.orm.mapper.Mapper", "__mapper__"
         )
index 45d025fc99341e0ce7f82b9e1fdba627c2aefd99..23c78aa51fc47b12c063d25a0eeda875fe4cf5ee 100644 (file)
@@ -5,14 +5,15 @@
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
 
+from typing import List
 from typing import Optional
 from typing import Union
 
-from mypy import nodes
 from mypy.nodes import AssignmentStmt
 from mypy.nodes import CallExpr
 from mypy.nodes import ClassDef
 from mypy.nodes import Decorator
+from mypy.nodes import LambdaExpr
 from mypy.nodes import ListExpr
 from mypy.nodes import MemberExpr
 from mypy.nodes import NameExpr
@@ -42,62 +43,68 @@ from . import names
 from . import util
 
 
-def _scan_declarative_assignments_and_apply_types(
+def scan_declarative_assignments_and_apply_types(
     cls: ClassDef,
     api: SemanticAnalyzerPluginInterface,
     is_mixin_scan: bool = False,
-) -> Optional[util.DeclClassApplied]:
+) -> Optional[List[util.SQLAlchemyAttribute]]:
 
-    info = util._info_for_cls(cls, api)
+    info = util.info_for_cls(cls, api)
 
     if info is None:
         # this can occur during cached passes
         return None
     elif cls.fullname.startswith("builtins"):
         return None
-    elif "_sa_decl_class_applied" in info.metadata:
-        cls_metadata = util.DeclClassApplied.deserialize(
-            info.metadata["_sa_decl_class_applied"], api
-        )
 
+    mapped_attributes: Optional[
+        List[util.SQLAlchemyAttribute]
+    ] = util.get_mapped_attributes(info, api)
+
+    if mapped_attributes is not None:
         # ensure that a class that's mapped is always picked up by
         # its mapped() decorator or declarative metaclass before
         # it would be detected as an unmapped mixin class
-        if not is_mixin_scan:
-            assert cls_metadata.is_mapped
 
+        if not is_mixin_scan:
             # mypy can call us more than once.  it then *may* have reset the
             # left hand side of everything, but not the right that we removed,
             # removing our ability to re-scan.   but we have the types
             # here, so lets re-apply them, or if we have an UnboundType,
             # we can re-scan
 
-            apply._re_apply_declarative_assignments(cls, api, cls_metadata)
+            apply.re_apply_declarative_assignments(cls, api, mapped_attributes)
 
-        return cls_metadata
+        return mapped_attributes
 
-    cls_metadata = util.DeclClassApplied(not is_mixin_scan, False, [], [])
+    mapped_attributes = []
 
     if not cls.defs.body:
         # when we get a mixin class from another file, the body is
         # empty (!) but the names are in the symbol table.  so use that.
 
         for sym_name, sym in info.names.items():
-            _scan_symbol_table_entry(cls, api, sym_name, sym, cls_metadata)
+            _scan_symbol_table_entry(
+                cls, api, sym_name, sym, mapped_attributes
+            )
     else:
-        for stmt in util._flatten_typechecking(cls.defs.body):
+        for stmt in util.flatten_typechecking(cls.defs.body):
             if isinstance(stmt, AssignmentStmt):
-                _scan_declarative_assignment_stmt(cls, api, stmt, cls_metadata)
+                _scan_declarative_assignment_stmt(
+                    cls, api, stmt, mapped_attributes
+                )
             elif isinstance(stmt, Decorator):
-                _scan_declarative_decorator_stmt(cls, api, stmt, cls_metadata)
-    _scan_for_mapped_bases(cls, api, cls_metadata)
+                _scan_declarative_decorator_stmt(
+                    cls, api, stmt, mapped_attributes
+                )
+    _scan_for_mapped_bases(cls, api)
 
     if not is_mixin_scan:
-        apply._add_additional_orm_attributes(cls, api, cls_metadata)
+        apply.add_additional_orm_attributes(cls, api, mapped_attributes)
 
-    info.metadata["_sa_decl_class_applied"] = cls_metadata.serialize()
+    util.set_mapped_attributes(info, mapped_attributes)
 
-    return cls_metadata
+    return mapped_attributes
 
 
 def _scan_symbol_table_entry(
@@ -105,7 +112,7 @@ def _scan_symbol_table_entry(
     api: SemanticAnalyzerPluginInterface,
     name: str,
     value: SymbolTableNode,
-    cls_metadata: util.DeclClassApplied,
+    attributes: List[util.SQLAlchemyAttribute],
 ) -> None:
     """Extract mapping information from a SymbolTableNode that's in the
     type.names dictionary.
@@ -116,7 +123,7 @@ def _scan_symbol_table_entry(
         return
 
     left_hand_explicit_type = None
-    type_id = names._type_id_for_named_node(value_type.type)
+    type_id = names.type_id_for_named_node(value_type.type)
     # type_id = names._type_id_for_unbound_type(value.type.type, cls, api)
 
     err = False
@@ -148,11 +155,11 @@ def _scan_symbol_table_entry(
             if isinstance(typeengine_arg, (UnboundType, TypeInfo)):
                 sym = api.lookup_qualified(typeengine_arg.name, typeengine_arg)
                 if sym is not None and isinstance(sym.node, TypeInfo):
-                    if names._has_base_type_id(sym.node, names.TYPEENGINE):
+                    if names.has_base_type_id(sym.node, names.TYPEENGINE):
 
                         left_hand_explicit_type = UnionType(
                             [
-                                infer._extract_python_type_from_typeengine(
+                                infer.extract_python_type_from_typeengine(
                                     api, sym.node, []
                                 ),
                                 NoneType(),
@@ -178,14 +185,23 @@ def _scan_symbol_table_entry(
         left_hand_explicit_type = AnyType(TypeOfAny.special_form)
 
     if left_hand_explicit_type is not None:
-        cls_metadata.mapped_attr_names.append((name, left_hand_explicit_type))
+        assert value.node is not None
+        attributes.append(
+            util.SQLAlchemyAttribute(
+                name=name,
+                line=value.node.line,
+                column=value.node.column,
+                typ=left_hand_explicit_type,
+                info=cls.info,
+            )
+        )
 
 
 def _scan_declarative_decorator_stmt(
     cls: ClassDef,
     api: SemanticAnalyzerPluginInterface,
     stmt: Decorator,
-    cls_metadata: util.DeclClassApplied,
+    attributes: List[util.SQLAlchemyAttribute],
 ) -> None:
     """Extract mapping information from a @declared_attr in a declarative
     class.
@@ -212,7 +228,7 @@ def _scan_declarative_decorator_stmt(
     for dec in stmt.decorators:
         if (
             isinstance(dec, (NameExpr, MemberExpr, SymbolNode))
-            and names._type_id_for_named_node(dec) is names.DECLARED_ATTR
+            and names.type_id_for_named_node(dec) is names.DECLARED_ATTR
         ):
             break
     else:
@@ -225,7 +241,7 @@ def _scan_declarative_decorator_stmt(
     if isinstance(stmt.func.type, CallableType):
         func_type = stmt.func.type.ret_type
         if isinstance(func_type, UnboundType):
-            type_id = names._type_id_for_unbound_type(func_type, cls, api)
+            type_id = names.type_id_for_unbound_type(func_type, cls, api)
         else:
             # this does not seem to occur unless the type argument is
             # incorrect
@@ -249,10 +265,10 @@ def _scan_declarative_decorator_stmt(
             if isinstance(typeengine_arg, UnboundType):
                 sym = api.lookup_qualified(typeengine_arg.name, typeengine_arg)
                 if sym is not None and isinstance(sym.node, TypeInfo):
-                    if names._has_base_type_id(sym.node, names.TYPEENGINE):
+                    if names.has_base_type_id(sym.node, names.TYPEENGINE):
                         left_hand_explicit_type = UnionType(
                             [
-                                infer._extract_python_type_from_typeengine(
+                                infer.extract_python_type_from_typeengine(
                                     api, sym.node, []
                                 ),
                                 NoneType(),
@@ -291,7 +307,7 @@ def _scan_declarative_decorator_stmt(
     # we see everywhere else.
     if isinstance(left_hand_explicit_type, UnboundType):
         left_hand_explicit_type = get_proper_type(
-            util._unbound_to_instance(api, left_hand_explicit_type)
+            util.unbound_to_instance(api, left_hand_explicit_type)
         )
 
     left_node.node.type = api.named_type(
@@ -305,23 +321,21 @@ def _scan_declarative_decorator_stmt(
     # <attr> : Mapped[<typ>] =
     # _sa_Mapped._empty_constructor(lambda: <function body>)
     # the function body is maintained so it gets type checked internally
-    column_descriptor = nodes.NameExpr("__sa_Mapped")
-    column_descriptor.fullname = "sqlalchemy.orm.attributes.Mapped"
-    mm = nodes.MemberExpr(column_descriptor, "_empty_constructor")
-
-    arg = nodes.LambdaExpr(stmt.func.arguments, stmt.func.body)
-    rvalue = CallExpr(
-        mm,
-        [arg],
-        [nodes.ARG_POS],
-        ["arg1"],
+    rvalue = util.expr_to_mapped_constructor(
+        LambdaExpr(stmt.func.arguments, stmt.func.body)
     )
 
     new_stmt = AssignmentStmt([left_node], rvalue)
     new_stmt.type = left_node.node.type
 
-    cls_metadata.mapped_attr_names.append(
-        (left_node.name, left_hand_explicit_type)
+    attributes.append(
+        util.SQLAlchemyAttribute(
+            name=left_node.name,
+            line=stmt.line,
+            column=stmt.column,
+            typ=left_hand_explicit_type,
+            info=cls.info,
+        )
     )
     cls.defs.body[dec_index] = new_stmt
 
@@ -330,7 +344,7 @@ def _scan_declarative_assignment_stmt(
     cls: ClassDef,
     api: SemanticAnalyzerPluginInterface,
     stmt: AssignmentStmt,
-    cls_metadata: util.DeclClassApplied,
+    attributes: List[util.SQLAlchemyAttribute],
 ) -> None:
     """Extract mapping information from an assignment statement in a
     declarative class.
@@ -356,10 +370,10 @@ def _scan_declarative_assignment_stmt(
 
     if node.name == "__abstract__":
         if api.parse_bool(stmt.rvalue) is True:
-            cls_metadata.is_mapped = False
+            util.set_is_base(cls.info)
         return
     elif node.name == "__tablename__":
-        cls_metadata.has_table = True
+        util.set_has_table(cls.info)
     elif node.name.startswith("__"):
         return
     elif node.name == "_mypy_mapped_attrs":
@@ -368,7 +382,7 @@ def _scan_declarative_assignment_stmt(
         else:
             for item in stmt.rvalue.items:
                 if isinstance(item, (NameExpr, StrExpr)):
-                    apply._apply_mypy_mapped_attr(cls, api, item, cls_metadata)
+                    apply.apply_mypy_mapped_attr(cls, api, item, attributes)
 
     left_hand_mapped_type: Optional[Type] = None
     left_hand_explicit_type: Optional[ProperType] = None
@@ -388,7 +402,7 @@ def _scan_declarative_assignment_stmt(
                 if (
                     mapped_sym is not None
                     and mapped_sym.node is not None
-                    and names._type_id_for_named_node(mapped_sym.node)
+                    and names.type_id_for_named_node(mapped_sym.node)
                     is names.MAPPED
                 ):
                     left_hand_explicit_type = get_proper_type(
@@ -404,7 +418,7 @@ def _scan_declarative_assignment_stmt(
         node_type = get_proper_type(node.type)
         if (
             isinstance(node_type, Instance)
-            and names._type_id_for_named_node(node_type.type) is names.MAPPED
+            and names.type_id_for_named_node(node_type.type) is names.MAPPED
         ):
             # print(node.type)
             # sqlalchemy.orm.attributes.Mapped[<python type>]
@@ -426,7 +440,7 @@ def _scan_declarative_assignment_stmt(
         stmt.rvalue.callee, RefExpr
     ):
 
-        python_type_for_type = infer._infer_type_from_right_hand_nameexpr(
+        python_type_for_type = infer.infer_type_from_right_hand_nameexpr(
             api, stmt, node, left_hand_explicit_type, stmt.rvalue.callee
         )
 
@@ -438,9 +452,17 @@ def _scan_declarative_assignment_stmt(
 
     assert python_type_for_type is not None
 
-    cls_metadata.mapped_attr_names.append((node.name, python_type_for_type))
+    attributes.append(
+        util.SQLAlchemyAttribute(
+            name=node.name,
+            line=stmt.line,
+            column=stmt.column,
+            typ=python_type_for_type,
+            info=cls.info,
+        )
+    )
 
-    apply._apply_type_to_mapped_statement(
+    apply.apply_type_to_mapped_statement(
         api,
         stmt,
         lvalue,
@@ -452,7 +474,6 @@ def _scan_declarative_assignment_stmt(
 def _scan_for_mapped_bases(
     cls: ClassDef,
     api: SemanticAnalyzerPluginInterface,
-    cls_metadata: util.DeclClassApplied,
 ) -> None:
     """Given a class, iterate through its superclass hierarchy to find
     all other classes that are considered as ORM-significant.
@@ -462,25 +483,18 @@ def _scan_for_mapped_bases(
 
     """
 
-    info = util._info_for_cls(cls, api)
+    info = util.info_for_cls(cls, api)
 
-    baseclasses = list(info.bases)
-
-    while baseclasses:
-        base: Instance = baseclasses.pop(0)
+    if info is None:
+        return
 
-        if base.type.fullname.startswith("builtins"):
+    for base_info in info.mro[1:-1]:
+        if base_info.fullname.startswith("builtins"):
             continue
 
         # scan each base for mapped attributes.  if they are not already
         # scanned (but have all their type info), that means they are unmapped
         # mixins
-        base_decl_class_applied = (
-            _scan_declarative_assignments_and_apply_types(
-                base.type.defn, api, is_mixin_scan=True
-            )
+        scan_declarative_assignments_and_apply_types(
+            base_info.defn, api, is_mixin_scan=True
         )
-
-        if base_decl_class_applied is not None:
-            cls_metadata.mapped_mro.append(base)
-        baseclasses.extend(base.type.bases)
index ca2b62966e2a379876a3ba5fac1f636c390d7809..85a94bba61cf49e08bdec89ede137d91e796fb13 100644 (file)
@@ -35,15 +35,15 @@ from . import names
 from . import util
 
 
-def _infer_type_from_right_hand_nameexpr(
+def infer_type_from_right_hand_nameexpr(
     api: SemanticAnalyzerPluginInterface,
     stmt: AssignmentStmt,
     node: Var,
     left_hand_explicit_type: Optional[ProperType],
-    infer_from_right_side: NameExpr,
+    infer_from_right_side: RefExpr,
 ) -> Optional[ProperType]:
 
-    type_id = names._type_id_for_callee(infer_from_right_side)
+    type_id = names.type_id_for_callee(infer_from_right_side)
 
     if type_id is None:
         return None
@@ -60,7 +60,7 @@ def _infer_type_from_right_hand_nameexpr(
             api, stmt, node, left_hand_explicit_type
         )
     elif type_id is names.SYNONYM_PROPERTY:
-        python_type_for_type = _infer_type_from_left_hand_type_only(
+        python_type_for_type = infer_type_from_left_hand_type_only(
             api, node, left_hand_explicit_type
         )
     elif type_id is names.COMPOSITE_PROPERTY:
@@ -128,8 +128,8 @@ def _infer_type_from_relationship(
     # string expression
     # isinstance(target_cls_arg, StrExpr)
 
-    uselist_arg = util._get_callexpr_kwarg(stmt.rvalue, "uselist")
-    collection_cls_arg: Optional[Expression] = util._get_callexpr_kwarg(
+    uselist_arg = util.get_callexpr_kwarg(stmt.rvalue, "uselist")
+    collection_cls_arg: Optional[Expression] = util.get_callexpr_kwarg(
         stmt.rvalue, "collection_class"
     )
     type_is_a_collection = False
@@ -137,7 +137,7 @@ def _infer_type_from_relationship(
     # this can be used to determine Optional for a many-to-one
     # in the same way nullable=False could be used, if we start supporting
     # that.
-    # innerjoin_arg = _get_callexpr_kwarg(stmt.rvalue, "innerjoin")
+    # innerjoin_arg = util.get_callexpr_kwarg(stmt.rvalue, "innerjoin")
 
     if (
         uselist_arg is not None
@@ -218,7 +218,7 @@ def _infer_type_from_relationship(
             util.fail(api, msg.format(node.name), node)
 
     if python_type_for_type is None:
-        return _infer_type_from_left_hand_type_only(
+        return infer_type_from_left_hand_type_only(
             api, node, left_hand_explicit_type
         )
     elif left_hand_explicit_type is not None:
@@ -260,7 +260,7 @@ def _infer_type_from_decl_composite_property(
         python_type_for_type = None
 
     if python_type_for_type is None:
-        return _infer_type_from_left_hand_type_only(
+        return infer_type_from_left_hand_type_only(
             api, node, left_hand_explicit_type
         )
     elif left_hand_explicit_type is not None:
@@ -287,7 +287,7 @@ def _infer_type_from_decl_column_property(
     first_prop_arg = stmt.rvalue.args[0]
 
     if isinstance(first_prop_arg, CallExpr):
-        type_id = names._type_id_for_callee(first_prop_arg.callee)
+        type_id = names.type_id_for_callee(first_prop_arg.callee)
 
         # look for column_property() / deferred() etc with Column as first
         # argument
@@ -300,7 +300,7 @@ def _infer_type_from_decl_column_property(
                 right_hand_expression=first_prop_arg,
             )
 
-    return _infer_type_from_left_hand_type_only(
+    return infer_type_from_left_hand_type_only(
         api, node, left_hand_explicit_type
     )
 
@@ -378,10 +378,10 @@ def _infer_type_from_decl_column(
     if callee is None:
         return None
 
-    if isinstance(callee.node, TypeInfo) and names._mro_has_id(
+    if isinstance(callee.node, TypeInfo) and names.mro_has_id(
         callee.node.mro, names.TYPEENGINE
     ):
-        python_type_for_type = _extract_python_type_from_typeengine(
+        python_type_for_type = extract_python_type_from_typeengine(
             api, callee.node, type_args
         )
 
@@ -396,7 +396,7 @@ def _infer_type_from_decl_column(
     else:
         # it's not TypeEngine, it's typically implicitly typed
         # like ForeignKey.  we can't infer from the right side.
-        return _infer_type_from_left_hand_type_only(
+        return infer_type_from_left_hand_type_only(
             api, node, left_hand_explicit_type
         )
 
@@ -472,7 +472,7 @@ def _infer_collection_type_from_left_and_inferred_right(
     )
 
 
-def _infer_type_from_left_hand_type_only(
+def infer_type_from_left_hand_type_only(
     api: SemanticAnalyzerPluginInterface,
     node: Var,
     left_hand_explicit_type: Optional[ProperType],
@@ -499,7 +499,7 @@ def _infer_type_from_left_hand_type_only(
         return left_hand_explicit_type
 
 
-def _extract_python_type_from_typeengine(
+def extract_python_type_from_typeengine(
     api: SemanticAnalyzerPluginInterface,
     node: TypeInfo,
     type_args: Sequence[Expression],
index 653ce4985a0dc883f07e88d7a4160811607ac432..22a79e29b96ba88cc62cc0c97dbdd42c4e416701 100644 (file)
@@ -153,7 +153,7 @@ _lookup: Dict[str, Tuple[int, Set[str]]] = {
 }
 
 
-def _has_base_type_id(info: TypeInfo, type_id: int) -> bool:
+def has_base_type_id(info: TypeInfo, type_id: int) -> bool:
     for mr in info.mro:
         check_type_id, fullnames = _lookup.get(mr.name, (None, None))
         if check_type_id == type_id:
@@ -167,7 +167,7 @@ def _has_base_type_id(info: TypeInfo, type_id: int) -> bool:
     return mr.fullname in fullnames
 
 
-def _mro_has_id(mro: List[TypeInfo], type_id: int) -> bool:
+def mro_has_id(mro: List[TypeInfo], type_id: int) -> bool:
     for mr in mro:
         check_type_id, fullnames = _lookup.get(mr.name, (None, None))
         if check_type_id == type_id:
@@ -181,49 +181,41 @@ def _mro_has_id(mro: List[TypeInfo], type_id: int) -> bool:
     return mr.fullname in fullnames
 
 
-def _type_id_for_unbound_type(
+def type_id_for_unbound_type(
     type_: UnboundType, cls: ClassDef, api: SemanticAnalyzerPluginInterface
 ) -> Optional[int]:
-    type_id = None
-
     sym = api.lookup_qualified(type_.name, type_)
     if sym is not None:
         if isinstance(sym.node, TypeAlias):
             target_type = get_proper_type(sym.node.target)
             if isinstance(target_type, Instance):
-                type_id = _type_id_for_named_node(target_type.type)
+                return type_id_for_named_node(target_type.type)
         elif isinstance(sym.node, TypeInfo):
-            type_id = _type_id_for_named_node(sym.node)
+            return type_id_for_named_node(sym.node)
 
-    return type_id
+    return None
 
 
-def _type_id_for_callee(callee: Expression) -> Optional[int]:
+def type_id_for_callee(callee: Expression) -> Optional[int]:
     if isinstance(callee, (MemberExpr, NameExpr)):
         if isinstance(callee.node, FuncDef):
-            return _type_id_for_funcdef(callee.node)
+            if callee.node.type and isinstance(callee.node.type, CallableType):
+                ret_type = get_proper_type(callee.node.type.ret_type)
+
+                if isinstance(ret_type, Instance):
+                    return type_id_for_fullname(ret_type.type.fullname)
+
+            return None
         elif isinstance(callee.node, TypeAlias):
             target_type = get_proper_type(callee.node.target)
             if isinstance(target_type, Instance):
-                type_id = _type_id_for_fullname(target_type.type.fullname)
+                return type_id_for_fullname(target_type.type.fullname)
         elif isinstance(callee.node, TypeInfo):
-            type_id = _type_id_for_named_node(callee)
-        else:
-            type_id = None
-    return type_id
-
-
-def _type_id_for_funcdef(node: FuncDef) -> Optional[int]:
-    if node.type and isinstance(node.type, CallableType):
-        ret_type = get_proper_type(node.type.ret_type)
-
-        if isinstance(ret_type, Instance):
-            return _type_id_for_fullname(ret_type.type.fullname)
-
+            return type_id_for_named_node(callee)
     return None
 
 
-def _type_id_for_named_node(
+def type_id_for_named_node(
     node: Union[NameExpr, MemberExpr, SymbolNode]
 ) -> Optional[int]:
     type_id, fullnames = _lookup.get(node.name, (None, None))
@@ -236,7 +228,7 @@ def _type_id_for_named_node(
         return None
 
 
-def _type_id_for_fullname(fullname: str) -> Optional[int]:
+def type_id_for_fullname(fullname: str) -> Optional[int]:
     tokens = fullname.split(".")
     immediate = tokens[-1]
 
index 687aeb8513118c4b1415b811e0300c0df9b11f4a..356b0d9489eb2066e17b98988378389925903702 100644 (file)
@@ -45,29 +45,14 @@ class SQLAlchemyPlugin(Plugin):
     def get_dynamic_class_hook(
         self, fullname: str
     ) -> Optional[Callable[[DynamicClassDefContext], None]]:
-        if names._type_id_for_fullname(fullname) is names.DECLARATIVE_BASE:
+        if names.type_id_for_fullname(fullname) is names.DECLARATIVE_BASE:
             return _dynamic_class_hook
         return None
 
-    def get_base_class_hook(
+    def get_customize_class_mro_hook(
         self, fullname: str
     ) -> Optional[Callable[[ClassDefContext], None]]:
-
-        # kind of a strange relationship between get_metaclass_hook()
-        # and get_base_class_hook().  the former doesn't fire off for
-        # subclasses.   but then you can just check it here from the "base"
-        # and get the same effect.
-        sym = self.lookup_fully_qualified(fullname)
-
-        if (
-            sym
-            and isinstance(sym.node, TypeInfo)
-            and sym.node.metaclass_type
-            and names._type_id_for_named_node(sym.node.metaclass_type.type)
-            is names.DECLARATIVE_META
-        ):
-            return _base_cls_hook
-        return None
+        return _fill_in_decorators
 
     def get_class_decorator_hook(
         self, fullname: str
@@ -76,7 +61,7 @@ class SQLAlchemyPlugin(Plugin):
         sym = self.lookup_fully_qualified(fullname)
 
         if sym is not None and sym.node is not None:
-            type_id = names._type_id_for_named_node(sym.node)
+            type_id = names.type_id_for_named_node(sym.node)
             if type_id is names.MAPPED_DECORATOR:
                 return _cls_decorator_hook
             elif type_id in (
@@ -89,10 +74,29 @@ class SQLAlchemyPlugin(Plugin):
 
         return None
 
-    def get_customize_class_mro_hook(
+    def get_metaclass_hook(
         self, fullname: str
     ) -> Optional[Callable[[ClassDefContext], None]]:
-        return _fill_in_decorators
+        if names.type_id_for_fullname(fullname) is names.DECLARATIVE_META:
+            # Set any classes that explicitly have metaclass=DeclarativeMeta
+            # as declarative so the check in `get_base_class_hook()` works
+            return _metaclass_cls_hook
+
+        return None
+
+    def get_base_class_hook(
+        self, fullname: str
+    ) -> Optional[Callable[[ClassDefContext], None]]:
+        sym = self.lookup_fully_qualified(fullname)
+
+        if (
+            sym
+            and isinstance(sym.node, TypeInfo)
+            and util.has_declarative_base(sym.node)
+        ):
+            return _base_cls_hook
+
+        return None
 
     def get_attribute_hook(
         self, fullname: str
@@ -101,6 +105,7 @@ class SQLAlchemyPlugin(Plugin):
             "sqlalchemy.orm.attributes.QueryableAttribute."
         ):
             return _queryable_getattr_hook
+
         return None
 
     def get_additional_deps(
@@ -116,10 +121,43 @@ def plugin(version: str) -> TypingType[SQLAlchemyPlugin]:
     return SQLAlchemyPlugin
 
 
-def _queryable_getattr_hook(ctx: AttributeContext) -> Type:
-    # how do I....tell it it has no attribute of a certain name?
-    # can't find any Type that seems to match that
-    return ctx.default_attr_type
+def _dynamic_class_hook(ctx: DynamicClassDefContext) -> None:
+    """Generate a declarative Base class when the declarative_base() function
+    is encountered."""
+
+    _add_globals(ctx)
+
+    cls = ClassDef(ctx.name, Block([]))
+    cls.fullname = ctx.api.qualified_name(ctx.name)
+
+    info = TypeInfo(SymbolTable(), cls, ctx.api.cur_mod_id)
+    cls.info = info
+    _set_declarative_metaclass(ctx.api, cls)
+
+    cls_arg = util.get_callexpr_kwarg(ctx.call, "cls", expr_types=(NameExpr,))
+    if cls_arg is not None and isinstance(cls_arg.node, TypeInfo):
+        util.set_is_base(cls_arg.node)
+        decl_class.scan_declarative_assignments_and_apply_types(
+            cls_arg.node.defn, ctx.api, is_mixin_scan=True
+        )
+        info.bases = [Instance(cls_arg.node, [])]
+    else:
+        obj = ctx.api.named_type("__builtins__.object")
+
+        info.bases = [obj]
+
+    try:
+        calculate_mro(info)
+    except MroError:
+        util.fail(
+            ctx.api, "Not able to calculate MRO for declarative base", ctx.call
+        )
+        obj = ctx.api.named_type("__builtins__.object")
+        info.bases = [obj]
+        info.fallback_to_any = True
+
+    ctx.api.add_symbol_table_node(ctx.name, SymbolTableNode(GDEF, info))
+    util.set_is_base(info)
 
 
 def _fill_in_decorators(ctx: ClassDefContext) -> None:
@@ -173,39 +211,6 @@ def _fill_in_decorators(ctx: ClassDefContext) -> None:
                 )
 
 
-def _add_globals(ctx: Union[ClassDefContext, DynamicClassDefContext]) -> None:
-    """Add __sa_DeclarativeMeta and __sa_Mapped symbol to the global space
-    for all class defs
-
-    """
-
-    util.add_global(
-        ctx,
-        "sqlalchemy.orm.decl_api",
-        "DeclarativeMeta",
-        "__sa_DeclarativeMeta",
-    )
-
-    util.add_global(ctx, "sqlalchemy.orm.attributes", "Mapped", "__sa_Mapped")
-
-
-def _cls_metadata_hook(ctx: ClassDefContext) -> None:
-    _add_globals(ctx)
-    decl_class._scan_declarative_assignments_and_apply_types(ctx.cls, ctx.api)
-
-
-def _base_cls_hook(ctx: ClassDefContext) -> None:
-    _add_globals(ctx)
-    decl_class._scan_declarative_assignments_and_apply_types(ctx.cls, ctx.api)
-
-
-def _declarative_mixin_hook(ctx: ClassDefContext) -> None:
-    _add_globals(ctx)
-    decl_class._scan_declarative_assignments_and_apply_types(
-        ctx.cls, ctx.api, is_mixin_scan=True
-    )
-
-
 def _cls_decorator_hook(ctx: ClassDefContext) -> None:
     _add_globals(ctx)
     assert isinstance(ctx.reason, nodes.MemberExpr)
@@ -217,10 +222,10 @@ def _cls_decorator_hook(ctx: ClassDefContext) -> None:
 
     assert (
         isinstance(node_type, Instance)
-        and names._type_id_for_named_node(node_type.type) is names.REGISTRY
+        and names.type_id_for_named_node(node_type.type) is names.REGISTRY
     )
 
-    decl_class._scan_declarative_assignments_and_apply_types(ctx.cls, ctx.api)
+    decl_class.scan_declarative_assignments_and_apply_types(ctx.cls, ctx.api)
 
 
 def _base_cls_decorator_hook(ctx: ClassDefContext) -> None:
@@ -228,69 +233,52 @@ def _base_cls_decorator_hook(ctx: ClassDefContext) -> None:
 
     cls = ctx.cls
 
-    _make_declarative_meta(ctx.api, cls)
+    _set_declarative_metaclass(ctx.api, cls)
 
-    decl_class._scan_declarative_assignments_and_apply_types(
+    util.set_is_base(ctx.cls.info)
+    decl_class.scan_declarative_assignments_and_apply_types(
         cls, ctx.api, is_mixin_scan=True
     )
 
 
-def _dynamic_class_hook(ctx: DynamicClassDefContext) -> None:
-    """Generate a declarative Base class when the declarative_base() function
-    is encountered."""
-
+def _declarative_mixin_hook(ctx: ClassDefContext) -> None:
     _add_globals(ctx)
+    util.set_is_base(ctx.cls.info)
+    decl_class.scan_declarative_assignments_and_apply_types(
+        ctx.cls, ctx.api, is_mixin_scan=True
+    )
 
-    cls = ClassDef(ctx.name, Block([]))
-    cls.fullname = ctx.api.qualified_name(ctx.name)
-
-    info = TypeInfo(SymbolTable(), cls, ctx.api.cur_mod_id)
-    cls.info = info
-    _make_declarative_meta(ctx.api, cls)
-
-    cls_arg = util._get_callexpr_kwarg(ctx.call, "cls", expr_types=(NameExpr,))
-    if cls_arg is not None and isinstance(cls_arg.node, TypeInfo):
-        decl_class._scan_declarative_assignments_and_apply_types(
-            cls_arg.node.defn, ctx.api, is_mixin_scan=True
-        )
-        info.bases = [Instance(cls_arg.node, [])]
-    else:
-        obj = ctx.api.named_type("__builtins__.object")
-
-        info.bases = [obj]
 
-    try:
-        calculate_mro(info)
-    except MroError:
-        util.fail(
-            ctx.api, "Not able to calculate MRO for declarative base", ctx.call
-        )
-        obj = ctx.api.named_type("__builtins__.object")
-        info.bases = [obj]
-        info.fallback_to_any = True
+def _metaclass_cls_hook(ctx: ClassDefContext) -> None:
+    util.set_is_base(ctx.cls.info)
 
-    ctx.api.add_symbol_table_node(ctx.name, SymbolTableNode(GDEF, info))
 
+def _base_cls_hook(ctx: ClassDefContext) -> None:
+    _add_globals(ctx)
+    decl_class.scan_declarative_assignments_and_apply_types(ctx.cls, ctx.api)
 
-def _make_declarative_meta(
-    api: SemanticAnalyzerPluginInterface, target_cls: ClassDef
-) -> None:
 
-    declarative_meta_name: NameExpr = NameExpr("__sa_DeclarativeMeta")
-    declarative_meta_name.kind = GDEF
-    declarative_meta_name.fullname = "sqlalchemy.orm.decl_api.DeclarativeMeta"
+def _queryable_getattr_hook(ctx: AttributeContext) -> Type:
+    # how do I....tell it it has no attribute of a certain name?
+    # can't find any Type that seems to match that
+    return ctx.default_attr_type
 
-    # installed by _add_globals
-    sym = api.lookup_qualified("__sa_DeclarativeMeta", target_cls)
 
-    assert sym is not None and isinstance(sym.node, nodes.TypeInfo)
+def _add_globals(ctx: Union[ClassDefContext, DynamicClassDefContext]) -> None:
+    """Add __sa_DeclarativeMeta and __sa_Mapped symbol to the global space
+    for all class defs
 
-    declarative_meta_typeinfo = sym.node
-    declarative_meta_name.node = declarative_meta_typeinfo
+    """
 
-    target_cls.metaclass = declarative_meta_name
+    util.add_global(ctx, "sqlalchemy.orm.attributes", "Mapped", "__sa_Mapped")
 
-    declarative_meta_instance = Instance(declarative_meta_typeinfo, [])
 
+def _set_declarative_metaclass(
+    api: SemanticAnalyzerPluginInterface, target_cls: ClassDef
+) -> None:
     info = target_cls.info
-    info.declared_metaclass = info.metaclass_type = declarative_meta_instance
+    sym = api.lookup_fully_qualified_or_none(
+        "sqlalchemy.orm.decl_api.DeclarativeMeta"
+    )
+    assert sym is not None and isinstance(sym.node, TypeInfo)
+    info.declared_metaclass = info.metaclass_type = Instance(sym.node, [])
index 67c3fa209162e997b724ad0ecf30e30eef73506d..614805d77c139fbea2ca116c5f4c207ffb24f1d3 100644 (file)
@@ -1,5 +1,4 @@
 from typing import Any
-from typing import cast
 from typing import Iterable
 from typing import Iterator
 from typing import List
@@ -10,12 +9,15 @@ from typing import Type as TypingType
 from typing import TypeVar
 from typing import Union
 
+from mypy.nodes import ARG_POS
 from mypy.nodes import CallExpr
 from mypy.nodes import ClassDef
 from mypy.nodes import CLASSDEF_NO_INFO
 from mypy.nodes import Context
+from mypy.nodes import Expression
 from mypy.nodes import IfStmt
 from mypy.nodes import JsonDict
+from mypy.nodes import MemberExpr
 from mypy.nodes import NameExpr
 from mypy.nodes import Statement
 from mypy.nodes import SymbolTableNode
@@ -24,10 +26,11 @@ from mypy.plugin import ClassDefContext
 from mypy.plugin import DynamicClassDefContext
 from mypy.plugin import SemanticAnalyzerPluginInterface
 from mypy.plugins.common import deserialize_and_fixup_type
+from mypy.typeops import map_type_from_supertype
 from mypy.types import Instance
 from mypy.types import NoneType
-from mypy.types import ProperType
 from mypy.types import Type
+from mypy.types import TypeVarType
 from mypy.types import UnboundType
 from mypy.types import UnionType
 
@@ -35,53 +38,117 @@ from mypy.types import UnionType
 _TArgType = TypeVar("_TArgType", bound=Union[CallExpr, NameExpr])
 
 
-class DeclClassApplied:
+class SQLAlchemyAttribute:
     def __init__(
         self,
-        is_mapped: bool,
-        has_table: bool,
-        mapped_attr_names: Iterable[Tuple[str, ProperType]],
-        mapped_mro: Iterable[Instance],
-    ):
-        self.is_mapped = is_mapped
-        self.has_table = has_table
-        self.mapped_attr_names = list(mapped_attr_names)
-        self.mapped_mro = list(mapped_mro)
+        name: str,
+        line: int,
+        column: int,
+        typ: Optional[Type],
+        info: TypeInfo,
+    ) -> None:
+        self.name = name
+        self.line = line
+        self.column = column
+        self.type = typ
+        self.info = info
 
     def serialize(self) -> JsonDict:
+        assert self.type
         return {
-            "is_mapped": self.is_mapped,
-            "has_table": self.has_table,
-            "mapped_attr_names": [
-                (name, type_.serialize())
-                for name, type_ in self.mapped_attr_names
-            ],
-            "mapped_mro": [type_.serialize() for type_ in self.mapped_mro],
+            "name": self.name,
+            "line": self.line,
+            "column": self.column,
+            "type": self.type.serialize(),
         }
 
+    def expand_typevar_from_subtype(self, sub_type: TypeInfo) -> None:
+        """Expands type vars in the context of a subtype when an attribute is inherited
+        from a generic super type."""
+        if not isinstance(self.type, TypeVarType):
+            return
+
+        self.type = map_type_from_supertype(self.type, sub_type, self.info)
+
     @classmethod
     def deserialize(
-        cls, data: JsonDict, api: SemanticAnalyzerPluginInterface
-    ) -> "DeclClassApplied":
-
-        return DeclClassApplied(
-            is_mapped=data["is_mapped"],
-            has_table=data["has_table"],
-            mapped_attr_names=cast(
-                List[Tuple[str, ProperType]],
-                [
-                    (name, deserialize_and_fixup_type(type_, api))
-                    for name, type_ in data["mapped_attr_names"]
-                ],
-            ),
-            mapped_mro=cast(
-                List[Instance],
-                [
-                    deserialize_and_fixup_type(type_, api)
-                    for type_ in data["mapped_mro"]
-                ],
-            ),
-        )
+        cls,
+        info: TypeInfo,
+        data: JsonDict,
+        api: SemanticAnalyzerPluginInterface,
+    ) -> "SQLAlchemyAttribute":
+        data = data.copy()
+        typ = deserialize_and_fixup_type(data.pop("type"), api)
+        return cls(typ=typ, info=info, **data)
+
+
+def _set_info_metadata(info: TypeInfo, key: str, data: Any) -> None:
+    info.metadata.setdefault("sqlalchemy", {})[key] = data
+
+
+def _get_info_metadata(info: TypeInfo, key: str) -> Optional[Any]:
+    return info.metadata.get("sqlalchemy", {}).get(key, None)
+
+
+def _get_info_mro_metadata(info: TypeInfo, key: str) -> Optional[Any]:
+    if info.mro:
+        for base in info.mro:
+            metadata = _get_info_metadata(base, key)
+            if metadata is not None:
+                return metadata
+    return None
+
+
+def set_is_base(info: TypeInfo) -> None:
+    _set_info_metadata(info, "is_base", True)
+
+
+def get_is_base(info: TypeInfo) -> bool:
+    is_base = _get_info_metadata(info, "is_base")
+    return is_base is True
+
+
+def has_declarative_base(info: TypeInfo) -> bool:
+    is_base = _get_info_mro_metadata(info, "is_base")
+    return is_base is True
+
+
+def set_has_table(info: TypeInfo) -> None:
+    _set_info_metadata(info, "has_table", True)
+
+
+def get_has_table(info: TypeInfo) -> bool:
+    is_base = _get_info_metadata(info, "has_table")
+    return is_base is True
+
+
+def get_mapped_attributes(
+    info: TypeInfo, api: SemanticAnalyzerPluginInterface
+) -> Optional[List[SQLAlchemyAttribute]]:
+    mapped_attributes: Optional[List[JsonDict]] = _get_info_metadata(
+        info, "mapped_attributes"
+    )
+    if mapped_attributes is None:
+        return None
+
+    attributes: List[SQLAlchemyAttribute] = []
+
+    for data in mapped_attributes:
+        attr = SQLAlchemyAttribute.deserialize(info, data, api)
+        attr.expand_typevar_from_subtype(info)
+        attributes.append(attr)
+
+    return attributes
+
+
+def set_mapped_attributes(
+    info: TypeInfo, attributes: List[SQLAlchemyAttribute]
+) -> None:
+    _set_info_metadata(
+        info,
+        "mapped_attributes",
+        [attribute.serialize() for attribute in attributes],
+    )
 
 
 def fail(api: SemanticAnalyzerPluginInterface, msg: str, ctx: Context) -> None:
@@ -106,14 +173,14 @@ def add_global(
 
 
 @overload
-def _get_callexpr_kwarg(
+def get_callexpr_kwarg(
     callexpr: CallExpr, name: str, *, expr_types: None = ...
 ) -> Optional[Union[CallExpr, NameExpr]]:
     ...
 
 
 @overload
-def _get_callexpr_kwarg(
+def get_callexpr_kwarg(
     callexpr: CallExpr,
     name: str,
     *,
@@ -122,7 +189,7 @@ def _get_callexpr_kwarg(
     ...
 
 
-def _get_callexpr_kwarg(
+def get_callexpr_kwarg(
     callexpr: CallExpr,
     name: str,
     *,
@@ -142,7 +209,7 @@ def _get_callexpr_kwarg(
     return None
 
 
-def _flatten_typechecking(stmts: Iterable[Statement]) -> Iterator[Statement]:
+def flatten_typechecking(stmts: Iterable[Statement]) -> Iterator[Statement]:
     for stmt in stmts:
         if (
             isinstance(stmt, IfStmt)
@@ -155,7 +222,7 @@ def _flatten_typechecking(stmts: Iterable[Statement]) -> Iterator[Statement]:
             yield stmt
 
 
-def _unbound_to_instance(
+def unbound_to_instance(
     api: SemanticAnalyzerPluginInterface, typ: Type
 ) -> Type:
     """Take the UnboundType that we seem to get as the ret_type from a FuncDef
@@ -173,10 +240,10 @@ def _unbound_to_instance(
     if typ.name == "Optional":
         # convert from "Optional?" to the more familiar
         # UnionType[..., NoneType()]
-        return _unbound_to_instance(
+        return unbound_to_instance(
             api,
             UnionType(
-                [_unbound_to_instance(api, typ_arg) for typ_arg in typ.args]
+                [unbound_to_instance(api, typ_arg) for typ_arg in typ.args]
                 + [NoneType()]
             ),
         )
@@ -193,7 +260,7 @@ def _unbound_to_instance(
         return Instance(
             bound_type,
             [
-                _unbound_to_instance(api, arg)
+                unbound_to_instance(api, arg)
                 if isinstance(arg, UnboundType)
                 else arg
                 for arg in typ.args
@@ -203,9 +270,9 @@ def _unbound_to_instance(
         return typ
 
 
-def _info_for_cls(
+def info_for_cls(
     cls: ClassDef, api: SemanticAnalyzerPluginInterface
-) -> TypeInfo:
+) -> Optional[TypeInfo]:
     if cls.info is CLASSDEF_NO_INFO:
         sym = api.lookup_qualified(cls.name, cls)
         if sym is None:
@@ -214,3 +281,15 @@ def _info_for_cls(
         return sym.node
 
     return cls.info
+
+
+def expr_to_mapped_constructor(expr: Expression) -> CallExpr:
+    column_descriptor = NameExpr("__sa_Mapped")
+    column_descriptor.fullname = "sqlalchemy.orm.attributes.Mapped"
+    member_expr = MemberExpr(column_descriptor, "_empty_constructor")
+    return CallExpr(
+        member_expr,
+        [expr],
+        [ARG_POS],
+        ["arg1"],
+    )
index 846b7284891188d703d4985fa17edb9e3a9fa71b..4b238fcadf182b30167be03aa0577500a4dae604 100644 (file)
--- a/setup.cfg
+++ b/setup.cfg
@@ -46,7 +46,7 @@ install_requires =
 asyncio =
     greenlet!=0.4.17;python_version>="3"
 mypy =
-    mypy >= 0.800;python_version>="3"
+    mypy >= 0.910;python_version>="3"
     sqlalchemy2-stubs
 mssql = pyodbc
 mssql_pymssql = pymssql