# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
+from typing import List
from typing import Optional
from typing import Union
-from mypy import nodes
from mypy.nodes import ARG_NAMED_OPT
from mypy.nodes import Argument
from mypy.nodes import AssignmentStmt
from mypy.nodes import MDEF
from mypy.nodes import MemberExpr
from mypy.nodes import NameExpr
+from mypy.nodes import RefExpr
from mypy.nodes import StrExpr
from mypy.nodes import SymbolTableNode
from mypy.nodes import TempNode
from . import util
-def _apply_mypy_mapped_attr(
+def apply_mypy_mapped_attr(
cls: ClassDef,
api: SemanticAnalyzerPluginInterface,
item: Union[NameExpr, StrExpr],
- cls_metadata: util.DeclClassApplied,
+ attributes: List[util.SQLAlchemyAttribute],
) -> None:
if isinstance(item, NameExpr):
name = item.name
elif isinstance(item, StrExpr):
name = item.value
else:
- return
+ return None
for stmt in cls.defs.body:
if (
break
else:
util.fail(api, "Can't find mapped attribute {}".format(name), cls)
- return
+ return None
if stmt.type is None:
util.fail(
"typing information",
stmt,
)
- return
+ return None
left_hand_explicit_type = get_proper_type(stmt.type)
assert isinstance(
left_hand_explicit_type, (Instance, UnionType, UnboundType)
)
- cls_metadata.mapped_attr_names.append((name, left_hand_explicit_type))
+ attributes.append(
+ util.SQLAlchemyAttribute(
+ name=name,
+ line=item.line,
+ column=item.column,
+ typ=left_hand_explicit_type,
+ info=cls.info,
+ )
+ )
- _apply_type_to_mapped_statement(
+ apply_type_to_mapped_statement(
api, stmt, stmt.lvalues[0], left_hand_explicit_type, None
)
-def _re_apply_declarative_assignments(
+def re_apply_declarative_assignments(
cls: ClassDef,
api: SemanticAnalyzerPluginInterface,
- cls_metadata: util.DeclClassApplied,
+ attributes: List[util.SQLAlchemyAttribute],
) -> None:
"""For multiple class passes, re-apply our left-hand side types as mypy
seems to reset them in place.
"""
- mapped_attr_lookup = {
- name: typ for name, typ in cls_metadata.mapped_attr_names
- }
+ mapped_attr_lookup = {attr.name: attr for attr in attributes}
update_cls_metadata = False
for stmt in cls.defs.body:
):
left_node = stmt.lvalues[0].node
- python_type_for_type = mapped_attr_lookup[stmt.lvalues[0].name]
+ python_type_for_type = mapped_attr_lookup[
+ stmt.lvalues[0].name
+ ].type
+
+ left_node_proper_type = get_proper_type(left_node.type)
+
# if we have scanned an UnboundType and now there's a more
# specific type than UnboundType, call the re-scan so we
# can get that set up correctly
if (
isinstance(python_type_for_type, UnboundType)
- and not isinstance(left_node.type, UnboundType)
+ and not isinstance(left_node_proper_type, UnboundType)
and (
- isinstance(stmt.rvalue.callee, MemberExpr)
+ isinstance(stmt.rvalue, CallExpr)
+ and isinstance(stmt.rvalue.callee, MemberExpr)
+ and isinstance(stmt.rvalue.callee.expr, NameExpr)
+ and stmt.rvalue.callee.expr.node is not None
and stmt.rvalue.callee.expr.node.fullname
== "sqlalchemy.orm.attributes.Mapped"
and stmt.rvalue.callee.name == "_empty_constructor"
and isinstance(stmt.rvalue.args[0], CallExpr)
+ and isinstance(stmt.rvalue.args[0].callee, RefExpr)
)
):
python_type_for_type = (
- infer._infer_type_from_right_hand_nameexpr(
+ infer.infer_type_from_right_hand_nameexpr(
api,
stmt,
left_node,
- left_node.type,
+ left_node_proper_type,
stmt.rvalue.args[0].callee,
)
)
):
continue
- # update the DeclClassApplied with the better information
- mapped_attr_lookup[stmt.lvalues[0].name] = python_type_for_type
+ # update the SQLAlchemyAttribute with the better information
+ mapped_attr_lookup[
+ stmt.lvalues[0].name
+ ].type = python_type_for_type
+
update_cls_metadata = True
- left_node.type = api.named_type(
- "__sa_Mapped", [python_type_for_type]
- )
+ if python_type_for_type is not None:
+ left_node.type = api.named_type(
+ "__sa_Mapped", [python_type_for_type]
+ )
if update_cls_metadata:
- cls_metadata.mapped_attr_names[:] = [
- (k, v) for k, v in mapped_attr_lookup.items()
- ]
+ util.set_mapped_attributes(cls.info, attributes)
-def _apply_type_to_mapped_statement(
+def apply_type_to_mapped_statement(
api: SemanticAnalyzerPluginInterface,
stmt: AssignmentStmt,
lvalue: NameExpr,
# _sa_Mapped._empty_constructor(<original CallExpr from rvalue>)
# the original right-hand side is maintained so it gets type checked
# internally
- column_descriptor = nodes.NameExpr("__sa_Mapped")
- column_descriptor.fullname = "sqlalchemy.orm.attributes.Mapped"
- mm = nodes.MemberExpr(column_descriptor, "_empty_constructor")
- orig_call_expr = stmt.rvalue
- stmt.rvalue = CallExpr(mm, [orig_call_expr], [nodes.ARG_POS], ["arg1"])
+ stmt.rvalue = util.expr_to_mapped_constructor(stmt.rvalue)
-def _add_additional_orm_attributes(
+def add_additional_orm_attributes(
cls: ClassDef,
api: SemanticAnalyzerPluginInterface,
- cls_metadata: util.DeclClassApplied,
+ attributes: List[util.SQLAlchemyAttribute],
) -> None:
"""Apply __init__, __table__ and other attributes to the mapped class."""
- info = util._info_for_cls(cls, api)
- if "__init__" not in info.names and cls_metadata.is_mapped:
- mapped_attr_names = {n: t for n, t in cls_metadata.mapped_attr_names}
+ info = util.info_for_cls(cls, api)
- for mapped_base in cls_metadata.mapped_mro:
- base_cls_metadata = util.DeclClassApplied.deserialize(
- mapped_base.type.metadata["_sa_decl_class_applied"], api
- )
- for n, t in base_cls_metadata.mapped_attr_names:
- mapped_attr_names.setdefault(n, t)
+ if info is None:
+ return
+
+ is_base = util.get_is_base(info)
+
+ if "__init__" not in info.names and not is_base:
+ mapped_attr_names = {attr.name: attr.type for attr in attributes}
+
+ for base in info.mro[1:-1]:
+ if "sqlalchemy" not in info.metadata:
+ continue
+
+ base_cls_attributes = util.get_mapped_attributes(base, api)
+ if base_cls_attributes is None:
+ continue
+
+ for attr in base_cls_attributes:
+ mapped_attr_names.setdefault(attr.name, attr.type)
arguments = []
for name, typ in mapped_attr_names.items():
kind=ARG_NAMED_OPT,
)
)
+
add_method_to_class(api, cls, "__init__", arguments, NoneTyp())
- if "__table__" not in info.names and cls_metadata.has_table:
+ if "__table__" not in info.names and util.get_has_table(info):
_apply_placeholder_attr_to_class(
api, cls, "sqlalchemy.sql.schema.Table", "__table__"
)
- if cls_metadata.is_mapped:
+ if not is_base:
_apply_placeholder_attr_to_class(
api, cls, "sqlalchemy.orm.mapper.Mapper", "__mapper__"
)
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
+from typing import List
from typing import Optional
from typing import Union
-from mypy import nodes
from mypy.nodes import AssignmentStmt
from mypy.nodes import CallExpr
from mypy.nodes import ClassDef
from mypy.nodes import Decorator
+from mypy.nodes import LambdaExpr
from mypy.nodes import ListExpr
from mypy.nodes import MemberExpr
from mypy.nodes import NameExpr
from . import util
-def _scan_declarative_assignments_and_apply_types(
+def scan_declarative_assignments_and_apply_types(
cls: ClassDef,
api: SemanticAnalyzerPluginInterface,
is_mixin_scan: bool = False,
-) -> Optional[util.DeclClassApplied]:
+) -> Optional[List[util.SQLAlchemyAttribute]]:
- info = util._info_for_cls(cls, api)
+ info = util.info_for_cls(cls, api)
if info is None:
# this can occur during cached passes
return None
elif cls.fullname.startswith("builtins"):
return None
- elif "_sa_decl_class_applied" in info.metadata:
- cls_metadata = util.DeclClassApplied.deserialize(
- info.metadata["_sa_decl_class_applied"], api
- )
+ mapped_attributes: Optional[
+ List[util.SQLAlchemyAttribute]
+ ] = util.get_mapped_attributes(info, api)
+
+ if mapped_attributes is not None:
# ensure that a class that's mapped is always picked up by
# its mapped() decorator or declarative metaclass before
# it would be detected as an unmapped mixin class
- if not is_mixin_scan:
- assert cls_metadata.is_mapped
+ if not is_mixin_scan:
# mypy can call us more than once. it then *may* have reset the
# left hand side of everything, but not the right that we removed,
# removing our ability to re-scan. but we have the types
# here, so lets re-apply them, or if we have an UnboundType,
# we can re-scan
- apply._re_apply_declarative_assignments(cls, api, cls_metadata)
+ apply.re_apply_declarative_assignments(cls, api, mapped_attributes)
- return cls_metadata
+ return mapped_attributes
- cls_metadata = util.DeclClassApplied(not is_mixin_scan, False, [], [])
+ mapped_attributes = []
if not cls.defs.body:
# when we get a mixin class from another file, the body is
# empty (!) but the names are in the symbol table. so use that.
for sym_name, sym in info.names.items():
- _scan_symbol_table_entry(cls, api, sym_name, sym, cls_metadata)
+ _scan_symbol_table_entry(
+ cls, api, sym_name, sym, mapped_attributes
+ )
else:
- for stmt in util._flatten_typechecking(cls.defs.body):
+ for stmt in util.flatten_typechecking(cls.defs.body):
if isinstance(stmt, AssignmentStmt):
- _scan_declarative_assignment_stmt(cls, api, stmt, cls_metadata)
+ _scan_declarative_assignment_stmt(
+ cls, api, stmt, mapped_attributes
+ )
elif isinstance(stmt, Decorator):
- _scan_declarative_decorator_stmt(cls, api, stmt, cls_metadata)
- _scan_for_mapped_bases(cls, api, cls_metadata)
+ _scan_declarative_decorator_stmt(
+ cls, api, stmt, mapped_attributes
+ )
+ _scan_for_mapped_bases(cls, api)
if not is_mixin_scan:
- apply._add_additional_orm_attributes(cls, api, cls_metadata)
+ apply.add_additional_orm_attributes(cls, api, mapped_attributes)
- info.metadata["_sa_decl_class_applied"] = cls_metadata.serialize()
+ util.set_mapped_attributes(info, mapped_attributes)
- return cls_metadata
+ return mapped_attributes
def _scan_symbol_table_entry(
api: SemanticAnalyzerPluginInterface,
name: str,
value: SymbolTableNode,
- cls_metadata: util.DeclClassApplied,
+ attributes: List[util.SQLAlchemyAttribute],
) -> None:
"""Extract mapping information from a SymbolTableNode that's in the
type.names dictionary.
return
left_hand_explicit_type = None
- type_id = names._type_id_for_named_node(value_type.type)
+ type_id = names.type_id_for_named_node(value_type.type)
# type_id = names._type_id_for_unbound_type(value.type.type, cls, api)
err = False
if isinstance(typeengine_arg, (UnboundType, TypeInfo)):
sym = api.lookup_qualified(typeengine_arg.name, typeengine_arg)
if sym is not None and isinstance(sym.node, TypeInfo):
- if names._has_base_type_id(sym.node, names.TYPEENGINE):
+ if names.has_base_type_id(sym.node, names.TYPEENGINE):
left_hand_explicit_type = UnionType(
[
- infer._extract_python_type_from_typeengine(
+ infer.extract_python_type_from_typeengine(
api, sym.node, []
),
NoneType(),
left_hand_explicit_type = AnyType(TypeOfAny.special_form)
if left_hand_explicit_type is not None:
- cls_metadata.mapped_attr_names.append((name, left_hand_explicit_type))
+ assert value.node is not None
+ attributes.append(
+ util.SQLAlchemyAttribute(
+ name=name,
+ line=value.node.line,
+ column=value.node.column,
+ typ=left_hand_explicit_type,
+ info=cls.info,
+ )
+ )
def _scan_declarative_decorator_stmt(
cls: ClassDef,
api: SemanticAnalyzerPluginInterface,
stmt: Decorator,
- cls_metadata: util.DeclClassApplied,
+ attributes: List[util.SQLAlchemyAttribute],
) -> None:
"""Extract mapping information from a @declared_attr in a declarative
class.
for dec in stmt.decorators:
if (
isinstance(dec, (NameExpr, MemberExpr, SymbolNode))
- and names._type_id_for_named_node(dec) is names.DECLARED_ATTR
+ and names.type_id_for_named_node(dec) is names.DECLARED_ATTR
):
break
else:
if isinstance(stmt.func.type, CallableType):
func_type = stmt.func.type.ret_type
if isinstance(func_type, UnboundType):
- type_id = names._type_id_for_unbound_type(func_type, cls, api)
+ type_id = names.type_id_for_unbound_type(func_type, cls, api)
else:
# this does not seem to occur unless the type argument is
# incorrect
if isinstance(typeengine_arg, UnboundType):
sym = api.lookup_qualified(typeengine_arg.name, typeengine_arg)
if sym is not None and isinstance(sym.node, TypeInfo):
- if names._has_base_type_id(sym.node, names.TYPEENGINE):
+ if names.has_base_type_id(sym.node, names.TYPEENGINE):
left_hand_explicit_type = UnionType(
[
- infer._extract_python_type_from_typeengine(
+ infer.extract_python_type_from_typeengine(
api, sym.node, []
),
NoneType(),
# we see everywhere else.
if isinstance(left_hand_explicit_type, UnboundType):
left_hand_explicit_type = get_proper_type(
- util._unbound_to_instance(api, left_hand_explicit_type)
+ util.unbound_to_instance(api, left_hand_explicit_type)
)
left_node.node.type = api.named_type(
# <attr> : Mapped[<typ>] =
# _sa_Mapped._empty_constructor(lambda: <function body>)
# the function body is maintained so it gets type checked internally
- column_descriptor = nodes.NameExpr("__sa_Mapped")
- column_descriptor.fullname = "sqlalchemy.orm.attributes.Mapped"
- mm = nodes.MemberExpr(column_descriptor, "_empty_constructor")
-
- arg = nodes.LambdaExpr(stmt.func.arguments, stmt.func.body)
- rvalue = CallExpr(
- mm,
- [arg],
- [nodes.ARG_POS],
- ["arg1"],
+ rvalue = util.expr_to_mapped_constructor(
+ LambdaExpr(stmt.func.arguments, stmt.func.body)
)
new_stmt = AssignmentStmt([left_node], rvalue)
new_stmt.type = left_node.node.type
- cls_metadata.mapped_attr_names.append(
- (left_node.name, left_hand_explicit_type)
+ attributes.append(
+ util.SQLAlchemyAttribute(
+ name=left_node.name,
+ line=stmt.line,
+ column=stmt.column,
+ typ=left_hand_explicit_type,
+ info=cls.info,
+ )
)
cls.defs.body[dec_index] = new_stmt
cls: ClassDef,
api: SemanticAnalyzerPluginInterface,
stmt: AssignmentStmt,
- cls_metadata: util.DeclClassApplied,
+ attributes: List[util.SQLAlchemyAttribute],
) -> None:
"""Extract mapping information from an assignment statement in a
declarative class.
if node.name == "__abstract__":
if api.parse_bool(stmt.rvalue) is True:
- cls_metadata.is_mapped = False
+ util.set_is_base(cls.info)
return
elif node.name == "__tablename__":
- cls_metadata.has_table = True
+ util.set_has_table(cls.info)
elif node.name.startswith("__"):
return
elif node.name == "_mypy_mapped_attrs":
else:
for item in stmt.rvalue.items:
if isinstance(item, (NameExpr, StrExpr)):
- apply._apply_mypy_mapped_attr(cls, api, item, cls_metadata)
+ apply.apply_mypy_mapped_attr(cls, api, item, attributes)
left_hand_mapped_type: Optional[Type] = None
left_hand_explicit_type: Optional[ProperType] = None
if (
mapped_sym is not None
and mapped_sym.node is not None
- and names._type_id_for_named_node(mapped_sym.node)
+ and names.type_id_for_named_node(mapped_sym.node)
is names.MAPPED
):
left_hand_explicit_type = get_proper_type(
node_type = get_proper_type(node.type)
if (
isinstance(node_type, Instance)
- and names._type_id_for_named_node(node_type.type) is names.MAPPED
+ and names.type_id_for_named_node(node_type.type) is names.MAPPED
):
# print(node.type)
# sqlalchemy.orm.attributes.Mapped[<python type>]
stmt.rvalue.callee, RefExpr
):
- python_type_for_type = infer._infer_type_from_right_hand_nameexpr(
+ python_type_for_type = infer.infer_type_from_right_hand_nameexpr(
api, stmt, node, left_hand_explicit_type, stmt.rvalue.callee
)
assert python_type_for_type is not None
- cls_metadata.mapped_attr_names.append((node.name, python_type_for_type))
+ attributes.append(
+ util.SQLAlchemyAttribute(
+ name=node.name,
+ line=stmt.line,
+ column=stmt.column,
+ typ=python_type_for_type,
+ info=cls.info,
+ )
+ )
- apply._apply_type_to_mapped_statement(
+ apply.apply_type_to_mapped_statement(
api,
stmt,
lvalue,
def _scan_for_mapped_bases(
cls: ClassDef,
api: SemanticAnalyzerPluginInterface,
- cls_metadata: util.DeclClassApplied,
) -> None:
"""Given a class, iterate through its superclass hierarchy to find
all other classes that are considered as ORM-significant.
"""
- info = util._info_for_cls(cls, api)
+ info = util.info_for_cls(cls, api)
- baseclasses = list(info.bases)
-
- while baseclasses:
- base: Instance = baseclasses.pop(0)
+ if info is None:
+ return
- if base.type.fullname.startswith("builtins"):
+ for base_info in info.mro[1:-1]:
+ if base_info.fullname.startswith("builtins"):
continue
# scan each base for mapped attributes. if they are not already
# scanned (but have all their type info), that means they are unmapped
# mixins
- base_decl_class_applied = (
- _scan_declarative_assignments_and_apply_types(
- base.type.defn, api, is_mixin_scan=True
- )
+ scan_declarative_assignments_and_apply_types(
+ base_info.defn, api, is_mixin_scan=True
)
-
- if base_decl_class_applied is not None:
- cls_metadata.mapped_mro.append(base)
- baseclasses.extend(base.type.bases)
from . import util
-def _infer_type_from_right_hand_nameexpr(
+def infer_type_from_right_hand_nameexpr(
api: SemanticAnalyzerPluginInterface,
stmt: AssignmentStmt,
node: Var,
left_hand_explicit_type: Optional[ProperType],
- infer_from_right_side: NameExpr,
+ infer_from_right_side: RefExpr,
) -> Optional[ProperType]:
- type_id = names._type_id_for_callee(infer_from_right_side)
+ type_id = names.type_id_for_callee(infer_from_right_side)
if type_id is None:
return None
api, stmt, node, left_hand_explicit_type
)
elif type_id is names.SYNONYM_PROPERTY:
- python_type_for_type = _infer_type_from_left_hand_type_only(
+ python_type_for_type = infer_type_from_left_hand_type_only(
api, node, left_hand_explicit_type
)
elif type_id is names.COMPOSITE_PROPERTY:
# string expression
# isinstance(target_cls_arg, StrExpr)
- uselist_arg = util._get_callexpr_kwarg(stmt.rvalue, "uselist")
- collection_cls_arg: Optional[Expression] = util._get_callexpr_kwarg(
+ uselist_arg = util.get_callexpr_kwarg(stmt.rvalue, "uselist")
+ collection_cls_arg: Optional[Expression] = util.get_callexpr_kwarg(
stmt.rvalue, "collection_class"
)
type_is_a_collection = False
# this can be used to determine Optional for a many-to-one
# in the same way nullable=False could be used, if we start supporting
# that.
- # innerjoin_arg = _get_callexpr_kwarg(stmt.rvalue, "innerjoin")
+ # innerjoin_arg = util.get_callexpr_kwarg(stmt.rvalue, "innerjoin")
if (
uselist_arg is not None
util.fail(api, msg.format(node.name), node)
if python_type_for_type is None:
- return _infer_type_from_left_hand_type_only(
+ return infer_type_from_left_hand_type_only(
api, node, left_hand_explicit_type
)
elif left_hand_explicit_type is not None:
python_type_for_type = None
if python_type_for_type is None:
- return _infer_type_from_left_hand_type_only(
+ return infer_type_from_left_hand_type_only(
api, node, left_hand_explicit_type
)
elif left_hand_explicit_type is not None:
first_prop_arg = stmt.rvalue.args[0]
if isinstance(first_prop_arg, CallExpr):
- type_id = names._type_id_for_callee(first_prop_arg.callee)
+ type_id = names.type_id_for_callee(first_prop_arg.callee)
# look for column_property() / deferred() etc with Column as first
# argument
right_hand_expression=first_prop_arg,
)
- return _infer_type_from_left_hand_type_only(
+ return infer_type_from_left_hand_type_only(
api, node, left_hand_explicit_type
)
if callee is None:
return None
- if isinstance(callee.node, TypeInfo) and names._mro_has_id(
+ if isinstance(callee.node, TypeInfo) and names.mro_has_id(
callee.node.mro, names.TYPEENGINE
):
- python_type_for_type = _extract_python_type_from_typeengine(
+ python_type_for_type = extract_python_type_from_typeengine(
api, callee.node, type_args
)
else:
# it's not TypeEngine, it's typically implicitly typed
# like ForeignKey. we can't infer from the right side.
- return _infer_type_from_left_hand_type_only(
+ return infer_type_from_left_hand_type_only(
api, node, left_hand_explicit_type
)
)
-def _infer_type_from_left_hand_type_only(
+def infer_type_from_left_hand_type_only(
api: SemanticAnalyzerPluginInterface,
node: Var,
left_hand_explicit_type: Optional[ProperType],
return left_hand_explicit_type
-def _extract_python_type_from_typeengine(
+def extract_python_type_from_typeengine(
api: SemanticAnalyzerPluginInterface,
node: TypeInfo,
type_args: Sequence[Expression],
}
-def _has_base_type_id(info: TypeInfo, type_id: int) -> bool:
+def has_base_type_id(info: TypeInfo, type_id: int) -> bool:
for mr in info.mro:
check_type_id, fullnames = _lookup.get(mr.name, (None, None))
if check_type_id == type_id:
return mr.fullname in fullnames
-def _mro_has_id(mro: List[TypeInfo], type_id: int) -> bool:
+def mro_has_id(mro: List[TypeInfo], type_id: int) -> bool:
for mr in mro:
check_type_id, fullnames = _lookup.get(mr.name, (None, None))
if check_type_id == type_id:
return mr.fullname in fullnames
-def _type_id_for_unbound_type(
+def type_id_for_unbound_type(
type_: UnboundType, cls: ClassDef, api: SemanticAnalyzerPluginInterface
) -> Optional[int]:
- type_id = None
-
sym = api.lookup_qualified(type_.name, type_)
if sym is not None:
if isinstance(sym.node, TypeAlias):
target_type = get_proper_type(sym.node.target)
if isinstance(target_type, Instance):
- type_id = _type_id_for_named_node(target_type.type)
+ return type_id_for_named_node(target_type.type)
elif isinstance(sym.node, TypeInfo):
- type_id = _type_id_for_named_node(sym.node)
+ return type_id_for_named_node(sym.node)
- return type_id
+ return None
-def _type_id_for_callee(callee: Expression) -> Optional[int]:
+def type_id_for_callee(callee: Expression) -> Optional[int]:
if isinstance(callee, (MemberExpr, NameExpr)):
if isinstance(callee.node, FuncDef):
- return _type_id_for_funcdef(callee.node)
+ if callee.node.type and isinstance(callee.node.type, CallableType):
+ ret_type = get_proper_type(callee.node.type.ret_type)
+
+ if isinstance(ret_type, Instance):
+ return type_id_for_fullname(ret_type.type.fullname)
+
+ return None
elif isinstance(callee.node, TypeAlias):
target_type = get_proper_type(callee.node.target)
if isinstance(target_type, Instance):
- type_id = _type_id_for_fullname(target_type.type.fullname)
+ return type_id_for_fullname(target_type.type.fullname)
elif isinstance(callee.node, TypeInfo):
- type_id = _type_id_for_named_node(callee)
- else:
- type_id = None
- return type_id
-
-
-def _type_id_for_funcdef(node: FuncDef) -> Optional[int]:
- if node.type and isinstance(node.type, CallableType):
- ret_type = get_proper_type(node.type.ret_type)
-
- if isinstance(ret_type, Instance):
- return _type_id_for_fullname(ret_type.type.fullname)
-
+ return type_id_for_named_node(callee)
return None
-def _type_id_for_named_node(
+def type_id_for_named_node(
node: Union[NameExpr, MemberExpr, SymbolNode]
) -> Optional[int]:
type_id, fullnames = _lookup.get(node.name, (None, None))
return None
-def _type_id_for_fullname(fullname: str) -> Optional[int]:
+def type_id_for_fullname(fullname: str) -> Optional[int]:
tokens = fullname.split(".")
immediate = tokens[-1]
def get_dynamic_class_hook(
self, fullname: str
) -> Optional[Callable[[DynamicClassDefContext], None]]:
- if names._type_id_for_fullname(fullname) is names.DECLARATIVE_BASE:
+ if names.type_id_for_fullname(fullname) is names.DECLARATIVE_BASE:
return _dynamic_class_hook
return None
- def get_base_class_hook(
+ def get_customize_class_mro_hook(
self, fullname: str
) -> Optional[Callable[[ClassDefContext], None]]:
-
- # kind of a strange relationship between get_metaclass_hook()
- # and get_base_class_hook(). the former doesn't fire off for
- # subclasses. but then you can just check it here from the "base"
- # and get the same effect.
- sym = self.lookup_fully_qualified(fullname)
-
- if (
- sym
- and isinstance(sym.node, TypeInfo)
- and sym.node.metaclass_type
- and names._type_id_for_named_node(sym.node.metaclass_type.type)
- is names.DECLARATIVE_META
- ):
- return _base_cls_hook
- return None
+ return _fill_in_decorators
def get_class_decorator_hook(
self, fullname: str
sym = self.lookup_fully_qualified(fullname)
if sym is not None and sym.node is not None:
- type_id = names._type_id_for_named_node(sym.node)
+ type_id = names.type_id_for_named_node(sym.node)
if type_id is names.MAPPED_DECORATOR:
return _cls_decorator_hook
elif type_id in (
return None
- def get_customize_class_mro_hook(
+ def get_metaclass_hook(
self, fullname: str
) -> Optional[Callable[[ClassDefContext], None]]:
- return _fill_in_decorators
+ if names.type_id_for_fullname(fullname) is names.DECLARATIVE_META:
+ # Set any classes that explicitly have metaclass=DeclarativeMeta
+ # as declarative so the check in `get_base_class_hook()` works
+ return _metaclass_cls_hook
+
+ return None
+
+ def get_base_class_hook(
+ self, fullname: str
+ ) -> Optional[Callable[[ClassDefContext], None]]:
+ sym = self.lookup_fully_qualified(fullname)
+
+ if (
+ sym
+ and isinstance(sym.node, TypeInfo)
+ and util.has_declarative_base(sym.node)
+ ):
+ return _base_cls_hook
+
+ return None
def get_attribute_hook(
self, fullname: str
"sqlalchemy.orm.attributes.QueryableAttribute."
):
return _queryable_getattr_hook
+
return None
def get_additional_deps(
return SQLAlchemyPlugin
-def _queryable_getattr_hook(ctx: AttributeContext) -> Type:
- # how do I....tell it it has no attribute of a certain name?
- # can't find any Type that seems to match that
- return ctx.default_attr_type
+def _dynamic_class_hook(ctx: DynamicClassDefContext) -> None:
+ """Generate a declarative Base class when the declarative_base() function
+ is encountered."""
+
+ _add_globals(ctx)
+
+ cls = ClassDef(ctx.name, Block([]))
+ cls.fullname = ctx.api.qualified_name(ctx.name)
+
+ info = TypeInfo(SymbolTable(), cls, ctx.api.cur_mod_id)
+ cls.info = info
+ _set_declarative_metaclass(ctx.api, cls)
+
+ cls_arg = util.get_callexpr_kwarg(ctx.call, "cls", expr_types=(NameExpr,))
+ if cls_arg is not None and isinstance(cls_arg.node, TypeInfo):
+ util.set_is_base(cls_arg.node)
+ decl_class.scan_declarative_assignments_and_apply_types(
+ cls_arg.node.defn, ctx.api, is_mixin_scan=True
+ )
+ info.bases = [Instance(cls_arg.node, [])]
+ else:
+ obj = ctx.api.named_type("__builtins__.object")
+
+ info.bases = [obj]
+
+ try:
+ calculate_mro(info)
+ except MroError:
+ util.fail(
+ ctx.api, "Not able to calculate MRO for declarative base", ctx.call
+ )
+ obj = ctx.api.named_type("__builtins__.object")
+ info.bases = [obj]
+ info.fallback_to_any = True
+
+ ctx.api.add_symbol_table_node(ctx.name, SymbolTableNode(GDEF, info))
+ util.set_is_base(info)
def _fill_in_decorators(ctx: ClassDefContext) -> None:
)
-def _add_globals(ctx: Union[ClassDefContext, DynamicClassDefContext]) -> None:
- """Add __sa_DeclarativeMeta and __sa_Mapped symbol to the global space
- for all class defs
-
- """
-
- util.add_global(
- ctx,
- "sqlalchemy.orm.decl_api",
- "DeclarativeMeta",
- "__sa_DeclarativeMeta",
- )
-
- util.add_global(ctx, "sqlalchemy.orm.attributes", "Mapped", "__sa_Mapped")
-
-
-def _cls_metadata_hook(ctx: ClassDefContext) -> None:
- _add_globals(ctx)
- decl_class._scan_declarative_assignments_and_apply_types(ctx.cls, ctx.api)
-
-
-def _base_cls_hook(ctx: ClassDefContext) -> None:
- _add_globals(ctx)
- decl_class._scan_declarative_assignments_and_apply_types(ctx.cls, ctx.api)
-
-
-def _declarative_mixin_hook(ctx: ClassDefContext) -> None:
- _add_globals(ctx)
- decl_class._scan_declarative_assignments_and_apply_types(
- ctx.cls, ctx.api, is_mixin_scan=True
- )
-
-
def _cls_decorator_hook(ctx: ClassDefContext) -> None:
_add_globals(ctx)
assert isinstance(ctx.reason, nodes.MemberExpr)
assert (
isinstance(node_type, Instance)
- and names._type_id_for_named_node(node_type.type) is names.REGISTRY
+ and names.type_id_for_named_node(node_type.type) is names.REGISTRY
)
- decl_class._scan_declarative_assignments_and_apply_types(ctx.cls, ctx.api)
+ decl_class.scan_declarative_assignments_and_apply_types(ctx.cls, ctx.api)
def _base_cls_decorator_hook(ctx: ClassDefContext) -> None:
cls = ctx.cls
- _make_declarative_meta(ctx.api, cls)
+ _set_declarative_metaclass(ctx.api, cls)
- decl_class._scan_declarative_assignments_and_apply_types(
+ util.set_is_base(ctx.cls.info)
+ decl_class.scan_declarative_assignments_and_apply_types(
cls, ctx.api, is_mixin_scan=True
)
-def _dynamic_class_hook(ctx: DynamicClassDefContext) -> None:
- """Generate a declarative Base class when the declarative_base() function
- is encountered."""
-
+def _declarative_mixin_hook(ctx: ClassDefContext) -> None:
_add_globals(ctx)
+ util.set_is_base(ctx.cls.info)
+ decl_class.scan_declarative_assignments_and_apply_types(
+ ctx.cls, ctx.api, is_mixin_scan=True
+ )
- cls = ClassDef(ctx.name, Block([]))
- cls.fullname = ctx.api.qualified_name(ctx.name)
-
- info = TypeInfo(SymbolTable(), cls, ctx.api.cur_mod_id)
- cls.info = info
- _make_declarative_meta(ctx.api, cls)
-
- cls_arg = util._get_callexpr_kwarg(ctx.call, "cls", expr_types=(NameExpr,))
- if cls_arg is not None and isinstance(cls_arg.node, TypeInfo):
- decl_class._scan_declarative_assignments_and_apply_types(
- cls_arg.node.defn, ctx.api, is_mixin_scan=True
- )
- info.bases = [Instance(cls_arg.node, [])]
- else:
- obj = ctx.api.named_type("__builtins__.object")
-
- info.bases = [obj]
- try:
- calculate_mro(info)
- except MroError:
- util.fail(
- ctx.api, "Not able to calculate MRO for declarative base", ctx.call
- )
- obj = ctx.api.named_type("__builtins__.object")
- info.bases = [obj]
- info.fallback_to_any = True
+def _metaclass_cls_hook(ctx: ClassDefContext) -> None:
+ util.set_is_base(ctx.cls.info)
- ctx.api.add_symbol_table_node(ctx.name, SymbolTableNode(GDEF, info))
+def _base_cls_hook(ctx: ClassDefContext) -> None:
+ _add_globals(ctx)
+ decl_class.scan_declarative_assignments_and_apply_types(ctx.cls, ctx.api)
-def _make_declarative_meta(
- api: SemanticAnalyzerPluginInterface, target_cls: ClassDef
-) -> None:
- declarative_meta_name: NameExpr = NameExpr("__sa_DeclarativeMeta")
- declarative_meta_name.kind = GDEF
- declarative_meta_name.fullname = "sqlalchemy.orm.decl_api.DeclarativeMeta"
+def _queryable_getattr_hook(ctx: AttributeContext) -> Type:
+ # how do I....tell it it has no attribute of a certain name?
+ # can't find any Type that seems to match that
+ return ctx.default_attr_type
- # installed by _add_globals
- sym = api.lookup_qualified("__sa_DeclarativeMeta", target_cls)
- assert sym is not None and isinstance(sym.node, nodes.TypeInfo)
+def _add_globals(ctx: Union[ClassDefContext, DynamicClassDefContext]) -> None:
+ """Add __sa_DeclarativeMeta and __sa_Mapped symbol to the global space
+ for all class defs
- declarative_meta_typeinfo = sym.node
- declarative_meta_name.node = declarative_meta_typeinfo
+ """
- target_cls.metaclass = declarative_meta_name
+ util.add_global(ctx, "sqlalchemy.orm.attributes", "Mapped", "__sa_Mapped")
- declarative_meta_instance = Instance(declarative_meta_typeinfo, [])
+def _set_declarative_metaclass(
+ api: SemanticAnalyzerPluginInterface, target_cls: ClassDef
+) -> None:
info = target_cls.info
- info.declared_metaclass = info.metaclass_type = declarative_meta_instance
+ sym = api.lookup_fully_qualified_or_none(
+ "sqlalchemy.orm.decl_api.DeclarativeMeta"
+ )
+ assert sym is not None and isinstance(sym.node, TypeInfo)
+ info.declared_metaclass = info.metaclass_type = Instance(sym.node, [])
from typing import Any
-from typing import cast
from typing import Iterable
from typing import Iterator
from typing import List
from typing import TypeVar
from typing import Union
+from mypy.nodes import ARG_POS
from mypy.nodes import CallExpr
from mypy.nodes import ClassDef
from mypy.nodes import CLASSDEF_NO_INFO
from mypy.nodes import Context
+from mypy.nodes import Expression
from mypy.nodes import IfStmt
from mypy.nodes import JsonDict
+from mypy.nodes import MemberExpr
from mypy.nodes import NameExpr
from mypy.nodes import Statement
from mypy.nodes import SymbolTableNode
from mypy.plugin import DynamicClassDefContext
from mypy.plugin import SemanticAnalyzerPluginInterface
from mypy.plugins.common import deserialize_and_fixup_type
+from mypy.typeops import map_type_from_supertype
from mypy.types import Instance
from mypy.types import NoneType
-from mypy.types import ProperType
from mypy.types import Type
+from mypy.types import TypeVarType
from mypy.types import UnboundType
from mypy.types import UnionType
_TArgType = TypeVar("_TArgType", bound=Union[CallExpr, NameExpr])
-class DeclClassApplied:
+class SQLAlchemyAttribute:
def __init__(
self,
- is_mapped: bool,
- has_table: bool,
- mapped_attr_names: Iterable[Tuple[str, ProperType]],
- mapped_mro: Iterable[Instance],
- ):
- self.is_mapped = is_mapped
- self.has_table = has_table
- self.mapped_attr_names = list(mapped_attr_names)
- self.mapped_mro = list(mapped_mro)
+ name: str,
+ line: int,
+ column: int,
+ typ: Optional[Type],
+ info: TypeInfo,
+ ) -> None:
+ self.name = name
+ self.line = line
+ self.column = column
+ self.type = typ
+ self.info = info
def serialize(self) -> JsonDict:
+ assert self.type
return {
- "is_mapped": self.is_mapped,
- "has_table": self.has_table,
- "mapped_attr_names": [
- (name, type_.serialize())
- for name, type_ in self.mapped_attr_names
- ],
- "mapped_mro": [type_.serialize() for type_ in self.mapped_mro],
+ "name": self.name,
+ "line": self.line,
+ "column": self.column,
+ "type": self.type.serialize(),
}
+ def expand_typevar_from_subtype(self, sub_type: TypeInfo) -> None:
+ """Expands type vars in the context of a subtype when an attribute is inherited
+ from a generic super type."""
+ if not isinstance(self.type, TypeVarType):
+ return
+
+ self.type = map_type_from_supertype(self.type, sub_type, self.info)
+
@classmethod
def deserialize(
- cls, data: JsonDict, api: SemanticAnalyzerPluginInterface
- ) -> "DeclClassApplied":
-
- return DeclClassApplied(
- is_mapped=data["is_mapped"],
- has_table=data["has_table"],
- mapped_attr_names=cast(
- List[Tuple[str, ProperType]],
- [
- (name, deserialize_and_fixup_type(type_, api))
- for name, type_ in data["mapped_attr_names"]
- ],
- ),
- mapped_mro=cast(
- List[Instance],
- [
- deserialize_and_fixup_type(type_, api)
- for type_ in data["mapped_mro"]
- ],
- ),
- )
+ cls,
+ info: TypeInfo,
+ data: JsonDict,
+ api: SemanticAnalyzerPluginInterface,
+ ) -> "SQLAlchemyAttribute":
+ data = data.copy()
+ typ = deserialize_and_fixup_type(data.pop("type"), api)
+ return cls(typ=typ, info=info, **data)
+
+
+def _set_info_metadata(info: TypeInfo, key: str, data: Any) -> None:
+ info.metadata.setdefault("sqlalchemy", {})[key] = data
+
+
+def _get_info_metadata(info: TypeInfo, key: str) -> Optional[Any]:
+ return info.metadata.get("sqlalchemy", {}).get(key, None)
+
+
+def _get_info_mro_metadata(info: TypeInfo, key: str) -> Optional[Any]:
+ if info.mro:
+ for base in info.mro:
+ metadata = _get_info_metadata(base, key)
+ if metadata is not None:
+ return metadata
+ return None
+
+
+def set_is_base(info: TypeInfo) -> None:
+ _set_info_metadata(info, "is_base", True)
+
+
+def get_is_base(info: TypeInfo) -> bool:
+ is_base = _get_info_metadata(info, "is_base")
+ return is_base is True
+
+
+def has_declarative_base(info: TypeInfo) -> bool:
+ is_base = _get_info_mro_metadata(info, "is_base")
+ return is_base is True
+
+
+def set_has_table(info: TypeInfo) -> None:
+ _set_info_metadata(info, "has_table", True)
+
+
+def get_has_table(info: TypeInfo) -> bool:
+ is_base = _get_info_metadata(info, "has_table")
+ return is_base is True
+
+
+def get_mapped_attributes(
+ info: TypeInfo, api: SemanticAnalyzerPluginInterface
+) -> Optional[List[SQLAlchemyAttribute]]:
+ mapped_attributes: Optional[List[JsonDict]] = _get_info_metadata(
+ info, "mapped_attributes"
+ )
+ if mapped_attributes is None:
+ return None
+
+ attributes: List[SQLAlchemyAttribute] = []
+
+ for data in mapped_attributes:
+ attr = SQLAlchemyAttribute.deserialize(info, data, api)
+ attr.expand_typevar_from_subtype(info)
+ attributes.append(attr)
+
+ return attributes
+
+
+def set_mapped_attributes(
+ info: TypeInfo, attributes: List[SQLAlchemyAttribute]
+) -> None:
+ _set_info_metadata(
+ info,
+ "mapped_attributes",
+ [attribute.serialize() for attribute in attributes],
+ )
def fail(api: SemanticAnalyzerPluginInterface, msg: str, ctx: Context) -> None:
@overload
-def _get_callexpr_kwarg(
+def get_callexpr_kwarg(
callexpr: CallExpr, name: str, *, expr_types: None = ...
) -> Optional[Union[CallExpr, NameExpr]]:
...
@overload
-def _get_callexpr_kwarg(
+def get_callexpr_kwarg(
callexpr: CallExpr,
name: str,
*,
...
-def _get_callexpr_kwarg(
+def get_callexpr_kwarg(
callexpr: CallExpr,
name: str,
*,
return None
-def _flatten_typechecking(stmts: Iterable[Statement]) -> Iterator[Statement]:
+def flatten_typechecking(stmts: Iterable[Statement]) -> Iterator[Statement]:
for stmt in stmts:
if (
isinstance(stmt, IfStmt)
yield stmt
-def _unbound_to_instance(
+def unbound_to_instance(
api: SemanticAnalyzerPluginInterface, typ: Type
) -> Type:
"""Take the UnboundType that we seem to get as the ret_type from a FuncDef
if typ.name == "Optional":
# convert from "Optional?" to the more familiar
# UnionType[..., NoneType()]
- return _unbound_to_instance(
+ return unbound_to_instance(
api,
UnionType(
- [_unbound_to_instance(api, typ_arg) for typ_arg in typ.args]
+ [unbound_to_instance(api, typ_arg) for typ_arg in typ.args]
+ [NoneType()]
),
)
return Instance(
bound_type,
[
- _unbound_to_instance(api, arg)
+ unbound_to_instance(api, arg)
if isinstance(arg, UnboundType)
else arg
for arg in typ.args
return typ
-def _info_for_cls(
+def info_for_cls(
cls: ClassDef, api: SemanticAnalyzerPluginInterface
-) -> TypeInfo:
+) -> Optional[TypeInfo]:
if cls.info is CLASSDEF_NO_INFO:
sym = api.lookup_qualified(cls.name, cls)
if sym is None:
return sym.node
return cls.info
+
+
+def expr_to_mapped_constructor(expr: Expression) -> CallExpr:
+ column_descriptor = NameExpr("__sa_Mapped")
+ column_descriptor.fullname = "sqlalchemy.orm.attributes.Mapped"
+ member_expr = MemberExpr(column_descriptor, "_empty_constructor")
+ return CallExpr(
+ member_expr,
+ [expr],
+ [ARG_POS],
+ ["arg1"],
+ )
asyncio =
greenlet!=0.4.17;python_version>="3"
mypy =
- mypy >= 0.800;python_version>="3"
+ mypy >= 0.910;python_version>="3"
sqlalchemy2-stubs
mssql = pyodbc
mssql_pymssql = pymssql
--- /dev/null
+from sqlalchemy import Column
+from sqlalchemy import Integer
+from sqlalchemy import String
+from sqlalchemy.orm import declarative_base
+from sqlalchemy.orm import registry
+
+
+reg: registry = registry()
+
+Base = declarative_base()
+
+
+class SomeAbstract(Base):
+ __abstract__ = True
+
+
+class HasUpdatedAt:
+ updated_at = Column(Integer)
+
+
+@reg.mapped
+class Foo(SomeAbstract):
+ __tablename__ = "foo"
+ id: int = Column(Integer(), primary_key=True)
+ name: str = Column(String)
+
+
+class Bar(HasUpdatedAt, Base):
+ __tablename__ = "bar"
+ id = Column(Integer(), primary_key=True)
+ num = Column(Integer)
+
+
+Bar.__mapper__
+
+# EXPECTED_MYPY: "Type[HasUpdatedAt]" has no attribute "__mapper__"
+HasUpdatedAt.__mapper__
+
+
+# EXPECTED_MYPY: "Type[SomeAbstract]" has no attribute "__mapper__"
+SomeAbstract.__mapper__