From 803a2373f58e794499e1e0b476db4c23e8dd7f87 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sun, 28 Mar 2021 11:09:40 -0400 Subject: [PATCH] Add DeclarativeMeta to globals Fixed issue in mypy plugin where newly added support for :func:`_orm.as_declarative` needed to more fully add the ``DeclarativeMeta`` class to the mypy interpreter's state so that it does not result in a name not found error; additionally improves how global names are setup for the plugin including the ``Mapped`` name. Introduces directory oriented testing as well, where a full set of files will be copied, mypy runs, then zero or more patches are applied and mypy is run again, to fully test incremental behaviors. Fixes: sqlalchemy/sqlalchemy2-stubs/#14 Change-Id: Ide785c07e19ba0694e8cf6f91560094ecb182016 --- MANIFEST.in | 2 +- .../changelog/unreleased_14/stubs_14.rst | 10 +++ lib/sqlalchemy/ext/mypy/decl_class.py | 16 ++--- lib/sqlalchemy/ext/mypy/plugin.py | 66 +++++++++++++------ lib/sqlalchemy/ext/mypy/util.py | 16 +++++ lib/sqlalchemy/testing/requirements.py | 12 ++++ test/ext/mypy/files/as_declarative.py | 18 ++++- .../ext/mypy/incremental/stubs_14/__init__.py | 24 +++++++ test/ext/mypy/incremental/stubs_14/address.py | 13 ++++ .../incremental/stubs_14/patch1.testpatch | 13 ++++ test/ext/mypy/incremental/stubs_14/user.py | 39 +++++++++++ test/ext/mypy/test_mypy_plugin_py3k.py | 61 ++++++++++++++++- tox.ini | 1 + 13 files changed, 256 insertions(+), 35 deletions(-) create mode 100644 doc/build/changelog/unreleased_14/stubs_14.rst create mode 100644 test/ext/mypy/incremental/stubs_14/__init__.py create mode 100644 test/ext/mypy/incremental/stubs_14/address.py create mode 100644 test/ext/mypy/incremental/stubs_14/patch1.testpatch create mode 100644 test/ext/mypy/incremental/stubs_14/user.py diff --git a/MANIFEST.in b/MANIFEST.in index 372cc16525..6d04f593c8 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -3,7 +3,7 @@ recursive-include doc *.html *.css *.txt *.js *.png *.py Makefile *.rst *.sty recursive-include examples *.py *.xml -recursive-include test *.py *.dat +recursive-include test *.py *.dat *.testpatch # include the c extensions, which otherwise # don't come in if --with-cextensions isn't specified. diff --git a/doc/build/changelog/unreleased_14/stubs_14.rst b/doc/build/changelog/unreleased_14/stubs_14.rst new file mode 100644 index 0000000000..1e4effde41 --- /dev/null +++ b/doc/build/changelog/unreleased_14/stubs_14.rst @@ -0,0 +1,10 @@ +.. change:: + :tags: bug, mypy + :tickets: sqlalchemy/sqlalchemy2-stubs/#14 + + Fixed issue in mypy plugin where newly added support for + :func:`_orm.as_declarative` needed to more fully add the + ``DeclarativeMeta`` class to the mypy interpreter's state so that it does + not result in a name not found error; additionally improves how global + names are setup for the plugin including the ``Mapped`` name. + diff --git a/lib/sqlalchemy/ext/mypy/decl_class.py b/lib/sqlalchemy/ext/mypy/decl_class.py index 9fb1fa807a..46f3cc30e3 100644 --- a/lib/sqlalchemy/ext/mypy/decl_class.py +++ b/lib/sqlalchemy/ext/mypy/decl_class.py @@ -229,7 +229,7 @@ def _scan_declarative_decorator_stmt( left_hand_explicit_type = AnyType(TypeOfAny.special_form) - descriptor = api.modules["sqlalchemy.orm.attributes"].names["Mapped"] + descriptor = api.lookup("__sa_Mapped", cls) left_node = NameExpr(stmt.var.name) left_node.node = stmt.var @@ -254,8 +254,7 @@ def _scan_declarative_decorator_stmt( # : Mapped[] = # _sa_Mapped._empty_constructor(lambda: ) # the function body is maintained so it gets type checked internally - api.add_symbol_table_node("_sa_Mapped", descriptor) - column_descriptor = nodes.NameExpr("_sa_Mapped") + column_descriptor = nodes.NameExpr("__sa_Mapped") column_descriptor.fullname = "sqlalchemy.orm.Mapped" mm = nodes.MemberExpr(column_descriptor, "_empty_constructor") @@ -753,7 +752,7 @@ def _infer_type_from_left_and_inferred_right( """ if not is_subtype(left_hand_explicit_type, python_type_for_type): - descriptor = api.modules["sqlalchemy.orm.attributes"].names["Mapped"] + descriptor = api.lookup("__sa_Mapped", node) effective_type = Instance(descriptor.node, [python_type_for_type]) @@ -794,8 +793,7 @@ def _infer_type_from_left_hand_type_only( ) util.fail(api, msg.format(node.name), node) - descriptor = api.modules["sqlalchemy.orm.attributes"].names["Mapped"] - + descriptor = api.lookup("__sa_Mapped", node) return Instance(descriptor.node, [AnyType(TypeOfAny.special_form)]) else: @@ -816,8 +814,7 @@ def _re_apply_declarative_assignments( name: typ for name, typ in cls_metadata.mapped_attr_names } - descriptor = api.modules["sqlalchemy.orm.attributes"].names["Mapped"] - + descriptor = api.lookup("__sa_Mapped", cls) for stmt in cls.defs.body: # for a re-apply, all of our statements are AssignmentStmt; # @declared_attr calls will have been converted and this @@ -859,8 +856,7 @@ def _apply_type_to_mapped_statement( attrname : Mapped[Optional[int]] = """ - descriptor = api.modules["sqlalchemy.orm.attributes"].names["Mapped"] - + descriptor = api.lookup("__sa_Mapped", stmt) left_node = lvalue.node inst = Instance(descriptor.node, [python_type_for_type]) diff --git a/lib/sqlalchemy/ext/mypy/plugin.py b/lib/sqlalchemy/ext/mypy/plugin.py index c8fbcd6a21..9ca1cb2daf 100644 --- a/lib/sqlalchemy/ext/mypy/plugin.py +++ b/lib/sqlalchemy/ext/mypy/plugin.py @@ -126,6 +126,7 @@ def _fill_in_decorators(ctx: ClassDefContext) -> None: # and "registry.as_declarative_base()" methods. # this seems like a bug in mypy that these decorators are otherwise # skipped. + if ( isinstance(decorator, nodes.CallExpr) and isinstance(decorator.callee, nodes.MemberExpr) @@ -165,15 +166,34 @@ def _fill_in_decorators(ctx: ClassDefContext) -> None: ) +def _add_globals(ctx: ClassDefContext): + """Add __sa_DeclarativeMeta and __sa_Mapped symbol to the global space + for all class defs + + """ + + util.add_global( + ctx, + "sqlalchemy.orm.decl_api", + "DeclarativeMeta", + "__sa_DeclarativeMeta", + ) + + util.add_global(ctx, "sqlalchemy.orm.attributes", "Mapped", "__sa_Mapped") + + def _cls_metadata_hook(ctx: ClassDefContext) -> None: + _add_globals(ctx) decl_class._scan_declarative_assignments_and_apply_types(ctx.cls, ctx.api) def _base_cls_hook(ctx: ClassDefContext) -> None: + _add_globals(ctx) decl_class._scan_declarative_assignments_and_apply_types(ctx.cls, ctx.api) def _cls_decorator_hook(ctx: ClassDefContext) -> None: + _add_globals(ctx) assert isinstance(ctx.reason, nodes.MemberExpr) expr = ctx.reason.expr assert names._type_id_for_named_node(expr.node.type.type) is names.REGISTRY @@ -181,28 +201,8 @@ def _cls_decorator_hook(ctx: ClassDefContext) -> None: decl_class._scan_declarative_assignments_and_apply_types(ctx.cls, ctx.api) -def _make_declarative_meta( - api: SemanticAnalyzerPluginInterface, target_cls: ClassDef -): - declarative_meta_sym: SymbolTableNode = api.modules[ - "sqlalchemy.orm.decl_api" - ].names["DeclarativeMeta"] - declarative_meta_typeinfo: TypeInfo = declarative_meta_sym.node - - declarative_meta_name: NameExpr = NameExpr("DeclarativeMeta") - declarative_meta_name.kind = GDEF - declarative_meta_name.fullname = "sqlalchemy.orm.decl_api.DeclarativeMeta" - declarative_meta_name.node = declarative_meta_typeinfo - - target_cls.metaclass = declarative_meta_name - - declarative_meta_instance = Instance(declarative_meta_typeinfo, []) - - info = target_cls.info - info.declared_metaclass = info.metaclass_type = declarative_meta_instance - - def _base_cls_decorator_hook(ctx: ClassDefContext) -> None: + _add_globals(ctx) cls = ctx.cls @@ -217,6 +217,8 @@ def _dynamic_class_hook(ctx: DynamicClassDefContext) -> None: """Generate a declarative Base class when the declarative_base() function is encountered.""" + _add_globals(ctx) + cls = ClassDef(ctx.name, Block([])) cls.fullname = ctx.api.qualified_name(ctx.name) @@ -246,3 +248,25 @@ def _dynamic_class_hook(ctx: DynamicClassDefContext) -> None: info.fallback_to_any = True ctx.api.add_symbol_table_node(ctx.name, SymbolTableNode(GDEF, info)) + + +def _make_declarative_meta( + api: SemanticAnalyzerPluginInterface, target_cls: ClassDef +): + + declarative_meta_name: NameExpr = NameExpr("__sa_DeclarativeMeta") + declarative_meta_name.kind = GDEF + declarative_meta_name.fullname = "sqlalchemy.orm.decl_api.DeclarativeMeta" + + # installed by _add_globals + sym = api.lookup("__sa_DeclarativeMeta", target_cls) + + declarative_meta_typeinfo = sym.node + declarative_meta_name.node = declarative_meta_typeinfo + + target_cls.metaclass = declarative_meta_name + + declarative_meta_instance = Instance(declarative_meta_typeinfo, []) + + info = target_cls.info + info.declared_metaclass = info.metaclass_type = declarative_meta_instance diff --git a/lib/sqlalchemy/ext/mypy/util.py b/lib/sqlalchemy/ext/mypy/util.py index e7178a885d..7079f3cd78 100644 --- a/lib/sqlalchemy/ext/mypy/util.py +++ b/lib/sqlalchemy/ext/mypy/util.py @@ -18,6 +18,22 @@ def fail(api: SemanticAnalyzerPluginInterface, msg: str, ctx: Context): return api.fail(msg, ctx) +def add_global( + ctx: SemanticAnalyzerPluginInterface, + module: str, + symbol_name: str, + asname: str, +): + module_globals = ctx.api.modules[ctx.api.cur_mod_id].names + + if asname not in module_globals: + lookup_sym: SymbolTableNode = ctx.api.modules[module].names[ + symbol_name + ] + + module_globals[asname] = lookup_sym + + def _get_callexpr_kwarg(callexpr: CallExpr, name: str) -> Optional[NameExpr]: try: arg_idx = callexpr.arg_names.index(name) diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py index 46844803b4..b6381dd577 100644 --- a/lib/sqlalchemy/testing/requirements.py +++ b/lib/sqlalchemy/testing/requirements.py @@ -1235,6 +1235,18 @@ class SuiteRequirements(Requirements): lambda: util.cpython, "cPython interpreter needed" ) + @property + def patch_library(self): + def check_lib(): + try: + __import__("patch") + except ImportError: + return False + else: + return True + + return exclusions.only_if(check_lib, "patch library needed") + @property def non_broken_pickle(self): from sqlalchemy.util import pickle diff --git a/test/ext/mypy/files/as_declarative.py b/test/ext/mypy/files/as_declarative.py index 7a3fdc068d..ab5245b20c 100644 --- a/test/ext/mypy/files/as_declarative.py +++ b/test/ext/mypy/files/as_declarative.py @@ -1,7 +1,13 @@ +from typing import List +from typing import Optional + from sqlalchemy import Column from sqlalchemy import Integer from sqlalchemy import String from sqlalchemy.ext.declarative import as_declarative +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import relationship +from sqlalchemy.sql.schema import ForeignKey @as_declarative @@ -12,7 +18,17 @@ class Base(object): class Foo(Base): __tablename__ = "foo" id: int = Column(Integer(), primary_key=True) - name: str = Column(String) + name: Mapped[str] = Column(String) + + bar: List["Bar"] = relationship("Bar") + + +class Bar(Base): + __tablename__ = "bar" + id: int = Column(Integer(), primary_key=True) + foo_id: int = Column(ForeignKey("foo.id")) + + foo: Optional[Foo] = relationship(Foo) f1 = Foo() diff --git a/test/ext/mypy/incremental/stubs_14/__init__.py b/test/ext/mypy/incremental/stubs_14/__init__.py new file mode 100644 index 0000000000..c40dd273ab --- /dev/null +++ b/test/ext/mypy/incremental/stubs_14/__init__.py @@ -0,0 +1,24 @@ +from typing import TYPE_CHECKING + +from sqlalchemy import Column +from sqlalchemy import Integer +from sqlalchemy.orm import as_declarative +from sqlalchemy.orm import declared_attr +from sqlalchemy.orm import Mapped +from .address import Address +from .user import User + +if TYPE_CHECKING: + from sqlalchemy.orm.decl_api import DeclarativeMeta + + +@as_declarative() +class Base(object): + @declared_attr + def __tablename__(self) -> Mapped[str]: + return self.__name__.lower() + + id = Column(Integer, primary_key=True) + + +__all__ = ["User", "Address"] diff --git a/test/ext/mypy/incremental/stubs_14/address.py b/test/ext/mypy/incremental/stubs_14/address.py new file mode 100644 index 0000000000..dd16273164 --- /dev/null +++ b/test/ext/mypy/incremental/stubs_14/address.py @@ -0,0 +1,13 @@ +from typing import TYPE_CHECKING + +from . import Base +from .user import HasUser + +if TYPE_CHECKING: + from .user import User # noqa + from sqlalchemy import Integer, Column # noqa + from sqlalchemy.orm import RelationshipProperty # noqa + + +class Address(Base, HasUser): + pass diff --git a/test/ext/mypy/incremental/stubs_14/patch1.testpatch b/test/ext/mypy/incremental/stubs_14/patch1.testpatch new file mode 100644 index 0000000000..528236a00e --- /dev/null +++ b/test/ext/mypy/incremental/stubs_14/patch1.testpatch @@ -0,0 +1,13 @@ +diff --git a/test/ext/mypy/incremental/stubs_14/user.py b/test/ext/mypy/incremental/stubs_14/user.py +index 2c60403e4..c7e8f8874 100644 +--- a/user.py ++++ b/user.py +@@ -18,6 +18,8 @@ if TYPE_CHECKING: + class User(Base): + name = Column(String) + ++ othername = Column(String) ++ + addresses: Mapped[List["Address"]] = relationship( + "Address", back_populates="user" + ) diff --git a/test/ext/mypy/incremental/stubs_14/user.py b/test/ext/mypy/incremental/stubs_14/user.py new file mode 100644 index 0000000000..c7e8f88747 --- /dev/null +++ b/test/ext/mypy/incremental/stubs_14/user.py @@ -0,0 +1,39 @@ +from typing import List +from typing import TYPE_CHECKING + +from sqlalchemy import Column +from sqlalchemy import ForeignKey +from sqlalchemy import Integer +from sqlalchemy import String +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import relationship +from sqlalchemy.orm.decl_api import declared_attr +from sqlalchemy.orm.relationships import RelationshipProperty +from . import Base + +if TYPE_CHECKING: + from .address import Address + + +class User(Base): + name = Column(String) + + othername = Column(String) + + addresses: Mapped[List["Address"]] = relationship( + "Address", back_populates="user" + ) + + +class HasUser: + @declared_attr + def user_id(self) -> "Column[Integer]": + return Column( + Integer, + ForeignKey(User.id, ondelete="CASCADE", onupdate="CASCADE"), + nullable=False, + ) + + @declared_attr + def user(self) -> RelationshipProperty[User]: + return relationship(User) diff --git a/test/ext/mypy/test_mypy_plugin_py3k.py b/test/ext/mypy/test_mypy_plugin_py3k.py index bf82aaa867..c8d042db0b 100644 --- a/test/ext/mypy/test_mypy_plugin_py3k.py +++ b/test/ext/mypy/test_mypy_plugin_py3k.py @@ -1,5 +1,6 @@ import os import re +import shutil import tempfile from sqlalchemy import testing @@ -36,8 +37,15 @@ class MypyPluginTest(fixtures.TestBase): def mypy_runner(self, cachedir): from mypy import api - def run(filename, use_plugin=True): - path = os.path.join(os.path.dirname(__file__), "files", filename) + def run( + filename, use_plugin=True, incremental=False, working_dir=None + ): + if working_dir: + path = os.path.join(working_dir, filename) + else: + path = os.path.join( + os.path.dirname(__file__), "files", filename + ) args = [ "--strict", @@ -59,6 +67,55 @@ class MypyPluginTest(fixtures.TestBase): return run + def _incremental_dirs(): + path = os.path.join(os.path.dirname(__file__), "incremental") + return [ + d for d in os.listdir(path) if os.path.isdir(os.path.join(path, d)) + ] + + @testing.combinations( + *[(dirname,) for dirname in _incremental_dirs()], argnames="dirname" + ) + @testing.requires.patch_library + def test_incremental(self, mypy_runner, cachedir, dirname): + import patch + + path = os.path.join(os.path.dirname(__file__), "incremental", dirname) + dest = os.path.join(cachedir, "mymodel") + os.mkdir(dest) + + patches = set() + + print("incremental test: %s" % dirname) + + for fname in os.listdir(path): + if fname.endswith(".py"): + shutil.copy( + os.path.join(path, fname), os.path.join(dest, fname) + ) + print("copying to: %s" % os.path.join(dest, fname)) + elif fname.endswith(".testpatch"): + patches.add(fname) + + for patchfile in [None] + sorted(patches): + if patchfile is not None: + print("Applying patchfile %s" % patchfile) + patch_obj = patch.fromfile(os.path.join(path, patchfile)) + patch_obj.apply(1, dest) + print("running mypy against %s/mymodel" % cachedir) + result = mypy_runner( + "mymodel", + use_plugin=True, + incremental=True, + working_dir=cachedir, + ) + eq_( + result[2], + 0, + msg="Failure after applying patch %s: %s" + % (patchfile, result[0]), + ) + def _file_combinations(): path = os.path.join(os.path.dirname(__file__), "files") return [f for f in os.listdir(path) if f.endswith(".py")] diff --git a/tox.ini b/tox.ini index 7831405024..6eea65c8f6 100644 --- a/tox.ini +++ b/tox.ini @@ -130,6 +130,7 @@ deps= mock; python_version < '3.3' importlib_metadata; python_version < '3.8' mypy + patch==1.* git+https://github.com/sqlalchemy/sqlalchemy2-stubs commands = pytest test/ext/mypy/test_mypy_plugin_py3k.py {posargs} -- 2.47.2