From 3e2295b2da7b57a6669f26db0df78f6409934184 Mon Sep 17 00:00:00 2001 From: Bryan Forbes Date: Wed, 14 Jul 2021 13:48:30 -0500 Subject: [PATCH] Refactor mypy plugin Change-Id: I067d56dcfbc998ddd1b22a448f756859428b9e31 --- lib/sqlalchemy/ext/mypy/apply.py | 115 +++++++++------ lib/sqlalchemy/ext/mypy/decl_class.py | 150 ++++++++++--------- lib/sqlalchemy/ext/mypy/infer.py | 32 ++-- lib/sqlalchemy/ext/mypy/names.py | 44 +++--- lib/sqlalchemy/ext/mypy/plugin.py | 202 ++++++++++++-------------- lib/sqlalchemy/ext/mypy/util.py | 179 ++++++++++++++++------- setup.cfg | 2 +- 7 files changed, 411 insertions(+), 313 deletions(-) diff --git a/lib/sqlalchemy/ext/mypy/apply.py b/lib/sqlalchemy/ext/mypy/apply.py index 293ef2f9a5..cf5b4fda25 100644 --- a/lib/sqlalchemy/ext/mypy/apply.py +++ b/lib/sqlalchemy/ext/mypy/apply.py @@ -5,10 +5,10 @@ # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +from typing import List from typing import Optional from typing import Union -from mypy import nodes from mypy.nodes import ARG_NAMED_OPT from mypy.nodes import Argument from mypy.nodes import AssignmentStmt @@ -17,6 +17,7 @@ from mypy.nodes import ClassDef from mypy.nodes import MDEF from mypy.nodes import MemberExpr from mypy.nodes import NameExpr +from mypy.nodes import RefExpr from mypy.nodes import StrExpr from mypy.nodes import SymbolTableNode from mypy.nodes import TempNode @@ -37,18 +38,18 @@ from . import infer from . import util -def _apply_mypy_mapped_attr( +def apply_mypy_mapped_attr( cls: ClassDef, api: SemanticAnalyzerPluginInterface, item: Union[NameExpr, StrExpr], - cls_metadata: util.DeclClassApplied, + attributes: List[util.SQLAlchemyAttribute], ) -> None: if isinstance(item, NameExpr): name = item.name elif isinstance(item, StrExpr): name = item.value else: - return + return None for stmt in cls.defs.body: if ( @@ -59,7 +60,7 @@ def _apply_mypy_mapped_attr( break else: util.fail(api, "Can't find mapped attribute {}".format(name), cls) - return + return None if stmt.type is None: util.fail( @@ -68,32 +69,38 @@ def _apply_mypy_mapped_attr( "typing information", stmt, ) - return + return None left_hand_explicit_type = get_proper_type(stmt.type) assert isinstance( left_hand_explicit_type, (Instance, UnionType, UnboundType) ) - cls_metadata.mapped_attr_names.append((name, left_hand_explicit_type)) + attributes.append( + util.SQLAlchemyAttribute( + name=name, + line=item.line, + column=item.column, + typ=left_hand_explicit_type, + info=cls.info, + ) + ) - _apply_type_to_mapped_statement( + apply_type_to_mapped_statement( api, stmt, stmt.lvalues[0], left_hand_explicit_type, None ) -def _re_apply_declarative_assignments( +def re_apply_declarative_assignments( cls: ClassDef, api: SemanticAnalyzerPluginInterface, - cls_metadata: util.DeclClassApplied, + attributes: List[util.SQLAlchemyAttribute], ) -> None: """For multiple class passes, re-apply our left-hand side types as mypy seems to reset them in place. """ - mapped_attr_lookup = { - name: typ for name, typ in cls_metadata.mapped_attr_names - } + mapped_attr_lookup = {attr.name: attr for attr in attributes} update_cls_metadata = False for stmt in cls.defs.body: @@ -109,28 +116,37 @@ def _re_apply_declarative_assignments( ): left_node = stmt.lvalues[0].node - python_type_for_type = mapped_attr_lookup[stmt.lvalues[0].name] + python_type_for_type = mapped_attr_lookup[ + stmt.lvalues[0].name + ].type + + left_node_proper_type = get_proper_type(left_node.type) + # if we have scanned an UnboundType and now there's a more # specific type than UnboundType, call the re-scan so we # can get that set up correctly if ( isinstance(python_type_for_type, UnboundType) - and not isinstance(left_node.type, UnboundType) + and not isinstance(left_node_proper_type, UnboundType) and ( - isinstance(stmt.rvalue.callee, MemberExpr) + isinstance(stmt.rvalue, CallExpr) + and isinstance(stmt.rvalue.callee, MemberExpr) + and isinstance(stmt.rvalue.callee.expr, NameExpr) + and stmt.rvalue.callee.expr.node is not None and stmt.rvalue.callee.expr.node.fullname == "sqlalchemy.orm.attributes.Mapped" and stmt.rvalue.callee.name == "_empty_constructor" and isinstance(stmt.rvalue.args[0], CallExpr) + and isinstance(stmt.rvalue.args[0].callee, RefExpr) ) ): python_type_for_type = ( - infer._infer_type_from_right_hand_nameexpr( + infer.infer_type_from_right_hand_nameexpr( api, stmt, left_node, - left_node.type, + left_node_proper_type, stmt.rvalue.args[0].callee, ) ) @@ -140,21 +156,23 @@ def _re_apply_declarative_assignments( ): continue - # update the DeclClassApplied with the better information - mapped_attr_lookup[stmt.lvalues[0].name] = python_type_for_type + # update the SQLAlchemyAttribute with the better information + mapped_attr_lookup[ + stmt.lvalues[0].name + ].type = python_type_for_type + update_cls_metadata = True - left_node.type = api.named_type( - "__sa_Mapped", [python_type_for_type] - ) + if python_type_for_type is not None: + left_node.type = api.named_type( + "__sa_Mapped", [python_type_for_type] + ) if update_cls_metadata: - cls_metadata.mapped_attr_names[:] = [ - (k, v) for k, v in mapped_attr_lookup.items() - ] + util.set_mapped_attributes(cls.info, attributes) -def _apply_type_to_mapped_statement( +def apply_type_to_mapped_statement( api: SemanticAnalyzerPluginInterface, stmt: AssignmentStmt, lvalue: NameExpr, @@ -205,30 +223,36 @@ def _apply_type_to_mapped_statement( # _sa_Mapped._empty_constructor() # the original right-hand side is maintained so it gets type checked # internally - column_descriptor = nodes.NameExpr("__sa_Mapped") - column_descriptor.fullname = "sqlalchemy.orm.attributes.Mapped" - mm = nodes.MemberExpr(column_descriptor, "_empty_constructor") - orig_call_expr = stmt.rvalue - stmt.rvalue = CallExpr(mm, [orig_call_expr], [nodes.ARG_POS], ["arg1"]) + stmt.rvalue = util.expr_to_mapped_constructor(stmt.rvalue) -def _add_additional_orm_attributes( +def add_additional_orm_attributes( cls: ClassDef, api: SemanticAnalyzerPluginInterface, - cls_metadata: util.DeclClassApplied, + attributes: List[util.SQLAlchemyAttribute], ) -> None: """Apply __init__, __table__ and other attributes to the mapped class.""" - info = util._info_for_cls(cls, api) - if "__init__" not in info.names and cls_metadata.is_mapped: - mapped_attr_names = {n: t for n, t in cls_metadata.mapped_attr_names} + info = util.info_for_cls(cls, api) - for mapped_base in cls_metadata.mapped_mro: - base_cls_metadata = util.DeclClassApplied.deserialize( - mapped_base.type.metadata["_sa_decl_class_applied"], api - ) - for n, t in base_cls_metadata.mapped_attr_names: - mapped_attr_names.setdefault(n, t) + if info is None: + return + + is_base = util.get_is_base(info) + + if "__init__" not in info.names and not is_base: + mapped_attr_names = {attr.name: attr.type for attr in attributes} + + for base in info.mro[1:-1]: + if "sqlalchemy" not in info.metadata: + continue + + base_cls_attributes = util.get_mapped_attributes(base, api) + if base_cls_attributes is None: + continue + + for attr in base_cls_attributes: + mapped_attr_names.setdefault(attr.name, attr.type) arguments = [] for name, typ in mapped_attr_names.items(): @@ -242,13 +266,14 @@ def _add_additional_orm_attributes( kind=ARG_NAMED_OPT, ) ) + add_method_to_class(api, cls, "__init__", arguments, NoneTyp()) - if "__table__" not in info.names and cls_metadata.has_table: + if "__table__" not in info.names and util.get_has_table(info): _apply_placeholder_attr_to_class( api, cls, "sqlalchemy.sql.schema.Table", "__table__" ) - if cls_metadata.is_mapped: + if not is_base: _apply_placeholder_attr_to_class( api, cls, "sqlalchemy.orm.mapper.Mapper", "__mapper__" ) diff --git a/lib/sqlalchemy/ext/mypy/decl_class.py b/lib/sqlalchemy/ext/mypy/decl_class.py index 45d025fc99..23c78aa51f 100644 --- a/lib/sqlalchemy/ext/mypy/decl_class.py +++ b/lib/sqlalchemy/ext/mypy/decl_class.py @@ -5,14 +5,15 @@ # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +from typing import List from typing import Optional from typing import Union -from mypy import nodes from mypy.nodes import AssignmentStmt from mypy.nodes import CallExpr from mypy.nodes import ClassDef from mypy.nodes import Decorator +from mypy.nodes import LambdaExpr from mypy.nodes import ListExpr from mypy.nodes import MemberExpr from mypy.nodes import NameExpr @@ -42,62 +43,68 @@ from . import names from . import util -def _scan_declarative_assignments_and_apply_types( +def scan_declarative_assignments_and_apply_types( cls: ClassDef, api: SemanticAnalyzerPluginInterface, is_mixin_scan: bool = False, -) -> Optional[util.DeclClassApplied]: +) -> Optional[List[util.SQLAlchemyAttribute]]: - info = util._info_for_cls(cls, api) + info = util.info_for_cls(cls, api) if info is None: # this can occur during cached passes return None elif cls.fullname.startswith("builtins"): return None - elif "_sa_decl_class_applied" in info.metadata: - cls_metadata = util.DeclClassApplied.deserialize( - info.metadata["_sa_decl_class_applied"], api - ) + mapped_attributes: Optional[ + List[util.SQLAlchemyAttribute] + ] = util.get_mapped_attributes(info, api) + + if mapped_attributes is not None: # ensure that a class that's mapped is always picked up by # its mapped() decorator or declarative metaclass before # it would be detected as an unmapped mixin class - if not is_mixin_scan: - assert cls_metadata.is_mapped + if not is_mixin_scan: # mypy can call us more than once. it then *may* have reset the # left hand side of everything, but not the right that we removed, # removing our ability to re-scan. but we have the types # here, so lets re-apply them, or if we have an UnboundType, # we can re-scan - apply._re_apply_declarative_assignments(cls, api, cls_metadata) + apply.re_apply_declarative_assignments(cls, api, mapped_attributes) - return cls_metadata + return mapped_attributes - cls_metadata = util.DeclClassApplied(not is_mixin_scan, False, [], []) + mapped_attributes = [] if not cls.defs.body: # when we get a mixin class from another file, the body is # empty (!) but the names are in the symbol table. so use that. for sym_name, sym in info.names.items(): - _scan_symbol_table_entry(cls, api, sym_name, sym, cls_metadata) + _scan_symbol_table_entry( + cls, api, sym_name, sym, mapped_attributes + ) else: - for stmt in util._flatten_typechecking(cls.defs.body): + for stmt in util.flatten_typechecking(cls.defs.body): if isinstance(stmt, AssignmentStmt): - _scan_declarative_assignment_stmt(cls, api, stmt, cls_metadata) + _scan_declarative_assignment_stmt( + cls, api, stmt, mapped_attributes + ) elif isinstance(stmt, Decorator): - _scan_declarative_decorator_stmt(cls, api, stmt, cls_metadata) - _scan_for_mapped_bases(cls, api, cls_metadata) + _scan_declarative_decorator_stmt( + cls, api, stmt, mapped_attributes + ) + _scan_for_mapped_bases(cls, api) if not is_mixin_scan: - apply._add_additional_orm_attributes(cls, api, cls_metadata) + apply.add_additional_orm_attributes(cls, api, mapped_attributes) - info.metadata["_sa_decl_class_applied"] = cls_metadata.serialize() + util.set_mapped_attributes(info, mapped_attributes) - return cls_metadata + return mapped_attributes def _scan_symbol_table_entry( @@ -105,7 +112,7 @@ def _scan_symbol_table_entry( api: SemanticAnalyzerPluginInterface, name: str, value: SymbolTableNode, - cls_metadata: util.DeclClassApplied, + attributes: List[util.SQLAlchemyAttribute], ) -> None: """Extract mapping information from a SymbolTableNode that's in the type.names dictionary. @@ -116,7 +123,7 @@ def _scan_symbol_table_entry( return left_hand_explicit_type = None - type_id = names._type_id_for_named_node(value_type.type) + type_id = names.type_id_for_named_node(value_type.type) # type_id = names._type_id_for_unbound_type(value.type.type, cls, api) err = False @@ -148,11 +155,11 @@ def _scan_symbol_table_entry( if isinstance(typeengine_arg, (UnboundType, TypeInfo)): sym = api.lookup_qualified(typeengine_arg.name, typeengine_arg) if sym is not None and isinstance(sym.node, TypeInfo): - if names._has_base_type_id(sym.node, names.TYPEENGINE): + if names.has_base_type_id(sym.node, names.TYPEENGINE): left_hand_explicit_type = UnionType( [ - infer._extract_python_type_from_typeengine( + infer.extract_python_type_from_typeengine( api, sym.node, [] ), NoneType(), @@ -178,14 +185,23 @@ def _scan_symbol_table_entry( left_hand_explicit_type = AnyType(TypeOfAny.special_form) if left_hand_explicit_type is not None: - cls_metadata.mapped_attr_names.append((name, left_hand_explicit_type)) + assert value.node is not None + attributes.append( + util.SQLAlchemyAttribute( + name=name, + line=value.node.line, + column=value.node.column, + typ=left_hand_explicit_type, + info=cls.info, + ) + ) def _scan_declarative_decorator_stmt( cls: ClassDef, api: SemanticAnalyzerPluginInterface, stmt: Decorator, - cls_metadata: util.DeclClassApplied, + attributes: List[util.SQLAlchemyAttribute], ) -> None: """Extract mapping information from a @declared_attr in a declarative class. @@ -212,7 +228,7 @@ def _scan_declarative_decorator_stmt( for dec in stmt.decorators: if ( isinstance(dec, (NameExpr, MemberExpr, SymbolNode)) - and names._type_id_for_named_node(dec) is names.DECLARED_ATTR + and names.type_id_for_named_node(dec) is names.DECLARED_ATTR ): break else: @@ -225,7 +241,7 @@ def _scan_declarative_decorator_stmt( if isinstance(stmt.func.type, CallableType): func_type = stmt.func.type.ret_type if isinstance(func_type, UnboundType): - type_id = names._type_id_for_unbound_type(func_type, cls, api) + type_id = names.type_id_for_unbound_type(func_type, cls, api) else: # this does not seem to occur unless the type argument is # incorrect @@ -249,10 +265,10 @@ def _scan_declarative_decorator_stmt( if isinstance(typeengine_arg, UnboundType): sym = api.lookup_qualified(typeengine_arg.name, typeengine_arg) if sym is not None and isinstance(sym.node, TypeInfo): - if names._has_base_type_id(sym.node, names.TYPEENGINE): + if names.has_base_type_id(sym.node, names.TYPEENGINE): left_hand_explicit_type = UnionType( [ - infer._extract_python_type_from_typeengine( + infer.extract_python_type_from_typeengine( api, sym.node, [] ), NoneType(), @@ -291,7 +307,7 @@ def _scan_declarative_decorator_stmt( # we see everywhere else. if isinstance(left_hand_explicit_type, UnboundType): left_hand_explicit_type = get_proper_type( - util._unbound_to_instance(api, left_hand_explicit_type) + util.unbound_to_instance(api, left_hand_explicit_type) ) left_node.node.type = api.named_type( @@ -305,23 +321,21 @@ def _scan_declarative_decorator_stmt( # : Mapped[] = # _sa_Mapped._empty_constructor(lambda: ) # the function body is maintained so it gets type checked internally - column_descriptor = nodes.NameExpr("__sa_Mapped") - column_descriptor.fullname = "sqlalchemy.orm.attributes.Mapped" - mm = nodes.MemberExpr(column_descriptor, "_empty_constructor") - - arg = nodes.LambdaExpr(stmt.func.arguments, stmt.func.body) - rvalue = CallExpr( - mm, - [arg], - [nodes.ARG_POS], - ["arg1"], + rvalue = util.expr_to_mapped_constructor( + LambdaExpr(stmt.func.arguments, stmt.func.body) ) new_stmt = AssignmentStmt([left_node], rvalue) new_stmt.type = left_node.node.type - cls_metadata.mapped_attr_names.append( - (left_node.name, left_hand_explicit_type) + attributes.append( + util.SQLAlchemyAttribute( + name=left_node.name, + line=stmt.line, + column=stmt.column, + typ=left_hand_explicit_type, + info=cls.info, + ) ) cls.defs.body[dec_index] = new_stmt @@ -330,7 +344,7 @@ def _scan_declarative_assignment_stmt( cls: ClassDef, api: SemanticAnalyzerPluginInterface, stmt: AssignmentStmt, - cls_metadata: util.DeclClassApplied, + attributes: List[util.SQLAlchemyAttribute], ) -> None: """Extract mapping information from an assignment statement in a declarative class. @@ -356,10 +370,10 @@ def _scan_declarative_assignment_stmt( if node.name == "__abstract__": if api.parse_bool(stmt.rvalue) is True: - cls_metadata.is_mapped = False + util.set_is_base(cls.info) return elif node.name == "__tablename__": - cls_metadata.has_table = True + util.set_has_table(cls.info) elif node.name.startswith("__"): return elif node.name == "_mypy_mapped_attrs": @@ -368,7 +382,7 @@ def _scan_declarative_assignment_stmt( else: for item in stmt.rvalue.items: if isinstance(item, (NameExpr, StrExpr)): - apply._apply_mypy_mapped_attr(cls, api, item, cls_metadata) + apply.apply_mypy_mapped_attr(cls, api, item, attributes) left_hand_mapped_type: Optional[Type] = None left_hand_explicit_type: Optional[ProperType] = None @@ -388,7 +402,7 @@ def _scan_declarative_assignment_stmt( if ( mapped_sym is not None and mapped_sym.node is not None - and names._type_id_for_named_node(mapped_sym.node) + and names.type_id_for_named_node(mapped_sym.node) is names.MAPPED ): left_hand_explicit_type = get_proper_type( @@ -404,7 +418,7 @@ def _scan_declarative_assignment_stmt( node_type = get_proper_type(node.type) if ( isinstance(node_type, Instance) - and names._type_id_for_named_node(node_type.type) is names.MAPPED + and names.type_id_for_named_node(node_type.type) is names.MAPPED ): # print(node.type) # sqlalchemy.orm.attributes.Mapped[] @@ -426,7 +440,7 @@ def _scan_declarative_assignment_stmt( stmt.rvalue.callee, RefExpr ): - python_type_for_type = infer._infer_type_from_right_hand_nameexpr( + python_type_for_type = infer.infer_type_from_right_hand_nameexpr( api, stmt, node, left_hand_explicit_type, stmt.rvalue.callee ) @@ -438,9 +452,17 @@ def _scan_declarative_assignment_stmt( assert python_type_for_type is not None - cls_metadata.mapped_attr_names.append((node.name, python_type_for_type)) + attributes.append( + util.SQLAlchemyAttribute( + name=node.name, + line=stmt.line, + column=stmt.column, + typ=python_type_for_type, + info=cls.info, + ) + ) - apply._apply_type_to_mapped_statement( + apply.apply_type_to_mapped_statement( api, stmt, lvalue, @@ -452,7 +474,6 @@ def _scan_declarative_assignment_stmt( def _scan_for_mapped_bases( cls: ClassDef, api: SemanticAnalyzerPluginInterface, - cls_metadata: util.DeclClassApplied, ) -> None: """Given a class, iterate through its superclass hierarchy to find all other classes that are considered as ORM-significant. @@ -462,25 +483,18 @@ def _scan_for_mapped_bases( """ - info = util._info_for_cls(cls, api) + info = util.info_for_cls(cls, api) - baseclasses = list(info.bases) - - while baseclasses: - base: Instance = baseclasses.pop(0) + if info is None: + return - if base.type.fullname.startswith("builtins"): + for base_info in info.mro[1:-1]: + if base_info.fullname.startswith("builtins"): continue # scan each base for mapped attributes. if they are not already # scanned (but have all their type info), that means they are unmapped # mixins - base_decl_class_applied = ( - _scan_declarative_assignments_and_apply_types( - base.type.defn, api, is_mixin_scan=True - ) + scan_declarative_assignments_and_apply_types( + base_info.defn, api, is_mixin_scan=True ) - - if base_decl_class_applied is not None: - cls_metadata.mapped_mro.append(base) - baseclasses.extend(base.type.bases) diff --git a/lib/sqlalchemy/ext/mypy/infer.py b/lib/sqlalchemy/ext/mypy/infer.py index ca2b62966e..85a94bba61 100644 --- a/lib/sqlalchemy/ext/mypy/infer.py +++ b/lib/sqlalchemy/ext/mypy/infer.py @@ -35,15 +35,15 @@ from . import names from . import util -def _infer_type_from_right_hand_nameexpr( +def infer_type_from_right_hand_nameexpr( api: SemanticAnalyzerPluginInterface, stmt: AssignmentStmt, node: Var, left_hand_explicit_type: Optional[ProperType], - infer_from_right_side: NameExpr, + infer_from_right_side: RefExpr, ) -> Optional[ProperType]: - type_id = names._type_id_for_callee(infer_from_right_side) + type_id = names.type_id_for_callee(infer_from_right_side) if type_id is None: return None @@ -60,7 +60,7 @@ def _infer_type_from_right_hand_nameexpr( api, stmt, node, left_hand_explicit_type ) elif type_id is names.SYNONYM_PROPERTY: - python_type_for_type = _infer_type_from_left_hand_type_only( + python_type_for_type = infer_type_from_left_hand_type_only( api, node, left_hand_explicit_type ) elif type_id is names.COMPOSITE_PROPERTY: @@ -128,8 +128,8 @@ def _infer_type_from_relationship( # string expression # isinstance(target_cls_arg, StrExpr) - uselist_arg = util._get_callexpr_kwarg(stmt.rvalue, "uselist") - collection_cls_arg: Optional[Expression] = util._get_callexpr_kwarg( + uselist_arg = util.get_callexpr_kwarg(stmt.rvalue, "uselist") + collection_cls_arg: Optional[Expression] = util.get_callexpr_kwarg( stmt.rvalue, "collection_class" ) type_is_a_collection = False @@ -137,7 +137,7 @@ def _infer_type_from_relationship( # this can be used to determine Optional for a many-to-one # in the same way nullable=False could be used, if we start supporting # that. - # innerjoin_arg = _get_callexpr_kwarg(stmt.rvalue, "innerjoin") + # innerjoin_arg = util.get_callexpr_kwarg(stmt.rvalue, "innerjoin") if ( uselist_arg is not None @@ -218,7 +218,7 @@ def _infer_type_from_relationship( util.fail(api, msg.format(node.name), node) if python_type_for_type is None: - return _infer_type_from_left_hand_type_only( + return infer_type_from_left_hand_type_only( api, node, left_hand_explicit_type ) elif left_hand_explicit_type is not None: @@ -260,7 +260,7 @@ def _infer_type_from_decl_composite_property( python_type_for_type = None if python_type_for_type is None: - return _infer_type_from_left_hand_type_only( + return infer_type_from_left_hand_type_only( api, node, left_hand_explicit_type ) elif left_hand_explicit_type is not None: @@ -287,7 +287,7 @@ def _infer_type_from_decl_column_property( first_prop_arg = stmt.rvalue.args[0] if isinstance(first_prop_arg, CallExpr): - type_id = names._type_id_for_callee(first_prop_arg.callee) + type_id = names.type_id_for_callee(first_prop_arg.callee) # look for column_property() / deferred() etc with Column as first # argument @@ -300,7 +300,7 @@ def _infer_type_from_decl_column_property( right_hand_expression=first_prop_arg, ) - return _infer_type_from_left_hand_type_only( + return infer_type_from_left_hand_type_only( api, node, left_hand_explicit_type ) @@ -378,10 +378,10 @@ def _infer_type_from_decl_column( if callee is None: return None - if isinstance(callee.node, TypeInfo) and names._mro_has_id( + if isinstance(callee.node, TypeInfo) and names.mro_has_id( callee.node.mro, names.TYPEENGINE ): - python_type_for_type = _extract_python_type_from_typeengine( + python_type_for_type = extract_python_type_from_typeengine( api, callee.node, type_args ) @@ -396,7 +396,7 @@ def _infer_type_from_decl_column( else: # it's not TypeEngine, it's typically implicitly typed # like ForeignKey. we can't infer from the right side. - return _infer_type_from_left_hand_type_only( + return infer_type_from_left_hand_type_only( api, node, left_hand_explicit_type ) @@ -472,7 +472,7 @@ def _infer_collection_type_from_left_and_inferred_right( ) -def _infer_type_from_left_hand_type_only( +def infer_type_from_left_hand_type_only( api: SemanticAnalyzerPluginInterface, node: Var, left_hand_explicit_type: Optional[ProperType], @@ -499,7 +499,7 @@ def _infer_type_from_left_hand_type_only( return left_hand_explicit_type -def _extract_python_type_from_typeengine( +def extract_python_type_from_typeengine( api: SemanticAnalyzerPluginInterface, node: TypeInfo, type_args: Sequence[Expression], diff --git a/lib/sqlalchemy/ext/mypy/names.py b/lib/sqlalchemy/ext/mypy/names.py index 653ce4985a..22a79e29b9 100644 --- a/lib/sqlalchemy/ext/mypy/names.py +++ b/lib/sqlalchemy/ext/mypy/names.py @@ -153,7 +153,7 @@ _lookup: Dict[str, Tuple[int, Set[str]]] = { } -def _has_base_type_id(info: TypeInfo, type_id: int) -> bool: +def has_base_type_id(info: TypeInfo, type_id: int) -> bool: for mr in info.mro: check_type_id, fullnames = _lookup.get(mr.name, (None, None)) if check_type_id == type_id: @@ -167,7 +167,7 @@ def _has_base_type_id(info: TypeInfo, type_id: int) -> bool: return mr.fullname in fullnames -def _mro_has_id(mro: List[TypeInfo], type_id: int) -> bool: +def mro_has_id(mro: List[TypeInfo], type_id: int) -> bool: for mr in mro: check_type_id, fullnames = _lookup.get(mr.name, (None, None)) if check_type_id == type_id: @@ -181,49 +181,41 @@ def _mro_has_id(mro: List[TypeInfo], type_id: int) -> bool: return mr.fullname in fullnames -def _type_id_for_unbound_type( +def type_id_for_unbound_type( type_: UnboundType, cls: ClassDef, api: SemanticAnalyzerPluginInterface ) -> Optional[int]: - type_id = None - sym = api.lookup_qualified(type_.name, type_) if sym is not None: if isinstance(sym.node, TypeAlias): target_type = get_proper_type(sym.node.target) if isinstance(target_type, Instance): - type_id = _type_id_for_named_node(target_type.type) + return type_id_for_named_node(target_type.type) elif isinstance(sym.node, TypeInfo): - type_id = _type_id_for_named_node(sym.node) + return type_id_for_named_node(sym.node) - return type_id + return None -def _type_id_for_callee(callee: Expression) -> Optional[int]: +def type_id_for_callee(callee: Expression) -> Optional[int]: if isinstance(callee, (MemberExpr, NameExpr)): if isinstance(callee.node, FuncDef): - return _type_id_for_funcdef(callee.node) + if callee.node.type and isinstance(callee.node.type, CallableType): + ret_type = get_proper_type(callee.node.type.ret_type) + + if isinstance(ret_type, Instance): + return type_id_for_fullname(ret_type.type.fullname) + + return None elif isinstance(callee.node, TypeAlias): target_type = get_proper_type(callee.node.target) if isinstance(target_type, Instance): - type_id = _type_id_for_fullname(target_type.type.fullname) + return type_id_for_fullname(target_type.type.fullname) elif isinstance(callee.node, TypeInfo): - type_id = _type_id_for_named_node(callee) - else: - type_id = None - return type_id - - -def _type_id_for_funcdef(node: FuncDef) -> Optional[int]: - if node.type and isinstance(node.type, CallableType): - ret_type = get_proper_type(node.type.ret_type) - - if isinstance(ret_type, Instance): - return _type_id_for_fullname(ret_type.type.fullname) - + return type_id_for_named_node(callee) return None -def _type_id_for_named_node( +def type_id_for_named_node( node: Union[NameExpr, MemberExpr, SymbolNode] ) -> Optional[int]: type_id, fullnames = _lookup.get(node.name, (None, None)) @@ -236,7 +228,7 @@ def _type_id_for_named_node( return None -def _type_id_for_fullname(fullname: str) -> Optional[int]: +def type_id_for_fullname(fullname: str) -> Optional[int]: tokens = fullname.split(".") immediate = tokens[-1] diff --git a/lib/sqlalchemy/ext/mypy/plugin.py b/lib/sqlalchemy/ext/mypy/plugin.py index 687aeb8513..356b0d9489 100644 --- a/lib/sqlalchemy/ext/mypy/plugin.py +++ b/lib/sqlalchemy/ext/mypy/plugin.py @@ -45,29 +45,14 @@ class SQLAlchemyPlugin(Plugin): def get_dynamic_class_hook( self, fullname: str ) -> Optional[Callable[[DynamicClassDefContext], None]]: - if names._type_id_for_fullname(fullname) is names.DECLARATIVE_BASE: + if names.type_id_for_fullname(fullname) is names.DECLARATIVE_BASE: return _dynamic_class_hook return None - def get_base_class_hook( + def get_customize_class_mro_hook( self, fullname: str ) -> Optional[Callable[[ClassDefContext], None]]: - - # kind of a strange relationship between get_metaclass_hook() - # and get_base_class_hook(). the former doesn't fire off for - # subclasses. but then you can just check it here from the "base" - # and get the same effect. - sym = self.lookup_fully_qualified(fullname) - - if ( - sym - and isinstance(sym.node, TypeInfo) - and sym.node.metaclass_type - and names._type_id_for_named_node(sym.node.metaclass_type.type) - is names.DECLARATIVE_META - ): - return _base_cls_hook - return None + return _fill_in_decorators def get_class_decorator_hook( self, fullname: str @@ -76,7 +61,7 @@ class SQLAlchemyPlugin(Plugin): sym = self.lookup_fully_qualified(fullname) if sym is not None and sym.node is not None: - type_id = names._type_id_for_named_node(sym.node) + type_id = names.type_id_for_named_node(sym.node) if type_id is names.MAPPED_DECORATOR: return _cls_decorator_hook elif type_id in ( @@ -89,10 +74,29 @@ class SQLAlchemyPlugin(Plugin): return None - def get_customize_class_mro_hook( + def get_metaclass_hook( self, fullname: str ) -> Optional[Callable[[ClassDefContext], None]]: - return _fill_in_decorators + if names.type_id_for_fullname(fullname) is names.DECLARATIVE_META: + # Set any classes that explicitly have metaclass=DeclarativeMeta + # as declarative so the check in `get_base_class_hook()` works + return _metaclass_cls_hook + + return None + + def get_base_class_hook( + self, fullname: str + ) -> Optional[Callable[[ClassDefContext], None]]: + sym = self.lookup_fully_qualified(fullname) + + if ( + sym + and isinstance(sym.node, TypeInfo) + and util.has_declarative_base(sym.node) + ): + return _base_cls_hook + + return None def get_attribute_hook( self, fullname: str @@ -101,6 +105,7 @@ class SQLAlchemyPlugin(Plugin): "sqlalchemy.orm.attributes.QueryableAttribute." ): return _queryable_getattr_hook + return None def get_additional_deps( @@ -116,10 +121,43 @@ def plugin(version: str) -> TypingType[SQLAlchemyPlugin]: return SQLAlchemyPlugin -def _queryable_getattr_hook(ctx: AttributeContext) -> Type: - # how do I....tell it it has no attribute of a certain name? - # can't find any Type that seems to match that - return ctx.default_attr_type +def _dynamic_class_hook(ctx: DynamicClassDefContext) -> None: + """Generate a declarative Base class when the declarative_base() function + is encountered.""" + + _add_globals(ctx) + + cls = ClassDef(ctx.name, Block([])) + cls.fullname = ctx.api.qualified_name(ctx.name) + + info = TypeInfo(SymbolTable(), cls, ctx.api.cur_mod_id) + cls.info = info + _set_declarative_metaclass(ctx.api, cls) + + cls_arg = util.get_callexpr_kwarg(ctx.call, "cls", expr_types=(NameExpr,)) + if cls_arg is not None and isinstance(cls_arg.node, TypeInfo): + util.set_is_base(cls_arg.node) + decl_class.scan_declarative_assignments_and_apply_types( + cls_arg.node.defn, ctx.api, is_mixin_scan=True + ) + info.bases = [Instance(cls_arg.node, [])] + else: + obj = ctx.api.named_type("__builtins__.object") + + info.bases = [obj] + + try: + calculate_mro(info) + except MroError: + util.fail( + ctx.api, "Not able to calculate MRO for declarative base", ctx.call + ) + obj = ctx.api.named_type("__builtins__.object") + info.bases = [obj] + info.fallback_to_any = True + + ctx.api.add_symbol_table_node(ctx.name, SymbolTableNode(GDEF, info)) + util.set_is_base(info) def _fill_in_decorators(ctx: ClassDefContext) -> None: @@ -173,39 +211,6 @@ def _fill_in_decorators(ctx: ClassDefContext) -> None: ) -def _add_globals(ctx: Union[ClassDefContext, DynamicClassDefContext]) -> None: - """Add __sa_DeclarativeMeta and __sa_Mapped symbol to the global space - for all class defs - - """ - - util.add_global( - ctx, - "sqlalchemy.orm.decl_api", - "DeclarativeMeta", - "__sa_DeclarativeMeta", - ) - - util.add_global(ctx, "sqlalchemy.orm.attributes", "Mapped", "__sa_Mapped") - - -def _cls_metadata_hook(ctx: ClassDefContext) -> None: - _add_globals(ctx) - decl_class._scan_declarative_assignments_and_apply_types(ctx.cls, ctx.api) - - -def _base_cls_hook(ctx: ClassDefContext) -> None: - _add_globals(ctx) - decl_class._scan_declarative_assignments_and_apply_types(ctx.cls, ctx.api) - - -def _declarative_mixin_hook(ctx: ClassDefContext) -> None: - _add_globals(ctx) - decl_class._scan_declarative_assignments_and_apply_types( - ctx.cls, ctx.api, is_mixin_scan=True - ) - - def _cls_decorator_hook(ctx: ClassDefContext) -> None: _add_globals(ctx) assert isinstance(ctx.reason, nodes.MemberExpr) @@ -217,10 +222,10 @@ def _cls_decorator_hook(ctx: ClassDefContext) -> None: assert ( isinstance(node_type, Instance) - and names._type_id_for_named_node(node_type.type) is names.REGISTRY + and names.type_id_for_named_node(node_type.type) is names.REGISTRY ) - decl_class._scan_declarative_assignments_and_apply_types(ctx.cls, ctx.api) + decl_class.scan_declarative_assignments_and_apply_types(ctx.cls, ctx.api) def _base_cls_decorator_hook(ctx: ClassDefContext) -> None: @@ -228,69 +233,52 @@ def _base_cls_decorator_hook(ctx: ClassDefContext) -> None: cls = ctx.cls - _make_declarative_meta(ctx.api, cls) + _set_declarative_metaclass(ctx.api, cls) - decl_class._scan_declarative_assignments_and_apply_types( + util.set_is_base(ctx.cls.info) + decl_class.scan_declarative_assignments_and_apply_types( cls, ctx.api, is_mixin_scan=True ) -def _dynamic_class_hook(ctx: DynamicClassDefContext) -> None: - """Generate a declarative Base class when the declarative_base() function - is encountered.""" - +def _declarative_mixin_hook(ctx: ClassDefContext) -> None: _add_globals(ctx) + util.set_is_base(ctx.cls.info) + decl_class.scan_declarative_assignments_and_apply_types( + ctx.cls, ctx.api, is_mixin_scan=True + ) - cls = ClassDef(ctx.name, Block([])) - cls.fullname = ctx.api.qualified_name(ctx.name) - - info = TypeInfo(SymbolTable(), cls, ctx.api.cur_mod_id) - cls.info = info - _make_declarative_meta(ctx.api, cls) - - cls_arg = util._get_callexpr_kwarg(ctx.call, "cls", expr_types=(NameExpr,)) - if cls_arg is not None and isinstance(cls_arg.node, TypeInfo): - decl_class._scan_declarative_assignments_and_apply_types( - cls_arg.node.defn, ctx.api, is_mixin_scan=True - ) - info.bases = [Instance(cls_arg.node, [])] - else: - obj = ctx.api.named_type("__builtins__.object") - - info.bases = [obj] - try: - calculate_mro(info) - except MroError: - util.fail( - ctx.api, "Not able to calculate MRO for declarative base", ctx.call - ) - obj = ctx.api.named_type("__builtins__.object") - info.bases = [obj] - info.fallback_to_any = True +def _metaclass_cls_hook(ctx: ClassDefContext) -> None: + util.set_is_base(ctx.cls.info) - ctx.api.add_symbol_table_node(ctx.name, SymbolTableNode(GDEF, info)) +def _base_cls_hook(ctx: ClassDefContext) -> None: + _add_globals(ctx) + decl_class.scan_declarative_assignments_and_apply_types(ctx.cls, ctx.api) -def _make_declarative_meta( - api: SemanticAnalyzerPluginInterface, target_cls: ClassDef -) -> None: - declarative_meta_name: NameExpr = NameExpr("__sa_DeclarativeMeta") - declarative_meta_name.kind = GDEF - declarative_meta_name.fullname = "sqlalchemy.orm.decl_api.DeclarativeMeta" +def _queryable_getattr_hook(ctx: AttributeContext) -> Type: + # how do I....tell it it has no attribute of a certain name? + # can't find any Type that seems to match that + return ctx.default_attr_type - # installed by _add_globals - sym = api.lookup_qualified("__sa_DeclarativeMeta", target_cls) - assert sym is not None and isinstance(sym.node, nodes.TypeInfo) +def _add_globals(ctx: Union[ClassDefContext, DynamicClassDefContext]) -> None: + """Add __sa_DeclarativeMeta and __sa_Mapped symbol to the global space + for all class defs - declarative_meta_typeinfo = sym.node - declarative_meta_name.node = declarative_meta_typeinfo + """ - target_cls.metaclass = declarative_meta_name + util.add_global(ctx, "sqlalchemy.orm.attributes", "Mapped", "__sa_Mapped") - declarative_meta_instance = Instance(declarative_meta_typeinfo, []) +def _set_declarative_metaclass( + api: SemanticAnalyzerPluginInterface, target_cls: ClassDef +) -> None: info = target_cls.info - info.declared_metaclass = info.metaclass_type = declarative_meta_instance + sym = api.lookup_fully_qualified_or_none( + "sqlalchemy.orm.decl_api.DeclarativeMeta" + ) + assert sym is not None and isinstance(sym.node, TypeInfo) + info.declared_metaclass = info.metaclass_type = Instance(sym.node, []) diff --git a/lib/sqlalchemy/ext/mypy/util.py b/lib/sqlalchemy/ext/mypy/util.py index 67c3fa2091..614805d77c 100644 --- a/lib/sqlalchemy/ext/mypy/util.py +++ b/lib/sqlalchemy/ext/mypy/util.py @@ -1,5 +1,4 @@ from typing import Any -from typing import cast from typing import Iterable from typing import Iterator from typing import List @@ -10,12 +9,15 @@ from typing import Type as TypingType from typing import TypeVar from typing import Union +from mypy.nodes import ARG_POS from mypy.nodes import CallExpr from mypy.nodes import ClassDef from mypy.nodes import CLASSDEF_NO_INFO from mypy.nodes import Context +from mypy.nodes import Expression from mypy.nodes import IfStmt from mypy.nodes import JsonDict +from mypy.nodes import MemberExpr from mypy.nodes import NameExpr from mypy.nodes import Statement from mypy.nodes import SymbolTableNode @@ -24,10 +26,11 @@ from mypy.plugin import ClassDefContext from mypy.plugin import DynamicClassDefContext from mypy.plugin import SemanticAnalyzerPluginInterface from mypy.plugins.common import deserialize_and_fixup_type +from mypy.typeops import map_type_from_supertype from mypy.types import Instance from mypy.types import NoneType -from mypy.types import ProperType from mypy.types import Type +from mypy.types import TypeVarType from mypy.types import UnboundType from mypy.types import UnionType @@ -35,53 +38,117 @@ from mypy.types import UnionType _TArgType = TypeVar("_TArgType", bound=Union[CallExpr, NameExpr]) -class DeclClassApplied: +class SQLAlchemyAttribute: def __init__( self, - is_mapped: bool, - has_table: bool, - mapped_attr_names: Iterable[Tuple[str, ProperType]], - mapped_mro: Iterable[Instance], - ): - self.is_mapped = is_mapped - self.has_table = has_table - self.mapped_attr_names = list(mapped_attr_names) - self.mapped_mro = list(mapped_mro) + name: str, + line: int, + column: int, + typ: Optional[Type], + info: TypeInfo, + ) -> None: + self.name = name + self.line = line + self.column = column + self.type = typ + self.info = info def serialize(self) -> JsonDict: + assert self.type return { - "is_mapped": self.is_mapped, - "has_table": self.has_table, - "mapped_attr_names": [ - (name, type_.serialize()) - for name, type_ in self.mapped_attr_names - ], - "mapped_mro": [type_.serialize() for type_ in self.mapped_mro], + "name": self.name, + "line": self.line, + "column": self.column, + "type": self.type.serialize(), } + def expand_typevar_from_subtype(self, sub_type: TypeInfo) -> None: + """Expands type vars in the context of a subtype when an attribute is inherited + from a generic super type.""" + if not isinstance(self.type, TypeVarType): + return + + self.type = map_type_from_supertype(self.type, sub_type, self.info) + @classmethod def deserialize( - cls, data: JsonDict, api: SemanticAnalyzerPluginInterface - ) -> "DeclClassApplied": - - return DeclClassApplied( - is_mapped=data["is_mapped"], - has_table=data["has_table"], - mapped_attr_names=cast( - List[Tuple[str, ProperType]], - [ - (name, deserialize_and_fixup_type(type_, api)) - for name, type_ in data["mapped_attr_names"] - ], - ), - mapped_mro=cast( - List[Instance], - [ - deserialize_and_fixup_type(type_, api) - for type_ in data["mapped_mro"] - ], - ), - ) + cls, + info: TypeInfo, + data: JsonDict, + api: SemanticAnalyzerPluginInterface, + ) -> "SQLAlchemyAttribute": + data = data.copy() + typ = deserialize_and_fixup_type(data.pop("type"), api) + return cls(typ=typ, info=info, **data) + + +def _set_info_metadata(info: TypeInfo, key: str, data: Any) -> None: + info.metadata.setdefault("sqlalchemy", {})[key] = data + + +def _get_info_metadata(info: TypeInfo, key: str) -> Optional[Any]: + return info.metadata.get("sqlalchemy", {}).get(key, None) + + +def _get_info_mro_metadata(info: TypeInfo, key: str) -> Optional[Any]: + if info.mro: + for base in info.mro: + metadata = _get_info_metadata(base, key) + if metadata is not None: + return metadata + return None + + +def set_is_base(info: TypeInfo) -> None: + _set_info_metadata(info, "is_base", True) + + +def get_is_base(info: TypeInfo) -> bool: + is_base = _get_info_metadata(info, "is_base") + return is_base is True + + +def has_declarative_base(info: TypeInfo) -> bool: + is_base = _get_info_mro_metadata(info, "is_base") + return is_base is True + + +def set_has_table(info: TypeInfo) -> None: + _set_info_metadata(info, "has_table", True) + + +def get_has_table(info: TypeInfo) -> bool: + is_base = _get_info_metadata(info, "has_table") + return is_base is True + + +def get_mapped_attributes( + info: TypeInfo, api: SemanticAnalyzerPluginInterface +) -> Optional[List[SQLAlchemyAttribute]]: + mapped_attributes: Optional[List[JsonDict]] = _get_info_metadata( + info, "mapped_attributes" + ) + if mapped_attributes is None: + return None + + attributes: List[SQLAlchemyAttribute] = [] + + for data in mapped_attributes: + attr = SQLAlchemyAttribute.deserialize(info, data, api) + attr.expand_typevar_from_subtype(info) + attributes.append(attr) + + return attributes + + +def set_mapped_attributes( + info: TypeInfo, attributes: List[SQLAlchemyAttribute] +) -> None: + _set_info_metadata( + info, + "mapped_attributes", + [attribute.serialize() for attribute in attributes], + ) def fail(api: SemanticAnalyzerPluginInterface, msg: str, ctx: Context) -> None: @@ -106,14 +173,14 @@ def add_global( @overload -def _get_callexpr_kwarg( +def get_callexpr_kwarg( callexpr: CallExpr, name: str, *, expr_types: None = ... ) -> Optional[Union[CallExpr, NameExpr]]: ... @overload -def _get_callexpr_kwarg( +def get_callexpr_kwarg( callexpr: CallExpr, name: str, *, @@ -122,7 +189,7 @@ def _get_callexpr_kwarg( ... -def _get_callexpr_kwarg( +def get_callexpr_kwarg( callexpr: CallExpr, name: str, *, @@ -142,7 +209,7 @@ def _get_callexpr_kwarg( return None -def _flatten_typechecking(stmts: Iterable[Statement]) -> Iterator[Statement]: +def flatten_typechecking(stmts: Iterable[Statement]) -> Iterator[Statement]: for stmt in stmts: if ( isinstance(stmt, IfStmt) @@ -155,7 +222,7 @@ def _flatten_typechecking(stmts: Iterable[Statement]) -> Iterator[Statement]: yield stmt -def _unbound_to_instance( +def unbound_to_instance( api: SemanticAnalyzerPluginInterface, typ: Type ) -> Type: """Take the UnboundType that we seem to get as the ret_type from a FuncDef @@ -173,10 +240,10 @@ def _unbound_to_instance( if typ.name == "Optional": # convert from "Optional?" to the more familiar # UnionType[..., NoneType()] - return _unbound_to_instance( + return unbound_to_instance( api, UnionType( - [_unbound_to_instance(api, typ_arg) for typ_arg in typ.args] + [unbound_to_instance(api, typ_arg) for typ_arg in typ.args] + [NoneType()] ), ) @@ -193,7 +260,7 @@ def _unbound_to_instance( return Instance( bound_type, [ - _unbound_to_instance(api, arg) + unbound_to_instance(api, arg) if isinstance(arg, UnboundType) else arg for arg in typ.args @@ -203,9 +270,9 @@ def _unbound_to_instance( return typ -def _info_for_cls( +def info_for_cls( cls: ClassDef, api: SemanticAnalyzerPluginInterface -) -> TypeInfo: +) -> Optional[TypeInfo]: if cls.info is CLASSDEF_NO_INFO: sym = api.lookup_qualified(cls.name, cls) if sym is None: @@ -214,3 +281,15 @@ def _info_for_cls( return sym.node return cls.info + + +def expr_to_mapped_constructor(expr: Expression) -> CallExpr: + column_descriptor = NameExpr("__sa_Mapped") + column_descriptor.fullname = "sqlalchemy.orm.attributes.Mapped" + member_expr = MemberExpr(column_descriptor, "_empty_constructor") + return CallExpr( + member_expr, + [expr], + [ARG_POS], + ["arg1"], + ) diff --git a/setup.cfg b/setup.cfg index 846b728489..4b238fcadf 100644 --- a/setup.cfg +++ b/setup.cfg @@ -46,7 +46,7 @@ install_requires = asyncio = greenlet!=0.4.17;python_version>="3" mypy = - mypy >= 0.800;python_version>="3" + mypy >= 0.910;python_version>="3" sqlalchemy2-stubs mssql = pyodbc mssql_pymssql = pymssql -- 2.47.2