]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
try to support mypy 0.990
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 9 Nov 2022 14:13:44 +0000 (09:13 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 9 Nov 2022 20:02:29 +0000 (15:02 -0500)
mypy introduces a crash we need to work around, also
some new rules.  It also has either a behavioral change
regarding how output is rendered in relationship to
files being within sys.path or not, so work around
that for test_mypy_plugin_py3k.py

References: https://github.com/python/mypy/issues/14027
Change-Id: I689c7fe27dc52abee932de9e0fb23b2a2eba76fa

12 files changed:
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/interfaces.py
lib/sqlalchemy/ext/associationproxy.py
lib/sqlalchemy/orm/bulk_persistence.py
lib/sqlalchemy/orm/decl_api.py
lib/sqlalchemy/pool/base.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/traversals.py
lib/sqlalchemy/sql/visitors.py
lib/sqlalchemy/util/langhelpers.py
lib/sqlalchemy/util/typing.py
test/ext/mypy/test_mypy_plugin_py3k.py

index 9dde9a445b29653fdd5946662ecca4c1dd803fb9..b2f6b29b7839d801bccce97c781d212f2be221f2 100644 (file)
@@ -2947,7 +2947,7 @@ class Engine(
             self.update_execution_options(**execution_options)
 
     def _lru_size_alert(self, cache: util.LRUCache[Any, Any]) -> None:
-        if self._should_log_info:
+        if self._should_log_info():
             self.logger.info(
                 "Compiled cache size pruning from %d items to %d.  "
                 "Increase cache size to reduce the frequency of pruning.",
index 7395b2fa48ff571d23b9f77ee9655a0bba8d62fb..e10fab831d9dc6c2e4d65cda93e01beb80edc27e 100644 (file)
@@ -1272,8 +1272,11 @@ class Dialect(EventTarget):
 
         This is an internal dialect method. Applications should use
         :meth:`.Inspector.get_columns`.
+
         """
 
+        raise NotImplementedError()
+
     def get_multi_columns(
         self,
         connection: Connection,
index a17c37dacebe8ef4320bfd584585287e343dd081..bfec0913766e31c1c053e4daf5862ef401c5341f 100644 (file)
@@ -148,30 +148,32 @@ class _GetterProtocol(Protocol[_T_co]):
         ...
 
 
-class _SetterProtocol(Protocol[_T_co]):
+# mypy 0.990 we are no longer allowed to make this Protocol[_T_con]
+class _SetterProtocol(Protocol):
     ...
 
 
-class _PlainSetterProtocol(_SetterProtocol[_T_con]):
+class _PlainSetterProtocol(_SetterProtocol, Protocol[_T_con]):
     def __call__(self, instance: Any, value: _T_con) -> None:
         ...
 
 
-class _DictSetterProtocol(_SetterProtocol[_T_con]):
+class _DictSetterProtocol(_SetterProtocol, Protocol[_T_con]):
     def __call__(self, instance: Any, key: Any, value: _T_con) -> None:
         ...
 
 
-class _CreatorProtocol(Protocol[_T_co]):
+# mypy 0.990 we are no longer allowed to make this Protocol[_T_con]
+class _CreatorProtocol(Protocol):
     ...
 
 
-class _PlainCreatorProtocol(_CreatorProtocol[_T_con]):
+class _PlainCreatorProtocol(_CreatorProtocol, Protocol[_T_con]):
     def __call__(self, value: _T_con) -> Any:
         ...
 
 
-class _KeyCreatorProtocol(_CreatorProtocol[_T_con]):
+class _KeyCreatorProtocol(_CreatorProtocol, Protocol[_T_con]):
     def __call__(self, key: Any, value: Optional[_T_con]) -> Any:
         ...
 
@@ -188,7 +190,7 @@ class _GetSetFactoryProtocol(Protocol):
         self,
         collection_class: Optional[Type[Any]],
         assoc_instance: AssociationProxyInstance[Any],
-    ) -> Tuple[_GetterProtocol[Any], _SetterProtocol[Any]]:
+    ) -> Tuple[_GetterProtocol[Any], _SetterProtocol]:
         ...
 
 
@@ -196,7 +198,7 @@ class _ProxyFactoryProtocol(Protocol):
     def __call__(
         self,
         lazy_collection: _LazyCollectionProtocol[Any],
-        creator: _CreatorProtocol[Any],
+        creator: _CreatorProtocol,
         value_attr: str,
         parent: AssociationProxyInstance[Any],
     ) -> Any:
@@ -214,7 +216,7 @@ class _AssociationProxyProtocol(Protocol[_T]):
     """describes the interface of :class:`.AssociationProxy`
     without including descriptor methods in the interface."""
 
-    creator: Optional[_CreatorProtocol[Any]]
+    creator: Optional[_CreatorProtocol]
     key: str
     target_collection: str
     value_attr: str
@@ -233,7 +235,7 @@ class _AssociationProxyProtocol(Protocol[_T]):
 
     def _default_getset(
         self, collection_class: Any
-    ) -> Tuple[_GetterProtocol[Any], _SetterProtocol[Any]]:
+    ) -> Tuple[_GetterProtocol[Any], _SetterProtocol]:
         ...
 
 
@@ -256,7 +258,7 @@ class AssociationProxy(
         self,
         target_collection: str,
         attr: str,
-        creator: Optional[_CreatorProtocol[Any]] = None,
+        creator: Optional[_CreatorProtocol] = None,
         getset_factory: Optional[_GetSetFactoryProtocol] = None,
         proxy_factory: Optional[_ProxyFactoryProtocol] = None,
         proxy_bulk_set: Optional[_ProxyBulkSetProtocol] = None,
@@ -459,7 +461,7 @@ class AssociationProxy(
 
     def _default_getset(
         self, collection_class: Any
-    ) -> Tuple[_GetterProtocol[Any], _SetterProtocol[Any]]:
+    ) -> Tuple[_GetterProtocol[Any], _SetterProtocol]:
         attr = self.value_attr
         _getter = operator.attrgetter(attr)
 
@@ -760,7 +762,7 @@ class AssociationProxyInstance(SQLORMOperations[_T]):
 
     def _default_getset(
         self, collection_class: Any
-    ) -> Tuple[_GetterProtocol[Any], _SetterProtocol[Any]]:
+    ) -> Tuple[_GetterProtocol[Any], _SetterProtocol]:
         attr = self.value_attr
         _getter = operator.attrgetter(attr)
 
@@ -864,7 +866,7 @@ class AssociationProxyInstance(SQLORMOperations[_T]):
         creator = (
             self.parent.creator
             if self.parent.creator is not None
-            else cast("_CreatorProtocol[_T]", self.target_class)
+            else cast("_CreatorProtocol", self.target_class)
         )
         collection_class = util.duck_type_collection(lazy_collection())
 
@@ -945,7 +947,7 @@ class AssociationProxyInstance(SQLORMOperations[_T]):
         creator = (
             self.parent.creator
             and self.parent.creator
-            or cast(_CreatorProtocol[Any], self.target_class)
+            or cast(_CreatorProtocol, self.target_class)
         )
 
         if self.parent.getset_factory:
@@ -1266,7 +1268,7 @@ class _AssociationCollection(Generic[_IT]):
     getter: _GetterProtocol[_IT]
     """A function.  Given an associated object, return the 'value'."""
 
-    creator: _CreatorProtocol[_IT]
+    creator: _CreatorProtocol
     """
     A function that creates new target entities.  Given one parameter:
     value.  This assertion is assumed::
@@ -1276,7 +1278,7 @@ class _AssociationCollection(Generic[_IT]):
     """
 
     parent: AssociationProxyInstance[_IT]
-    setter: _SetterProtocol[_IT]
+    setter: _SetterProtocol
     """A function.  Given an associated object and a value, store that
         value on the object.
     """
@@ -1288,9 +1290,9 @@ class _AssociationCollection(Generic[_IT]):
     def __init__(
         self,
         lazy_collection: _LazyCollectionProtocol[_IT],
-        creator: _CreatorProtocol[_IT],
+        creator: _CreatorProtocol,
         getter: _GetterProtocol[_IT],
-        setter: _SetterProtocol[_IT],
+        setter: _SetterProtocol,
         parent: AssociationProxyInstance[_IT],
     ):
         """Constructs an _AssociationCollection.
index 651533db6a10f7cf180ef20d1143fe46ae23353e..cfe48800396eaab8a3f9e48fa5903c41e3db6dea 100644 (file)
@@ -885,10 +885,12 @@ class BulkUDCompileState(ORMDMLState):
         if crit:
             eval_condition = evaluator_compiler.process(*crit)
         else:
-
-            def eval_condition(obj):
+            # workaround for mypy https://github.com/python/mypy/issues/14027
+            def _eval_condition(obj):
                 return True
 
+            eval_condition = _eval_condition
+
         return eval_condition
 
     @classmethod
index 94c5181de38d6d683b2c3059ae26d25138ccc7d0..01766ad850c2b0a74cf1787f66a750d7149408a5 100644 (file)
@@ -1,4 +1,4 @@
-# ext/declarative/api.py
+# orm/declarative/api.py
 # Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
 # <see AUTHORS file>
 #
@@ -1274,7 +1274,7 @@ class registry:
         if isinstance(cls, type):
             class_dict["__doc__"] = cls.__doc__
 
-        if self.constructor:
+        if self.constructor is not None:
             class_dict["__init__"] = self.constructor
 
         class_dict["__abstract__"] = True
index 18ff66e5de773578f731e9e119c331514aa7531d..47c39791c5c2011eedacfdf358c0c558200a08dd 100644 (file)
@@ -727,7 +727,7 @@ class _ConnectionRecord(ConnectionPoolEntry):
             lambda ref: _finalize_fairy(
                 None, rec, pool, ref, echo, transaction_was_reset=False
             )
-            if _finalize_fairy
+            if _finalize_fairy is not None
             else None,
         )
         _strong_ref_connection_records[ref] = rec
index 704c0d19cf982f1eb4c54bab8deaf2aa4bf0a71c..3e62cb3505d0fbfc1381d040d5c18b5d626b1706 100644 (file)
@@ -1694,7 +1694,7 @@ class SQLCompiler(Compiled):
             # at all if the key were present in the parameters
             if autoinc_key in self.binds:
 
-                def autoinc_getter(lastrowid, parameters):
+                def _autoinc_getter(lastrowid, parameters):
                     param_value = parameters.get(autoinc_key, lastrowid)
                     if param_value is not None:
                         # they supplied non-None parameter, use that.
@@ -1706,6 +1706,9 @@ class SQLCompiler(Compiled):
                         # use lastrowid
                         return lastrowid
 
+                # work around mypy https://github.com/python/mypy/issues/14027
+                autoinc_getter = _autoinc_getter
+
         else:
             lastrowid_processor = None
 
@@ -1727,7 +1730,7 @@ class SQLCompiler(Compiled):
                 return row_fn(
                     (
                         autoinc_getter(lastrowid, parameters)
-                        if autoinc_getter
+                        if autoinc_getter is not None
                         else lastrowid
                     )
                     if col is autoinc_col
index 94e635740e7f391ed04405abf5a4ef54c891df29..135407321469752596c032d8a943d679a6f9c4ff 100644 (file)
@@ -565,7 +565,7 @@ class TraversalComparatorStrategy(HasTraversalDispatch, util.MemoizedSlots):
                 assert right_attrname is not None
 
                 dispatch = self.dispatch(left_visit_sym)
-                assert dispatch, (
+                assert dispatch is not None, (
                     f"{self.__class__} has no dispatch for "
                     f"'{self._dispatch_lookup[left_visit_sym]}'"
                 )
index b550f8f28674c546ee6e2400fa685d6f3aa9aa92..73710784431e441a1c12aa13d2f0f62011494516 100644 (file)
@@ -553,7 +553,7 @@ class HasTraversalDispatch:
         names = []
         for attrname, visit_sym in internal_dispatch:
             meth = self.dispatch(visit_sym)
-            if meth:
+            if meth is not None:
                 visit_name = _dispatch_lookup[visit_sym]
                 names.append((attrname, visit_name))
 
index d4dac7249c5b6690ec61d19bc69a7fba7ea65b7b..d8d39f56c6200ec633175b81cc6de0283d0bb029 100644 (file)
@@ -1448,7 +1448,7 @@ def duck_type_collection(
         else:
             return specimen.__emulates__  # type: ignore
 
-    isa = isinstance(specimen, type) and issubclass or isinstance
+    isa = issubclass if isinstance(specimen, type) else isinstance
     if isa(specimen, list):
         return list
     elif isa(specimen, set):
index e4674a44cb03ca4aee1aa722ae390f99b952b229..9eb761eff0629016da1fe7608194d5d4c870da9b 100644 (file)
@@ -1,3 +1,9 @@
+# util/typing.py
+# Copyright (C) 2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
 # mypy: allow-untyped-defs, allow-untyped-calls
 
 from __future__ import annotations
@@ -17,6 +23,7 @@ from typing import Optional
 from typing import overload
 from typing import Tuple
 from typing import Type
+from typing import TYPE_CHECKING
 from typing import TypeVar
 from typing import Union
 
@@ -368,14 +375,16 @@ class DescriptorReference(Generic[_DESC]):
 
     """
 
-    def __get__(self, instance: object, owner: Any) -> _DESC:
-        ...
+    if TYPE_CHECKING:
 
-    def __set__(self, instance: Any, value: _DESC) -> None:
-        ...
+        def __get__(self, instance: object, owner: Any) -> _DESC:
+            ...
 
-    def __delete__(self, instance: Any) -> None:
-        ...
+        def __set__(self, instance: Any, value: _DESC) -> None:
+            ...
+
+        def __delete__(self, instance: Any) -> None:
+            ...
 
 
 _DESC_co = TypeVar("_DESC_co", bound=DescriptorProto, covariant=True)
@@ -389,14 +398,16 @@ class RODescriptorReference(Generic[_DESC_co]):
 
     """
 
-    def __get__(self, instance: object, owner: Any) -> _DESC_co:
-        ...
+    if TYPE_CHECKING:
 
-    def __set__(self, instance: Any, value: Any) -> NoReturn:
-        ...
+        def __get__(self, instance: object, owner: Any) -> _DESC_co:
+            ...
 
-    def __delete__(self, instance: Any) -> NoReturn:
-        ...
+        def __set__(self, instance: Any, value: Any) -> NoReturn:
+            ...
+
+        def __delete__(self, instance: Any) -> NoReturn:
+            ...
 
 
 _FN = TypeVar("_FN", bound=Optional[Callable[..., Any]])
@@ -411,14 +422,16 @@ class CallableReference(Generic[_FN]):
 
     """
 
-    def __get__(self, instance: object, owner: Any) -> _FN:
-        ...
+    if TYPE_CHECKING:
 
-    def __set__(self, instance: Any, value: _FN) -> None:
-        ...
+        def __get__(self, instance: object, owner: Any) -> _FN:
+            ...
 
-    def __delete__(self, instance: Any) -> None:
-        ...
+        def __set__(self, instance: Any, value: _FN) -> None:
+            ...
+
+        def __delete__(self, instance: Any) -> None:
+            ...
 
 
 # $def ro_descriptor_reference(fn: Callable[])
index dce75b4f346932e86e419ac95fe43c4849b04418..5d3388ca6ea5fae469e40231b923a0e6d97b4452 100644 (file)
@@ -117,7 +117,15 @@ class MypyPluginTest(fixtures.TestBase):
                 ),
             ]
 
-            args.append(path)
+            # mypy as of 0.990 is more aggressively blocking messaging
+            # for paths that are in sys.path, and as pytest puts currdir,
+            # test/ etc in sys.path, just copy the source file to the
+            # tempdir we are working in so that we don't have to try to
+            # manipulate sys.path and/or guess what mypy is doing
+            filename = os.path.basename(path)
+            test_program = os.path.join(cachedir, filename)
+            shutil.copyfile(path, test_program)
+            args.append(test_program)
 
             # I set this locally but for the suite here needs to be
             # disabled
@@ -281,7 +289,9 @@ class MypyPluginTest(fixtures.TestBase):
         not_located = []
 
         if expected_messages:
-            eq_(result[2], 1, msg=result)
+            # mypy 0.990 changed how return codes work, so don't assume a
+            # 1 or a 0 return code here, could be either depending on if
+            # errors were generated or not
 
             output = []