]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Improve compiled extension detection
authorFederico Caselli <cfederico87@gmail.com>
Sun, 4 Sep 2022 19:44:46 +0000 (21:44 +0200)
committerFederico Caselli <cfederico87@gmail.com>
Mon, 5 Sep 2022 20:32:14 +0000 (22:32 +0200)
Ensure that all cython extension are imported by the compied detection logic.
This is required since cython extensions moduels are marked as optional
in the install, so it's possible that only some of them are compiled.
The extensions are enabled only if all of them are correctly compiled

Change-Id: I355cbac06f5c7a47d35661f42ebab3b0156c1965

lib/sqlalchemy/util/_has_cy.py
setup.py
test/base/test_utils.py

index cf68c1933b9037cf0fd26e058198ee7e79b98a91..072c78d2cfdec9744bd8c19aac90662ef1226099 100644 (file)
@@ -1,14 +1,32 @@
-# mypy: allow-untyped-defs, allow-untyped-calls
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+# mypy: ignore-errors
 
 import os
 import typing
 
+
+def _import_cy_extensions():
+    # all cython extension extension modules are treated as optional by the
+    # setup, so to ensure that all are compiled, all should be imported here
+    from ..cyextension import collections
+    from ..cyextension import immutabledict
+    from ..cyextension import processors
+    from ..cyextension import resultproxy
+    from ..cyextension import util
+
+    return (collections, immutabledict, processors, resultproxy, util)
+
+
 if not typing.TYPE_CHECKING:
     if os.environ.get("DISABLE_SQLALCHEMY_CEXT_RUNTIME"):
         HAS_CYEXTENSION = False
     else:
         try:
-            from ..cyextension import util  # noqa
+            _import_cy_extensions()
         except ImportError:
             HAS_CYEXTENSION = False
         else:
index ae71acd2c528006093b43616100ab4559244218b..e301cdb29bd944259fc0dc627b7dac855ecd0251 100644 (file)
--- a/setup.py
+++ b/setup.py
@@ -34,6 +34,7 @@ if HAS_CYTHON and IS_CPYTHON and not DISABLE_EXTENSION:
     assert _cy_Extension is not None
     assert _cy_build_ext is not None
 
+    # when adding a cython module, also update the imports in _has_cy
     cython_files = [
         "collections.pyx",
         "immutabledict.pyx",
index 98451cc4f1716b1abe3ba683e6639e2a97cb8efc..27945f23682ee9a10cb8d27a09e7c328bd04759b 100644 (file)
@@ -2,6 +2,7 @@
 
 import copy
 import inspect
+from pathlib import Path
 import pickle
 import sys
 
@@ -33,6 +34,8 @@ from sqlalchemy.util import langhelpers
 from sqlalchemy.util import preloaded
 from sqlalchemy.util import WeakSequence
 from sqlalchemy.util._collections import merge_lists_w_ordering
+from sqlalchemy.util._has_cy import _import_cy_extensions
+from sqlalchemy.util._has_cy import HAS_CYEXTENSION
 
 
 class WeakSequenceTest(fixtures.TestBase):
@@ -3346,3 +3349,18 @@ class MethodOveriddenTest(fixtures.TestBase):
                 pass
 
         is_true(util.method_is_overridden(HoHo(), Bat.bar))
+
+
+class CyExtensionTest(fixtures.TestBase):
+    @testing.only_if(lambda: HAS_CYEXTENSION, "No Cython")
+    def test_all_cyext_imported(self):
+        ext = _import_cy_extensions()
+        lib_folder = (Path(__file__).parent / ".." / ".." / "lib").resolve()
+        sa_folder = lib_folder / "sqlalchemy"
+        cython_files = [f.resolve() for f in sa_folder.glob("**/*.pyx")]
+        eq_(len(ext), len(cython_files))
+        names = {
+            ".".join(f.relative_to(lib_folder).parts).replace(".pyx", "")
+            for f in cython_files
+        }
+        eq_({m.__name__ for m in ext}, set(names))