From: Federico Caselli Date: Sun, 4 Sep 2022 19:44:46 +0000 (+0200) Subject: Improve compiled extension detection X-Git-Tag: rel_2_0_0b1~75 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=741534d840daeeba73aad1703b5bdeb3a0b86db9;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Improve compiled extension detection Ensure that all cython extension are imported by the compied detection logic. This is required since cython extensions moduels are marked as optional in the install, so it's possible that only some of them are compiled. The extensions are enabled only if all of them are correctly compiled Change-Id: I355cbac06f5c7a47d35661f42ebab3b0156c1965 --- diff --git a/lib/sqlalchemy/util/_has_cy.py b/lib/sqlalchemy/util/_has_cy.py index cf68c1933b..072c78d2cf 100644 --- a/lib/sqlalchemy/util/_has_cy.py +++ b/lib/sqlalchemy/util/_has_cy.py @@ -1,14 +1,32 @@ -# mypy: allow-untyped-defs, allow-untyped-calls +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors import os import typing + +def _import_cy_extensions(): + # all cython extension extension modules are treated as optional by the + # setup, so to ensure that all are compiled, all should be imported here + from ..cyextension import collections + from ..cyextension import immutabledict + from ..cyextension import processors + from ..cyextension import resultproxy + from ..cyextension import util + + return (collections, immutabledict, processors, resultproxy, util) + + if not typing.TYPE_CHECKING: if os.environ.get("DISABLE_SQLALCHEMY_CEXT_RUNTIME"): HAS_CYEXTENSION = False else: try: - from ..cyextension import util # noqa + _import_cy_extensions() except ImportError: HAS_CYEXTENSION = False else: diff --git a/setup.py b/setup.py index ae71acd2c5..e301cdb29b 100644 --- a/setup.py +++ b/setup.py @@ -34,6 +34,7 @@ if HAS_CYTHON and IS_CPYTHON and not DISABLE_EXTENSION: assert _cy_Extension is not None assert _cy_build_ext is not None + # when adding a cython module, also update the imports in _has_cy cython_files = [ "collections.pyx", "immutabledict.pyx", diff --git a/test/base/test_utils.py b/test/base/test_utils.py index 98451cc4f1..27945f2368 100644 --- a/test/base/test_utils.py +++ b/test/base/test_utils.py @@ -2,6 +2,7 @@ import copy import inspect +from pathlib import Path import pickle import sys @@ -33,6 +34,8 @@ from sqlalchemy.util import langhelpers from sqlalchemy.util import preloaded from sqlalchemy.util import WeakSequence from sqlalchemy.util._collections import merge_lists_w_ordering +from sqlalchemy.util._has_cy import _import_cy_extensions +from sqlalchemy.util._has_cy import HAS_CYEXTENSION class WeakSequenceTest(fixtures.TestBase): @@ -3346,3 +3349,18 @@ class MethodOveriddenTest(fixtures.TestBase): pass is_true(util.method_is_overridden(HoHo(), Bat.bar)) + + +class CyExtensionTest(fixtures.TestBase): + @testing.only_if(lambda: HAS_CYEXTENSION, "No Cython") + def test_all_cyext_imported(self): + ext = _import_cy_extensions() + lib_folder = (Path(__file__).parent / ".." / ".." / "lib").resolve() + sa_folder = lib_folder / "sqlalchemy" + cython_files = [f.resolve() for f in sa_folder.glob("**/*.pyx")] + eq_(len(ext), len(cython_files)) + names = { + ".".join(f.relative_to(lib_folder).parts).replace(".pyx", "") + for f in cython_files + } + eq_({m.__name__ for m in ext}, set(names))