From 0be89aaa38d06a9beced7f1bfe2987f4b6afebb8 Mon Sep 17 00:00:00 2001 From: Federico Caselli Date: Sat, 23 Sep 2023 23:39:42 +0200 Subject: [PATCH] Switch to cython pure python mode Replaces the pyx files with py files that can be both compiled by cython or imported as is by python. This avoids the need of duplicating the code to have a python only fallback. The cython files are also reorganized to be in the module they use instead of all being in the cyextension package, that has been removed. The performance is pretty much equal between main and this change. A detailed comparison is at this link https://docs.google.com/spreadsheets/d/1jkmGpnCyEcPyy6aRK9alElGjxlNHu44Wxjr4VrD99so/edit?usp=sharing Change-Id: Iaed232ea5dfb41534cc9f58f6ea2f912a93263af --- .gitignore | 6 + MANIFEST.in | 8 +- doc/build/conf.py | 2 - lib/sqlalchemy/cyextension/.gitignore | 5 - lib/sqlalchemy/cyextension/__init__.py | 6 - lib/sqlalchemy/cyextension/collections.pyx | 409 ------------- lib/sqlalchemy/cyextension/immutabledict.pxd | 8 - lib/sqlalchemy/cyextension/immutabledict.pyx | 133 ----- lib/sqlalchemy/cyextension/processors.pyx | 68 --- lib/sqlalchemy/cyextension/resultproxy.pyx | 102 ---- lib/sqlalchemy/cyextension/util.pyx | 91 --- lib/sqlalchemy/engine/_processors_cy.py | 92 +++ lib/sqlalchemy/engine/_py_processors.py | 136 ----- lib/sqlalchemy/engine/_py_row.py | 129 ----- lib/sqlalchemy/engine/_py_util.py | 74 --- lib/sqlalchemy/engine/_row_cy.py | 162 ++++++ lib/sqlalchemy/engine/_util_cy.py | 129 +++++ lib/sqlalchemy/engine/base.py | 38 +- lib/sqlalchemy/engine/processors.py | 101 ++-- lib/sqlalchemy/engine/result.py | 13 +- lib/sqlalchemy/engine/row.py | 9 +- lib/sqlalchemy/engine/util.py | 15 +- lib/sqlalchemy/orm/collections.py | 5 +- lib/sqlalchemy/sql/_py_util.py | 75 --- lib/sqlalchemy/sql/_util_cy.py | 108 ++++ lib/sqlalchemy/sql/visitors.py | 14 +- lib/sqlalchemy/testing/plugin/pytestplugin.py | 6 +- lib/sqlalchemy/util/_collections.py | 40 +- lib/sqlalchemy/util/_collections_cy.py | 528 +++++++++++++++++ lib/sqlalchemy/util/_has_cy.py | 40 -- lib/sqlalchemy/util/_has_cython.py | 44 ++ lib/sqlalchemy/util/_immutabledict_cy.py | 208 +++++++ lib/sqlalchemy/util/_py_collections.py | 541 ------------------ lib/sqlalchemy/util/cython.py | 61 ++ lib/sqlalchemy/util/langhelpers.py | 34 +- pyproject.toml | 7 +- setup.py | 28 +- test/aaa_profiling/test_memusage.py | 2 +- test/base/test_result.py | 37 +- test/base/test_utils.py | 37 +- test/engine/test_processors.py | 74 ++- test/perf/compiled_extensions.py | 439 ++++++++------ test/profiles.txt | 18 +- tools/cython_imports.py | 73 +++ tox.ini | 3 +- 45 files changed, 1932 insertions(+), 2226 deletions(-) delete mode 100644 lib/sqlalchemy/cyextension/.gitignore delete mode 100644 lib/sqlalchemy/cyextension/__init__.py delete mode 100644 lib/sqlalchemy/cyextension/collections.pyx delete mode 100644 lib/sqlalchemy/cyextension/immutabledict.pxd delete mode 100644 lib/sqlalchemy/cyextension/immutabledict.pyx delete mode 100644 lib/sqlalchemy/cyextension/processors.pyx delete mode 100644 lib/sqlalchemy/cyextension/resultproxy.pyx delete mode 100644 lib/sqlalchemy/cyextension/util.pyx create mode 100644 lib/sqlalchemy/engine/_processors_cy.py delete mode 100644 lib/sqlalchemy/engine/_py_processors.py delete mode 100644 lib/sqlalchemy/engine/_py_row.py delete mode 100644 lib/sqlalchemy/engine/_py_util.py create mode 100644 lib/sqlalchemy/engine/_row_cy.py create mode 100644 lib/sqlalchemy/engine/_util_cy.py delete mode 100644 lib/sqlalchemy/sql/_py_util.py create mode 100644 lib/sqlalchemy/sql/_util_cy.py create mode 100644 lib/sqlalchemy/util/_collections_cy.py delete mode 100644 lib/sqlalchemy/util/_has_cy.py create mode 100644 lib/sqlalchemy/util/_has_cython.py create mode 100644 lib/sqlalchemy/util/_immutabledict_cy.py delete mode 100644 lib/sqlalchemy/util/_py_collections.py create mode 100644 lib/sqlalchemy/util/cython.py create mode 100644 tools/cython_imports.py diff --git a/.gitignore b/.gitignore index 13b40c819a..f2544502f3 100644 --- a/.gitignore +++ b/.gitignore @@ -40,3 +40,9 @@ test/test_schema.db /db_idents.txt .DS_Store .vs + +# cython complied files +/lib/**/*.c +/lib/**/*.cpp +# cython annotated output +/lib/**/*.html diff --git a/MANIFEST.in b/MANIFEST.in index 7a272fe6b4..22a39e89c7 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -8,12 +8,12 @@ recursive-include tools *.py # for some reason in some environments stale Cython .c files # are being pulled in, these should never be in a dist -exclude lib/sqlalchemy/cyextension/*.c -exclude lib/sqlalchemy/cyextension/*.so +exclude lib/sqlalchemy/**/*.c +exclude lib/sqlalchemy/**/*.so -# include the pyx and pxd extensions, which otherwise +# include the pxd extensions, which otherwise # don't come in if --with-cextensions isn't specified. -recursive-include lib *.pyx *.pxd *.txt *.typed +recursive-include lib *.pxd *.txt *.typed include README* AUTHORS LICENSE CHANGES* tox.ini prune doc/build/output diff --git a/doc/build/conf.py b/doc/build/conf.py index bda3ff1d3c..5e89280be8 100644 --- a/doc/build/conf.py +++ b/doc/build/conf.py @@ -25,8 +25,6 @@ sys.path.insert(0, os.path.abspath("../..")) # examples # sys.path.insert(0, os.path.abspath(".")) -os.environ["DISABLE_SQLALCHEMY_CEXT_RUNTIME"] = "true" - # -- General configuration -------------------------------------------------- # If your documentation needs a minimal Sphinx version, state it here. diff --git a/lib/sqlalchemy/cyextension/.gitignore b/lib/sqlalchemy/cyextension/.gitignore deleted file mode 100644 index dfc107eafc..0000000000 --- a/lib/sqlalchemy/cyextension/.gitignore +++ /dev/null @@ -1,5 +0,0 @@ -# cython complied files -*.c -*.o -# cython annotated output -*.html \ No newline at end of file diff --git a/lib/sqlalchemy/cyextension/__init__.py b/lib/sqlalchemy/cyextension/__init__.py deleted file mode 100644 index 88a4d90396..0000000000 --- a/lib/sqlalchemy/cyextension/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -# cyextension/__init__.py -# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors -# -# -# This module is part of SQLAlchemy and is released under -# the MIT License: https://www.opensource.org/licenses/mit-license.php diff --git a/lib/sqlalchemy/cyextension/collections.pyx b/lib/sqlalchemy/cyextension/collections.pyx deleted file mode 100644 index 86d24852b3..0000000000 --- a/lib/sqlalchemy/cyextension/collections.pyx +++ /dev/null @@ -1,409 +0,0 @@ -# cyextension/collections.pyx -# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors -# -# -# This module is part of SQLAlchemy and is released under -# the MIT License: https://www.opensource.org/licenses/mit-license.php -cimport cython -from cpython.long cimport PyLong_FromLongLong -from cpython.set cimport PySet_Add - -from collections.abc import Collection -from itertools import filterfalse - -cdef bint add_not_present(set seen, object item, hashfunc): - hash_value = hashfunc(item) - if hash_value not in seen: - PySet_Add(seen, hash_value) - return True - else: - return False - -cdef list cunique_list(seq, hashfunc=None): - cdef set seen = set() - if not hashfunc: - return [x for x in seq if x not in seen and not PySet_Add(seen, x)] - else: - return [x for x in seq if add_not_present(seen, x, hashfunc)] - -def unique_list(seq, hashfunc=None): - return cunique_list(seq, hashfunc) - -cdef class OrderedSet(set): - - cdef list _list - - @classmethod - def __class_getitem__(cls, key): - return cls - - def __init__(self, d=None): - set.__init__(self) - if d is not None: - self._list = cunique_list(d) - set.update(self, self._list) - else: - self._list = [] - - cpdef OrderedSet copy(self): - cdef OrderedSet cp = OrderedSet.__new__(OrderedSet) - cp._list = list(self._list) - set.update(cp, cp._list) - return cp - - @cython.final - cdef OrderedSet _from_list(self, list new_list): - cdef OrderedSet new = OrderedSet.__new__(OrderedSet) - new._list = new_list - set.update(new, new_list) - return new - - def add(self, element): - if element not in self: - self._list.append(element) - PySet_Add(self, element) - - def remove(self, element): - # set.remove will raise if element is not in self - set.remove(self, element) - self._list.remove(element) - - def pop(self): - try: - value = self._list.pop() - except IndexError: - raise KeyError("pop from an empty set") from None - set.remove(self, value) - return value - - def insert(self, Py_ssize_t pos, element): - if element not in self: - self._list.insert(pos, element) - PySet_Add(self, element) - - def discard(self, element): - if element in self: - set.remove(self, element) - self._list.remove(element) - - def clear(self): - set.clear(self) - self._list = [] - - def __getitem__(self, key): - return self._list[key] - - def __iter__(self): - return iter(self._list) - - def __add__(self, other): - return self.union(other) - - def __repr__(self): - return "%s(%r)" % (self.__class__.__name__, self._list) - - __str__ = __repr__ - - def update(self, *iterables): - for iterable in iterables: - for e in iterable: - if e not in self: - self._list.append(e) - set.add(self, e) - - def __ior__(self, iterable): - self.update(iterable) - return self - - def union(self, *other): - result = self.copy() - result.update(*other) - return result - - def __or__(self, other): - return self.union(other) - - def intersection(self, *other): - cdef set other_set = set.intersection(self, *other) - return self._from_list([a for a in self._list if a in other_set]) - - def __and__(self, other): - return self.intersection(other) - - def symmetric_difference(self, other): - cdef set other_set - if isinstance(other, set): - other_set = other - collection = other_set - elif isinstance(other, Collection): - collection = other - other_set = set(other) - else: - collection = list(other) - other_set = set(collection) - result = self._from_list([a for a in self._list if a not in other_set]) - result.update(a for a in collection if a not in self) - return result - - def __xor__(self, other): - return self.symmetric_difference(other) - - def difference(self, *other): - cdef set other_set = set.difference(self, *other) - return self._from_list([a for a in self._list if a in other_set]) - - def __sub__(self, other): - return self.difference(other) - - def intersection_update(self, *other): - set.intersection_update(self, *other) - self._list = [a for a in self._list if a in self] - - def __iand__(self, other): - self.intersection_update(other) - return self - - cpdef symmetric_difference_update(self, other): - collection = other if isinstance(other, Collection) else list(other) - set.symmetric_difference_update(self, collection) - self._list = [a for a in self._list if a in self] - self._list += [a for a in collection if a in self] - - def __ixor__(self, other): - self.symmetric_difference_update(other) - return self - - def difference_update(self, *other): - set.difference_update(self, *other) - self._list = [a for a in self._list if a in self] - - def __isub__(self, other): - self.difference_update(other) - return self - -cdef object cy_id(object item): - return PyLong_FromLongLong( (item)) - -# NOTE: cython 0.x will call __add__, __sub__, etc with the parameter swapped -# instead of the __rmeth__, so they need to check that also self is of the -# correct type. This is fixed in cython 3.x. See: -# https://docs.cython.org/en/latest/src/userguide/special_methods.html#arithmetic-methods -cdef class IdentitySet: - """A set that considers only object id() for uniqueness. - - This strategy has edge cases for builtin types- it's possible to have - two 'foo' strings in one of these sets, for example. Use sparingly. - - """ - - cdef dict _members - - def __init__(self, iterable=None): - self._members = {} - if iterable: - self.update(iterable) - - def add(self, value): - self._members[cy_id(value)] = value - - def __contains__(self, value): - return cy_id(value) in self._members - - cpdef remove(self, value): - del self._members[cy_id(value)] - - def discard(self, value): - try: - self.remove(value) - except KeyError: - pass - - def pop(self): - cdef tuple pair - try: - pair = self._members.popitem() - return pair[1] - except KeyError: - raise KeyError("pop from an empty set") - - def clear(self): - self._members.clear() - - def __eq__(self, other): - cdef IdentitySet other_ - if isinstance(other, IdentitySet): - other_ = other - return self._members == other_._members - else: - return False - - def __ne__(self, other): - cdef IdentitySet other_ - if isinstance(other, IdentitySet): - other_ = other - return self._members != other_._members - else: - return True - - cpdef issubset(self, iterable): - cdef IdentitySet other - if isinstance(iterable, self.__class__): - other = iterable - else: - other = self.__class__(iterable) - - if len(self) > len(other): - return False - for m in filterfalse(other._members.__contains__, self._members): - return False - return True - - def __le__(self, other): - if not isinstance(other, IdentitySet): - return NotImplemented - return self.issubset(other) - - def __lt__(self, other): - if not isinstance(other, IdentitySet): - return NotImplemented - return len(self) < len(other) and self.issubset(other) - - cpdef issuperset(self, iterable): - cdef IdentitySet other - if isinstance(iterable, self.__class__): - other = iterable - else: - other = self.__class__(iterable) - - if len(self) < len(other): - return False - for m in filterfalse(self._members.__contains__, other._members): - return False - return True - - def __ge__(self, other): - if not isinstance(other, IdentitySet): - return NotImplemented - return self.issuperset(other) - - def __gt__(self, other): - if not isinstance(other, IdentitySet): - return NotImplemented - return len(self) > len(other) and self.issuperset(other) - - cpdef IdentitySet union(self, iterable): - cdef IdentitySet result = self.__class__() - result._members.update(self._members) - result.update(iterable) - return result - - def __or__(self, other): - if not isinstance(other, IdentitySet) or not isinstance(self, IdentitySet): - return NotImplemented - return self.union(other) - - cpdef update(self, iterable): - for obj in iterable: - self._members[cy_id(obj)] = obj - - def __ior__(self, other): - if not isinstance(other, IdentitySet): - return NotImplemented - self.update(other) - return self - - cpdef IdentitySet difference(self, iterable): - cdef IdentitySet result = self.__new__(self.__class__) - if isinstance(iterable, self.__class__): - other = (iterable)._members - else: - other = {cy_id(obj) for obj in iterable} - result._members = {k:v for k, v in self._members.items() if k not in other} - return result - - def __sub__(self, other): - if not isinstance(other, IdentitySet) or not isinstance(self, IdentitySet): - return NotImplemented - return self.difference(other) - - cpdef difference_update(self, iterable): - cdef IdentitySet other = self.difference(iterable) - self._members = other._members - - def __isub__(self, other): - if not isinstance(other, IdentitySet): - return NotImplemented - self.difference_update(other) - return self - - cpdef IdentitySet intersection(self, iterable): - cdef IdentitySet result = self.__new__(self.__class__) - if isinstance(iterable, self.__class__): - other = (iterable)._members - else: - other = {cy_id(obj) for obj in iterable} - result._members = {k: v for k, v in self._members.items() if k in other} - return result - - def __and__(self, other): - if not isinstance(other, IdentitySet) or not isinstance(self, IdentitySet): - return NotImplemented - return self.intersection(other) - - cpdef intersection_update(self, iterable): - cdef IdentitySet other = self.intersection(iterable) - self._members = other._members - - def __iand__(self, other): - if not isinstance(other, IdentitySet): - return NotImplemented - self.intersection_update(other) - return self - - cpdef IdentitySet symmetric_difference(self, iterable): - cdef IdentitySet result = self.__new__(self.__class__) - cdef dict other - if isinstance(iterable, self.__class__): - other = (iterable)._members - else: - other = {cy_id(obj): obj for obj in iterable} - result._members = {k: v for k, v in self._members.items() if k not in other} - result._members.update( - [(k, v) for k, v in other.items() if k not in self._members] - ) - return result - - def __xor__(self, other): - if not isinstance(other, IdentitySet) or not isinstance(self, IdentitySet): - return NotImplemented - return self.symmetric_difference(other) - - cpdef symmetric_difference_update(self, iterable): - cdef IdentitySet other = self.symmetric_difference(iterable) - self._members = other._members - - def __ixor__(self, other): - if not isinstance(other, IdentitySet): - return NotImplemented - self.symmetric_difference(other) - return self - - cpdef IdentitySet copy(self): - cdef IdentitySet cp = self.__new__(self.__class__) - cp._members = self._members.copy() - return cp - - def __copy__(self): - return self.copy() - - def __len__(self): - return len(self._members) - - def __iter__(self): - return iter(self._members.values()) - - def __hash__(self): - raise TypeError("set objects are unhashable") - - def __repr__(self): - return "%s(%r)" % (type(self).__name__, list(self._members.values())) diff --git a/lib/sqlalchemy/cyextension/immutabledict.pxd b/lib/sqlalchemy/cyextension/immutabledict.pxd deleted file mode 100644 index 76f2289316..0000000000 --- a/lib/sqlalchemy/cyextension/immutabledict.pxd +++ /dev/null @@ -1,8 +0,0 @@ -# cyextension/immutabledict.pxd -# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors -# -# -# This module is part of SQLAlchemy and is released under -# the MIT License: https://www.opensource.org/licenses/mit-license.php -cdef class immutabledict(dict): - pass diff --git a/lib/sqlalchemy/cyextension/immutabledict.pyx b/lib/sqlalchemy/cyextension/immutabledict.pyx deleted file mode 100644 index b37eccc4c3..0000000000 --- a/lib/sqlalchemy/cyextension/immutabledict.pyx +++ /dev/null @@ -1,133 +0,0 @@ -# cyextension/immutabledict.pyx -# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors -# -# -# This module is part of SQLAlchemy and is released under -# the MIT License: https://www.opensource.org/licenses/mit-license.php -from cpython.dict cimport PyDict_New, PyDict_Update, PyDict_Size - - -def _readonly_fn(obj): - raise TypeError( - "%s object is immutable and/or readonly" % obj.__class__.__name__) - - -def _immutable_fn(obj): - raise TypeError( - "%s object is immutable" % obj.__class__.__name__) - - -class ReadOnlyContainer: - - __slots__ = () - - def _readonly(self, *a,**kw): - _readonly_fn(self) - - __delitem__ = __setitem__ = __setattr__ = _readonly - - -class ImmutableDictBase(dict): - def _immutable(self, *a,**kw): - _immutable_fn(self) - - @classmethod - def __class_getitem__(cls, key): - return cls - - __delitem__ = __setitem__ = __setattr__ = _immutable - clear = pop = popitem = setdefault = update = _immutable - - -cdef class immutabledict(dict): - def __repr__(self): - return f"immutabledict({dict.__repr__(self)})" - - @classmethod - def __class_getitem__(cls, key): - return cls - - def union(self, *args, **kw): - cdef dict to_merge = None - cdef immutabledict result - cdef Py_ssize_t args_len = len(args) - if args_len > 1: - raise TypeError( - f'union expected at most 1 argument, got {args_len}' - ) - if args_len == 1: - attribute = args[0] - if isinstance(attribute, dict): - to_merge = attribute - if to_merge is None: - to_merge = dict(*args, **kw) - - if PyDict_Size(to_merge) == 0: - return self - - # new + update is faster than immutabledict(self) - result = immutabledict() - PyDict_Update(result, self) - PyDict_Update(result, to_merge) - return result - - def merge_with(self, *other): - cdef immutabledict result = None - cdef object d - cdef bint update = False - if not other: - return self - for d in other: - if d: - if update == False: - update = True - # new + update is faster than immutabledict(self) - result = immutabledict() - PyDict_Update(result, self) - PyDict_Update( - result, (d if isinstance(d, dict) else dict(d)) - ) - - return self if update == False else result - - def copy(self): - return self - - def __reduce__(self): - return immutabledict, (dict(self), ) - - def __delitem__(self, k): - _immutable_fn(self) - - def __setitem__(self, k, v): - _immutable_fn(self) - - def __setattr__(self, k, v): - _immutable_fn(self) - - def clear(self, *args, **kw): - _immutable_fn(self) - - def pop(self, *args, **kw): - _immutable_fn(self) - - def popitem(self, *args, **kw): - _immutable_fn(self) - - def setdefault(self, *args, **kw): - _immutable_fn(self) - - def update(self, *args, **kw): - _immutable_fn(self) - - # PEP 584 - def __ior__(self, other): - _immutable_fn(self) - - def __or__(self, other): - return immutabledict(dict.__or__(self, other)) - - def __ror__(self, other): - # NOTE: this is used only in cython 3.x; - # version 0.x will call __or__ with args inversed - return immutabledict(dict.__ror__(self, other)) diff --git a/lib/sqlalchemy/cyextension/processors.pyx b/lib/sqlalchemy/cyextension/processors.pyx deleted file mode 100644 index 3d714569fa..0000000000 --- a/lib/sqlalchemy/cyextension/processors.pyx +++ /dev/null @@ -1,68 +0,0 @@ -# cyextension/processors.pyx -# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors -# -# -# This module is part of SQLAlchemy and is released under -# the MIT License: https://www.opensource.org/licenses/mit-license.php -import datetime -from datetime import datetime as datetime_cls -from datetime import time as time_cls -from datetime import date as date_cls -import re - -from cpython.object cimport PyObject_Str -from cpython.unicode cimport PyUnicode_AsASCIIString, PyUnicode_Check, PyUnicode_Decode -from libc.stdio cimport sscanf - - -def int_to_boolean(value): - if value is None: - return None - return True if value else False - -def to_str(value): - return PyObject_Str(value) if value is not None else None - -def to_float(value): - return float(value) if value is not None else None - -cdef inline bytes to_bytes(object value, str type_name): - try: - return PyUnicode_AsASCIIString(value) - except Exception as e: - raise ValueError( - f"Couldn't parse {type_name} string '{value!r}' " - "- value is not a string." - ) from e - -def str_to_datetime(value): - if value is not None: - value = datetime_cls.fromisoformat(value) - return value - -def str_to_time(value): - if value is not None: - value = time_cls.fromisoformat(value) - return value - - -def str_to_date(value): - if value is not None: - value = date_cls.fromisoformat(value) - return value - - - -cdef class DecimalResultProcessor: - cdef object type_ - cdef str format_ - - def __cinit__(self, type_, format_): - self.type_ = type_ - self.format_ = format_ - - def process(self, object value): - if value is None: - return None - else: - return self.type_(self.format_ % value) diff --git a/lib/sqlalchemy/cyextension/resultproxy.pyx b/lib/sqlalchemy/cyextension/resultproxy.pyx deleted file mode 100644 index b6e357a1f3..0000000000 --- a/lib/sqlalchemy/cyextension/resultproxy.pyx +++ /dev/null @@ -1,102 +0,0 @@ -# cyextension/resultproxy.pyx -# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors -# -# -# This module is part of SQLAlchemy and is released under -# the MIT License: https://www.opensource.org/licenses/mit-license.php -import operator - -cdef class BaseRow: - cdef readonly object _parent - cdef readonly dict _key_to_index - cdef readonly tuple _data - - def __init__(self, object parent, object processors, dict key_to_index, object data): - """Row objects are constructed by CursorResult objects.""" - - self._parent = parent - - self._key_to_index = key_to_index - - if processors: - self._data = _apply_processors(processors, data) - else: - self._data = tuple(data) - - def __reduce__(self): - return ( - rowproxy_reconstructor, - (self.__class__, self.__getstate__()), - ) - - def __getstate__(self): - return {"_parent": self._parent, "_data": self._data} - - def __setstate__(self, dict state): - parent = state["_parent"] - self._parent = parent - self._data = state["_data"] - self._key_to_index = parent._key_to_index - - def _values_impl(self): - return list(self) - - def __iter__(self): - return iter(self._data) - - def __len__(self): - return len(self._data) - - def __hash__(self): - return hash(self._data) - - def __getitem__(self, index): - return self._data[index] - - def _get_by_key_impl_mapping(self, key): - return self._get_by_key_impl(key, 0) - - cdef _get_by_key_impl(self, object key, int attr_err): - index = self._key_to_index.get(key) - if index is not None: - return self._data[index] - self._parent._key_not_found(key, attr_err != 0) - - def __getattr__(self, name): - return self._get_by_key_impl(name, 1) - - def _to_tuple_instance(self): - return self._data - - -cdef tuple _apply_processors(proc, data): - res = [] - for i in range(len(proc)): - p = proc[i] - if p is None: - res.append(data[i]) - else: - res.append(p(data[i])) - return tuple(res) - - -def rowproxy_reconstructor(cls, state): - obj = cls.__new__(cls) - obj.__setstate__(state) - return obj - - -cdef int is_contiguous(tuple indexes): - cdef int i - for i in range(1, len(indexes)): - if indexes[i-1] != indexes[i] -1: - return 0 - return 1 - - -def tuplegetter(*indexes): - if len(indexes) == 1 or is_contiguous(indexes) != 0: - # slice form is faster but returns a list if input is list - return operator.itemgetter(slice(indexes[0], indexes[-1] + 1)) - else: - return operator.itemgetter(*indexes) diff --git a/lib/sqlalchemy/cyextension/util.pyx b/lib/sqlalchemy/cyextension/util.pyx deleted file mode 100644 index cb17acd69c..0000000000 --- a/lib/sqlalchemy/cyextension/util.pyx +++ /dev/null @@ -1,91 +0,0 @@ -# cyextension/util.pyx -# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors -# -# -# This module is part of SQLAlchemy and is released under -# the MIT License: https://www.opensource.org/licenses/mit-license.php -from collections.abc import Mapping - -from sqlalchemy import exc - -cdef tuple _Empty_Tuple = () - -cdef inline bint _mapping_or_tuple(object value): - return isinstance(value, dict) or isinstance(value, tuple) or isinstance(value, Mapping) - -cdef inline bint _check_item(object params) except 0: - cdef object item - cdef bint ret = 1 - if params: - item = params[0] - if not _mapping_or_tuple(item): - ret = 0 - raise exc.ArgumentError( - "List argument must consist only of tuples or dictionaries" - ) - return ret - -def _distill_params_20(object params): - if params is None: - return _Empty_Tuple - elif isinstance(params, list) or isinstance(params, tuple): - _check_item(params) - return params - elif isinstance(params, dict) or isinstance(params, Mapping): - return [params] - else: - raise exc.ArgumentError("mapping or list expected for parameters") - - -def _distill_raw_params(object params): - if params is None: - return _Empty_Tuple - elif isinstance(params, list): - _check_item(params) - return params - elif _mapping_or_tuple(params): - return [params] - else: - raise exc.ArgumentError("mapping or sequence expected for parameters") - -cdef class prefix_anon_map(dict): - def __missing__(self, str key): - cdef str derived - cdef int anonymous_counter - cdef dict self_dict = self - - derived = key.split(" ", 1)[1] - - anonymous_counter = self_dict.get(derived, 1) - self_dict[derived] = anonymous_counter + 1 - value = f"{derived}_{anonymous_counter}" - self_dict[key] = value - return value - - -cdef class cache_anon_map(dict): - cdef int _index - - def __init__(self): - self._index = 0 - - def get_anon(self, obj): - cdef long long idself - cdef str id_ - cdef dict self_dict = self - - idself = id(obj) - if idself in self_dict: - return self_dict[idself], True - else: - id_ = self.__missing__(idself) - return id_, False - - def __missing__(self, key): - cdef str val - cdef dict self_dict = self - - self_dict[key] = val = str(self._index) - self._index += 1 - return val - diff --git a/lib/sqlalchemy/engine/_processors_cy.py b/lib/sqlalchemy/engine/_processors_cy.py new file mode 100644 index 0000000000..7909fd3668 --- /dev/null +++ b/lib/sqlalchemy/engine/_processors_cy.py @@ -0,0 +1,92 @@ +# engine/_processors_cy.py +# Copyright (C) 2010-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: disable-error-code="misc" +from __future__ import annotations + +from datetime import date as date_cls +from datetime import datetime as datetime_cls +from datetime import time as time_cls +from typing import Any +from typing import Optional + +# START GENERATED CYTHON IMPORT +# This section is automatically generated by the script tools/cython_imports.py +try: + # NOTE: the cython compiler needs this "import cython" in the file, it + # can't be only "from sqlalchemy.util import cython" with the fallback + # in that module + import cython +except ModuleNotFoundError: + from sqlalchemy.util import cython + + +def _is_compiled() -> bool: + """Utility function to indicate if this module is compiled or not.""" + return cython.compiled # type: ignore[no-any-return] + + +# END GENERATED CYTHON IMPORT + + +@cython.annotation_typing(False) +def int_to_boolean(value: Any) -> Optional[bool]: + if value is None: + return None + return True if value else False + + +@cython.annotation_typing(False) +def to_str(value: Any) -> Optional[str]: + if value is None: + return None + return str(value) + + +@cython.annotation_typing(False) +def to_float(value: Any) -> Optional[float]: + if value is None: + return None + return float(value) + + +@cython.annotation_typing(False) +def str_to_datetime(value: Optional[str]) -> Optional[datetime_cls]: + if value is None: + return None + return datetime_cls.fromisoformat(value) + + +@cython.annotation_typing(False) +def str_to_time(value: Optional[str]) -> Optional[time_cls]: + if value is None: + return None + return time_cls.fromisoformat(value) + + +@cython.annotation_typing(False) +def str_to_date(value: Optional[str]) -> Optional[date_cls]: + if value is None: + return None + return date_cls.fromisoformat(value) + + +@cython.cclass +class to_decimal_processor_factory: + type_: type + format_: str + + __slots__ = ("type_", "format_") + + def __init__(self, type_: type, scale: int): + self.type_ = type_ + self.format_ = f"%.{scale}f" + + def __call__(self, value: Optional[Any]) -> object: + if value is None: + return None + else: + return self.type_(self.format_ % value) diff --git a/lib/sqlalchemy/engine/_py_processors.py b/lib/sqlalchemy/engine/_py_processors.py deleted file mode 100644 index 2cc35b501e..0000000000 --- a/lib/sqlalchemy/engine/_py_processors.py +++ /dev/null @@ -1,136 +0,0 @@ -# engine/_py_processors.py -# Copyright (C) 2010-2024 the SQLAlchemy authors and contributors -# -# Copyright (C) 2010 Gaetan de Menten gdementen@gmail.com -# -# This module is part of SQLAlchemy and is released under -# the MIT License: https://www.opensource.org/licenses/mit-license.php - -"""defines generic type conversion functions, as used in bind and result -processors. - -They all share one common characteristic: None is passed through unchanged. - -""" - -from __future__ import annotations - -import datetime -from datetime import date as date_cls -from datetime import datetime as datetime_cls -from datetime import time as time_cls -from decimal import Decimal -import typing -from typing import Any -from typing import Callable -from typing import Optional -from typing import Type -from typing import TypeVar -from typing import Union - - -_DT = TypeVar( - "_DT", bound=Union[datetime.datetime, datetime.time, datetime.date] -) - - -def str_to_datetime_processor_factory( - regexp: typing.Pattern[str], type_: Callable[..., _DT] -) -> Callable[[Optional[str]], Optional[_DT]]: - rmatch = regexp.match - # Even on python2.6 datetime.strptime is both slower than this code - # and it does not support microseconds. - has_named_groups = bool(regexp.groupindex) - - def process(value: Optional[str]) -> Optional[_DT]: - if value is None: - return None - else: - try: - m = rmatch(value) - except TypeError as err: - raise ValueError( - "Couldn't parse %s string '%r' " - "- value is not a string." % (type_.__name__, value) - ) from err - - if m is None: - raise ValueError( - "Couldn't parse %s string: " - "'%s'" % (type_.__name__, value) - ) - if has_named_groups: - groups = m.groupdict(0) - return type_( - **dict( - list( - zip( - iter(groups.keys()), - list(map(int, iter(groups.values()))), - ) - ) - ) - ) - else: - return type_(*list(map(int, m.groups(0)))) - - return process - - -def to_decimal_processor_factory( - target_class: Type[Decimal], scale: int -) -> Callable[[Optional[float]], Optional[Decimal]]: - fstring = "%%.%df" % scale - - def process(value: Optional[float]) -> Optional[Decimal]: - if value is None: - return None - else: - return target_class(fstring % value) - - return process - - -def to_float(value: Optional[Union[int, float]]) -> Optional[float]: - if value is None: - return None - else: - return float(value) - - -def to_str(value: Optional[Any]) -> Optional[str]: - if value is None: - return None - else: - return str(value) - - -def int_to_boolean(value: Optional[int]) -> Optional[bool]: - if value is None: - return None - else: - return bool(value) - - -def str_to_datetime(value: Optional[str]) -> Optional[datetime.datetime]: - if value is not None: - dt_value = datetime_cls.fromisoformat(value) - else: - dt_value = None - return dt_value - - -def str_to_time(value: Optional[str]) -> Optional[datetime.time]: - if value is not None: - dt_value = time_cls.fromisoformat(value) - else: - dt_value = None - return dt_value - - -def str_to_date(value: Optional[str]) -> Optional[datetime.date]: - if value is not None: - dt_value = date_cls.fromisoformat(value) - else: - dt_value = None - return dt_value diff --git a/lib/sqlalchemy/engine/_py_row.py b/lib/sqlalchemy/engine/_py_row.py deleted file mode 100644 index 94ba85f2c2..0000000000 --- a/lib/sqlalchemy/engine/_py_row.py +++ /dev/null @@ -1,129 +0,0 @@ -# engine/_py_row.py -# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors -# -# -# This module is part of SQLAlchemy and is released under -# the MIT License: https://www.opensource.org/licenses/mit-license.php -from __future__ import annotations - -import operator -import typing -from typing import Any -from typing import Callable -from typing import Dict -from typing import Iterator -from typing import List -from typing import Mapping -from typing import Optional -from typing import Tuple -from typing import Type - -from ..util.typing import TupleAny - -if typing.TYPE_CHECKING: - from .result import _KeyType - from .result import _ProcessorsType - from .result import _TupleGetterType - from .result import ResultMetaData - -MD_INDEX = 0 # integer index in cursor.description - - -class BaseRow: - __slots__ = ("_parent", "_data", "_key_to_index") - - _parent: ResultMetaData - _key_to_index: Mapping[_KeyType, int] - _data: TupleAny - - def __init__( - self, - parent: ResultMetaData, - processors: Optional[_ProcessorsType], - key_to_index: Mapping[_KeyType, int], - data: TupleAny, - ): - """Row objects are constructed by CursorResult objects.""" - object.__setattr__(self, "_parent", parent) - - object.__setattr__(self, "_key_to_index", key_to_index) - - if processors: - object.__setattr__( - self, - "_data", - tuple( - [ - proc(value) if proc else value - for proc, value in zip(processors, data) - ] - ), - ) - else: - object.__setattr__(self, "_data", tuple(data)) - - def __reduce__(self) -> Tuple[Callable[..., BaseRow], Tuple[Any, ...]]: - return ( - rowproxy_reconstructor, - (self.__class__, self.__getstate__()), - ) - - def __getstate__(self) -> Dict[str, Any]: - return {"_parent": self._parent, "_data": self._data} - - def __setstate__(self, state: Dict[str, Any]) -> None: - parent = state["_parent"] - object.__setattr__(self, "_parent", parent) - object.__setattr__(self, "_data", state["_data"]) - object.__setattr__(self, "_key_to_index", parent._key_to_index) - - def _values_impl(self) -> List[Any]: - return list(self) - - def __iter__(self) -> Iterator[Any]: - return iter(self._data) - - def __len__(self) -> int: - return len(self._data) - - def __hash__(self) -> int: - return hash(self._data) - - def __getitem__(self, key: Any) -> Any: - return self._data[key] - - def _get_by_key_impl_mapping(self, key: str) -> Any: - try: - return self._data[self._key_to_index[key]] - except KeyError: - pass - self._parent._key_not_found(key, False) - - def __getattr__(self, name: str) -> Any: - try: - return self._data[self._key_to_index[name]] - except KeyError: - pass - self._parent._key_not_found(name, True) - - def _to_tuple_instance(self) -> Tuple[Any, ...]: - return self._data - - -# This reconstructor is necessary so that pickles with the Cy extension or -# without use the same Binary format. -def rowproxy_reconstructor( - cls: Type[BaseRow], state: Dict[str, Any] -) -> BaseRow: - obj = cls.__new__(cls) - obj.__setstate__(state) - return obj - - -def tuplegetter(*indexes: int) -> _TupleGetterType: - if len(indexes) != 1: - for i in range(1, len(indexes)): - if indexes[i - 1] != indexes[i] - 1: - return operator.itemgetter(*indexes) - # slice form is faster but returns a list if input is list - return operator.itemgetter(slice(indexes[0], indexes[-1] + 1)) diff --git a/lib/sqlalchemy/engine/_py_util.py b/lib/sqlalchemy/engine/_py_util.py deleted file mode 100644 index 2be4322abb..0000000000 --- a/lib/sqlalchemy/engine/_py_util.py +++ /dev/null @@ -1,74 +0,0 @@ -# engine/_py_util.py -# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors -# -# -# This module is part of SQLAlchemy and is released under -# the MIT License: https://www.opensource.org/licenses/mit-license.php -from __future__ import annotations - -import typing -from typing import Any -from typing import Mapping -from typing import Optional -from typing import Tuple - -from .. import exc - -if typing.TYPE_CHECKING: - from .interfaces import _CoreAnyExecuteParams - from .interfaces import _CoreMultiExecuteParams - from .interfaces import _DBAPIAnyExecuteParams - from .interfaces import _DBAPIMultiExecuteParams - - -_no_tuple: Tuple[Any, ...] = () - - -def _distill_params_20( - params: Optional[_CoreAnyExecuteParams], -) -> _CoreMultiExecuteParams: - if params is None: - return _no_tuple - # Assume list is more likely than tuple - elif isinstance(params, list) or isinstance(params, tuple): - # collections_abc.MutableSequence): # avoid abc.__instancecheck__ - if params and not isinstance(params[0], (tuple, Mapping)): - raise exc.ArgumentError( - "List argument must consist only of tuples or dictionaries" - ) - - return params - elif isinstance(params, dict) or isinstance( - # only do immutabledict or abc.__instancecheck__ for Mapping after - # we've checked for plain dictionaries and would otherwise raise - params, - Mapping, - ): - return [params] - else: - raise exc.ArgumentError("mapping or list expected for parameters") - - -def _distill_raw_params( - params: Optional[_DBAPIAnyExecuteParams], -) -> _DBAPIMultiExecuteParams: - if params is None: - return _no_tuple - elif isinstance(params, list): - # collections_abc.MutableSequence): # avoid abc.__instancecheck__ - if params and not isinstance(params[0], (tuple, Mapping)): - raise exc.ArgumentError( - "List argument must consist only of tuples or dictionaries" - ) - - return params - elif isinstance(params, (tuple, dict)) or isinstance( - # only do abc.__instancecheck__ for Mapping after we've checked - # for plain dictionaries and would otherwise raise - params, - Mapping, - ): - # cast("Union[List[Mapping[str, Any]], Tuple[Any, ...]]", [params]) - return [params] # type: ignore - else: - raise exc.ArgumentError("mapping or sequence expected for parameters") diff --git a/lib/sqlalchemy/engine/_row_cy.py b/lib/sqlalchemy/engine/_row_cy.py new file mode 100644 index 0000000000..903bc5b93e --- /dev/null +++ b/lib/sqlalchemy/engine/_row_cy.py @@ -0,0 +1,162 @@ +# engine/_row_cy.py +# Copyright (C) 2010-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: disable-error-code="misc" +from __future__ import annotations + +from typing import Any +from typing import Dict +from typing import Iterator +from typing import List +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from .result import _KeyType + from .result import _ProcessorsType + from .result import ResultMetaData + +# START GENERATED CYTHON IMPORT +# This section is automatically generated by the script tools/cython_imports.py +try: + # NOTE: the cython compiler needs this "import cython" in the file, it + # can't be only "from sqlalchemy.util import cython" with the fallback + # in that module + import cython +except ModuleNotFoundError: + from sqlalchemy.util import cython + + +def _is_compiled() -> bool: + """Utility function to indicate if this module is compiled or not.""" + return cython.compiled # type: ignore[no-any-return] + + +# END GENERATED CYTHON IMPORT + + +@cython.cclass +class BaseRow: + __slots__ = ("_parent", "_data", "_key_to_index") + + if cython.compiled: + _parent: ResultMetaData = cython.declare(object, visibility="readonly") + _key_to_index: Dict[_KeyType, int] = cython.declare( + dict, visibility="readonly" + ) + _data: Tuple[Any, ...] = cython.declare(tuple, visibility="readonly") + + def __init__( + self, + parent: ResultMetaData, + processors: Optional[_ProcessorsType], + key_to_index: Dict[_KeyType, int], + data: Sequence[Any], + ) -> None: + """Row objects are constructed by CursorResult objects.""" + + data_tuple: Tuple[Any, ...] = ( + _apply_processors(processors, data) + if processors is not None + else tuple(data) + ) + self._set_attrs(parent, key_to_index, data_tuple) + + @cython.cfunc + @cython.inline + def _set_attrs( # type: ignore[no-untyped-def] # cython crashes + self, + parent: ResultMetaData, + key_to_index: Dict[_KeyType, int], + data: Tuple[Any, ...], + ): + if cython.compiled: + # cython does not use __setattr__ + self._parent = parent + self._key_to_index = key_to_index + self._data = data + else: + # python does, so use object.__setattr__ + object.__setattr__(self, "_parent", parent) + object.__setattr__(self, "_key_to_index", key_to_index) + object.__setattr__(self, "_data", data) + + def __reduce__(self) -> Tuple[Any, Any]: + return ( + rowproxy_reconstructor, + (self.__class__, self.__getstate__()), + ) + + def __getstate__(self) -> Dict[str, Any]: + return {"_parent": self._parent, "_data": self._data} + + def __setstate__(self, state: Dict[str, Any]) -> None: + parent = state["_parent"] + self._set_attrs(parent, parent._key_to_index, state["_data"]) + + def _values_impl(self) -> List[Any]: + return list(self._data) + + def __iter__(self) -> Iterator[Any]: + return iter(self._data) + + def __len__(self) -> int: + return len(self._data) + + def __hash__(self) -> int: + return hash(self._data) + + def __getitem__(self, key: Any) -> Any: + return self._data[key] + + def _get_by_key_impl_mapping(self, key: _KeyType) -> Any: + return self._get_by_key_impl(key, False) + + @cython.cfunc + def _get_by_key_impl(self, key: _KeyType, attr_err: cython.bint) -> object: + index: Optional[int] = self._key_to_index.get(key) + if index is not None: + return self._data[index] + self._parent._key_not_found(key, attr_err) + + @cython.annotation_typing(False) + def __getattr__(self, name: str) -> Any: + return self._get_by_key_impl(name, True) + + def _to_tuple_instance(self) -> Tuple[Any, ...]: + return self._data + + +@cython.inline +@cython.cfunc +def _apply_processors( + proc: _ProcessorsType, data: Sequence[Any] +) -> Tuple[Any, ...]: + res: List[Any] = list(data) + proc_size: cython.Py_ssize_t = len(proc) + # TODO: would be nice to do this only on the fist row + assert len(res) == proc_size + for i in range(proc_size): + p = proc[i] + if p is not None: + res[i] = p(res[i]) + return tuple(res) + + +# This reconstructor is necessary so that pickles with the Cy extension or +# without use the same Binary format. +# Turn off annotation typing so the compiled version accepts the python +# class too. +@cython.annotation_typing(False) +def rowproxy_reconstructor( + cls: Type[BaseRow], state: Dict[str, Any] +) -> BaseRow: + obj = cls.__new__(cls) + obj.__setstate__(state) + return obj diff --git a/lib/sqlalchemy/engine/_util_cy.py b/lib/sqlalchemy/engine/_util_cy.py new file mode 100644 index 0000000000..156fcce998 --- /dev/null +++ b/lib/sqlalchemy/engine/_util_cy.py @@ -0,0 +1,129 @@ +# engine/_util_cy.py +# Copyright (C) 2010-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: disable-error-code="misc, type-arg" +from __future__ import annotations + +from collections.abc import Mapping +import operator +from typing import Any +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import TYPE_CHECKING + +from sqlalchemy import exc + +if TYPE_CHECKING: + from .interfaces import _CoreAnyExecuteParams + from .interfaces import _CoreMultiExecuteParams + from .interfaces import _DBAPIAnyExecuteParams + from .interfaces import _DBAPIMultiExecuteParams + from .result import _TupleGetterType + +# START GENERATED CYTHON IMPORT +# This section is automatically generated by the script tools/cython_imports.py +try: + # NOTE: the cython compiler needs this "import cython" in the file, it + # can't be only "from sqlalchemy.util import cython" with the fallback + # in that module + import cython +except ModuleNotFoundError: + from sqlalchemy.util import cython + + +def _is_compiled() -> bool: + """Utility function to indicate if this module is compiled or not.""" + return cython.compiled # type: ignore[no-any-return] + + +# END GENERATED CYTHON IMPORT + +_Empty_Tuple: Tuple[Any, ...] = cython.declare(tuple, ()) + + +@cython.inline +@cython.cfunc +def _is_mapping_or_tuple(value: object) -> cython.bint: + return ( + isinstance(value, dict) + or isinstance(value, tuple) + or isinstance(value, Mapping) + # only do immutabledict or abc.__instancecheck__ for Mapping after + # we've checked for plain dictionaries and would otherwise raise + ) + + +@cython.inline +@cython.cfunc +@cython.exceptval(0) +def _validate_execute_many_item(params: Sequence[Any]) -> cython.bint: + ret: cython.bint = 1 + if len(params) > 0: + if not _is_mapping_or_tuple(params[0]): + ret = 0 + raise exc.ArgumentError( + "List argument must consist only of tuples or dictionaries" + ) + return ret + + +# _is_mapping_or_tuple and _validate_execute_many_item could be +# inlined if pure python perf is a problem +def _distill_params_20( + params: Optional[_CoreAnyExecuteParams], +) -> _CoreMultiExecuteParams: + if params is None: + return _Empty_Tuple + # Assume list is more likely than tuple + elif isinstance(params, list) or isinstance(params, tuple): + # collections_abc.MutableSequence # avoid abc.__instancecheck__ + _validate_execute_many_item(params) + return params + elif isinstance(params, dict) or isinstance(params, Mapping): + # only do immutabledict or abc.__instancecheck__ for Mapping after + # we've checked for plain dictionaries and would otherwise raise + return [params] + else: + raise exc.ArgumentError("mapping or list expected for parameters") + + +def _distill_raw_params( + params: Optional[_DBAPIAnyExecuteParams], +) -> _DBAPIMultiExecuteParams: + if params is None: + return _Empty_Tuple + elif isinstance(params, list): + # collections_abc.MutableSequence # avoid abc.__instancecheck__ + _validate_execute_many_item(params) + return params + elif _is_mapping_or_tuple(params): + return [params] # type: ignore[return-value] + else: + raise exc.ArgumentError("mapping or sequence expected for parameters") + + +@cython.cfunc +def _is_contiguous(indexes: Tuple[int, ...]) -> cython.bint: + i: cython.Py_ssize_t + prev: cython.Py_ssize_t + curr: cython.Py_ssize_t + for i in range(1, len(indexes)): + prev = indexes[i - 1] + curr = indexes[i] + if prev != curr - 1: + return False + return True + + +def tuplegetter(*indexes: int) -> _TupleGetterType: + max_index: int + if len(indexes) == 1 or _is_contiguous(indexes): + # slice form is faster but returns a list if input is list + max_index = indexes[-1] + return operator.itemgetter(slice(indexes[0], max_index + 1)) + else: + return operator.itemgetter(*indexes) diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index a674c5902b..4f0d104870 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -1449,9 +1449,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): ) -> Any: """Execute a schema.ColumnDefault object.""" - execution_options = self._execution_options.merge_with( - execution_options - ) + exec_opts = self._execution_options.merge_with(execution_options) event_multiparams: Optional[_CoreMultiExecuteParams] event_params: Optional[_CoreAnyExecuteParams] @@ -1467,7 +1465,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): event_multiparams, event_params, ) = self._invoke_before_exec_event( - default, distilled_parameters, execution_options + default, distilled_parameters, exec_opts ) else: event_multiparams = event_params = None @@ -1479,7 +1477,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): dialect = self.dialect ctx = dialect.execution_ctx_cls._init_default( - dialect, self, conn, execution_options + dialect, self, conn, exec_opts ) except (exc.PendingRollbackError, exc.ResourceClosedError): raise @@ -1494,7 +1492,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): default, event_multiparams, event_params, - execution_options, + exec_opts, ret, ) @@ -1603,7 +1601,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): ) -> CursorResult[Unpack[TupleAny]]: """Execute a sql.ClauseElement object.""" - execution_options = elem._execution_options.merge_with( + exec_opts = elem._execution_options.merge_with( self._execution_options, execution_options ) @@ -1615,7 +1613,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): event_multiparams, event_params, ) = self._invoke_before_exec_event( - elem, distilled_parameters, execution_options + elem, distilled_parameters, exec_opts ) if distilled_parameters: @@ -1629,11 +1627,9 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): dialect = self.dialect - schema_translate_map = execution_options.get( - "schema_translate_map", None - ) + schema_translate_map = exec_opts.get("schema_translate_map", None) - compiled_cache: Optional[CompiledCacheType] = execution_options.get( + compiled_cache: Optional[CompiledCacheType] = exec_opts.get( "compiled_cache", self.engine._compiled_cache ) @@ -1650,7 +1646,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): dialect.execution_ctx_cls._init_compiled, compiled_sql, distilled_parameters, - execution_options, + exec_opts, compiled_sql, distilled_parameters, elem, @@ -1663,7 +1659,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): elem, event_multiparams, event_params, - execution_options, + exec_opts, ret, ) return ret @@ -1680,7 +1676,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): """ - execution_options = compiled.execution_options.merge_with( + exec_opts = compiled.execution_options.merge_with( self._execution_options, execution_options ) @@ -1691,7 +1687,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): event_multiparams, event_params, ) = self._invoke_before_exec_event( - compiled, distilled_parameters, execution_options + compiled, distilled_parameters, exec_opts ) dialect = self.dialect @@ -1701,7 +1697,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): dialect.execution_ctx_cls._init_compiled, compiled, distilled_parameters, - execution_options, + exec_opts, compiled, distilled_parameters, None, @@ -1713,7 +1709,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): compiled, event_multiparams, event_params, - execution_options, + exec_opts, ret, ) return ret @@ -1779,9 +1775,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): distilled_parameters = _distill_raw_params(parameters) - execution_options = self._execution_options.merge_with( - execution_options - ) + exec_opts = self._execution_options.merge_with(execution_options) dialect = self.dialect ret = self._execute_context( @@ -1789,7 +1783,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): dialect.execution_ctx_cls._init_statement, statement, None, - execution_options, + exec_opts, statement, distilled_parameters, ) diff --git a/lib/sqlalchemy/engine/processors.py b/lib/sqlalchemy/engine/processors.py index 610e03d5a1..47f07e006c 100644 --- a/lib/sqlalchemy/engine/processors.py +++ b/lib/sqlalchemy/engine/processors.py @@ -14,48 +14,69 @@ They all share one common characteristic: None is passed through unchanged. """ from __future__ import annotations -import typing +import datetime +from typing import Callable +from typing import Optional +from typing import Pattern +from typing import TypeVar +from typing import Union -from ._py_processors import str_to_datetime_processor_factory # noqa -from ..util._has_cy import HAS_CYEXTENSION +from ._processors_cy import int_to_boolean as int_to_boolean # noqa: F401 +from ._processors_cy import str_to_date as str_to_date # noqa: F401 +from ._processors_cy import str_to_datetime as str_to_datetime # noqa: F401 +from ._processors_cy import str_to_time as str_to_time # noqa: F401 +from ._processors_cy import to_float as to_float # noqa: F401 +from ._processors_cy import to_str as to_str # noqa: F401 -if typing.TYPE_CHECKING or not HAS_CYEXTENSION: - from ._py_processors import int_to_boolean as int_to_boolean - from ._py_processors import str_to_date as str_to_date - from ._py_processors import str_to_datetime as str_to_datetime - from ._py_processors import str_to_time as str_to_time - from ._py_processors import ( +if True: + from ._processors_cy import ( # noqa: F401 to_decimal_processor_factory as to_decimal_processor_factory, ) - from ._py_processors import to_float as to_float - from ._py_processors import to_str as to_str -else: - from sqlalchemy.cyextension.processors import ( - DecimalResultProcessor, - ) - from sqlalchemy.cyextension.processors import ( # noqa: F401 - int_to_boolean as int_to_boolean, - ) - from sqlalchemy.cyextension.processors import ( # noqa: F401,E501 - str_to_date as str_to_date, - ) - from sqlalchemy.cyextension.processors import ( # noqa: F401 - str_to_datetime as str_to_datetime, - ) - from sqlalchemy.cyextension.processors import ( # noqa: F401,E501 - str_to_time as str_to_time, - ) - from sqlalchemy.cyextension.processors import ( # noqa: F401,E501 - to_float as to_float, - ) - from sqlalchemy.cyextension.processors import ( # noqa: F401,E501 - to_str as to_str, - ) - def to_decimal_processor_factory(target_class, scale): - # Note that the scale argument is not taken into account for integer - # values in the C implementation while it is in the Python one. - # For example, the Python implementation might return - # Decimal('5.00000') whereas the C implementation will - # return Decimal('5'). These are equivalent of course. - return DecimalResultProcessor(target_class, "%%.%df" % scale).process + +_DT = TypeVar( + "_DT", bound=Union[datetime.datetime, datetime.time, datetime.date] +) + + +def str_to_datetime_processor_factory( + regexp: Pattern[str], type_: Callable[..., _DT] +) -> Callable[[Optional[str]], Optional[_DT]]: + rmatch = regexp.match + # Even on python2.6 datetime.strptime is both slower than this code + # and it does not support microseconds. + has_named_groups = bool(regexp.groupindex) + + def process(value: Optional[str]) -> Optional[_DT]: + if value is None: + return None + else: + try: + m = rmatch(value) + except TypeError as err: + raise ValueError( + "Couldn't parse %s string '%r' " + "- value is not a string." % (type_.__name__, value) + ) from err + + if m is None: + raise ValueError( + "Couldn't parse %s string: " + "'%s'" % (type_.__name__, value) + ) + if has_named_groups: + groups = m.groupdict(0) + return type_( + **dict( + list( + zip( + iter(groups.keys()), + list(map(int, iter(groups.values()))), + ) + ) + ) + ) + else: + return type_(*list(map(int, m.groups(0)))) + + return process diff --git a/lib/sqlalchemy/engine/result.py b/lib/sqlalchemy/engine/result.py index fad6102551..226b7f8c63 100644 --- a/lib/sqlalchemy/engine/result.py +++ b/lib/sqlalchemy/engine/result.py @@ -33,6 +33,7 @@ from typing import TYPE_CHECKING from typing import TypeVar from typing import Union +from ._util_cy import tuplegetter as tuplegetter from .row import Row from .row import RowMapping from .. import exc @@ -43,18 +44,12 @@ from ..sql.base import InPlaceGenerative from ..util import deprecated from ..util import HasMemoized_ro_memoized_attribute from ..util import NONE_SET -from ..util._has_cy import HAS_CYEXTENSION from ..util.typing import Literal from ..util.typing import Self from ..util.typing import TupleAny from ..util.typing import TypeVarTuple from ..util.typing import Unpack -if typing.TYPE_CHECKING or not HAS_CYEXTENSION: - from ._py_row import tuplegetter as tuplegetter -else: - from sqlalchemy.cyextension.resultproxy import tuplegetter as tuplegetter - if typing.TYPE_CHECKING: from ..sql.schema import Column from ..sql.type_api import _ResultProcessorType @@ -103,7 +98,7 @@ class ResultMetaData: _keymap: _KeyMapType _keys: Sequence[str] _processors: Optional[_ProcessorsType] - _key_to_index: Mapping[_KeyType, int] + _key_to_index: Dict[_KeyType, int] @property def keys(self) -> RMKeyView: @@ -183,7 +178,7 @@ class ResultMetaData: def _make_key_to_index( self, keymap: Mapping[_KeyType, Sequence[Any]], index: int - ) -> Mapping[_KeyType, int]: + ) -> Dict[_KeyType, int]: return { key: rec[index] for key, rec in keymap.items() @@ -462,7 +457,7 @@ class ResultInternal(InPlaceGenerative, Generic[_R]): def process_row( metadata: ResultMetaData, processors: Optional[_ProcessorsType], - key_to_index: Mapping[_KeyType, int], + key_to_index: Dict[_KeyType, int], scalar_obj: Any, ) -> Row[Unpack[TupleAny]]: return _proc( diff --git a/lib/sqlalchemy/engine/row.py b/lib/sqlalchemy/engine/row.py index 79d8026c62..893b9c5c0c 100644 --- a/lib/sqlalchemy/engine/row.py +++ b/lib/sqlalchemy/engine/row.py @@ -25,19 +25,13 @@ from typing import Optional from typing import Sequence from typing import Tuple from typing import TYPE_CHECKING -from typing import TypeVar +from ._row_cy import BaseRow as BaseRow from ..sql import util as sql_util from ..util import deprecated -from ..util._has_cy import HAS_CYEXTENSION from ..util.typing import TypeVarTuple from ..util.typing import Unpack -if TYPE_CHECKING or not HAS_CYEXTENSION: - from ._py_row import BaseRow as BaseRow -else: - from sqlalchemy.cyextension.resultproxy import BaseRow as BaseRow - if TYPE_CHECKING: from typing import Tuple as _RowBase @@ -48,7 +42,6 @@ else: _RowBase = Sequence -_T = TypeVar("_T", bound=Any) _Ts = TypeVarTuple("_Ts") diff --git a/lib/sqlalchemy/engine/util.py b/lib/sqlalchemy/engine/util.py index 34c615c841..284973b455 100644 --- a/lib/sqlalchemy/engine/util.py +++ b/lib/sqlalchemy/engine/util.py @@ -7,29 +7,18 @@ from __future__ import annotations -import typing from typing import Any from typing import Callable from typing import Optional from typing import Protocol from typing import TypeVar +from ._util_cy import _distill_params_20 as _distill_params_20 # noqa: F401 +from ._util_cy import _distill_raw_params as _distill_raw_params # noqa: F401 from .. import exc from .. import util -from ..util._has_cy import HAS_CYEXTENSION from ..util.typing import Self -if typing.TYPE_CHECKING or not HAS_CYEXTENSION: - from ._py_util import _distill_params_20 as _distill_params_20 - from ._py_util import _distill_raw_params as _distill_raw_params -else: - from sqlalchemy.cyextension.util import ( # noqa: F401 - _distill_params_20 as _distill_params_20, - ) - from sqlalchemy.cyextension.util import ( # noqa: F401 - _distill_raw_params as _distill_raw_params, - ) - _C = TypeVar("_C", bound=Callable[[], Any]) diff --git a/lib/sqlalchemy/orm/collections.py b/lib/sqlalchemy/orm/collections.py index eeef7241c8..d112680df6 100644 --- a/lib/sqlalchemy/orm/collections.py +++ b/lib/sqlalchemy/orm/collections.py @@ -1553,14 +1553,15 @@ class InstrumentedDict(Dict[_KT, _VT]): """An instrumented version of the built-in dict.""" -__canned_instrumentation: util.immutabledict[Any, _CollectionFactoryType] = ( +__canned_instrumentation = cast( + util.immutabledict[Any, _CollectionFactoryType], util.immutabledict( { list: InstrumentedList, set: InstrumentedSet, dict: InstrumentedDict, } - ) + ), ) __interfaces: util.immutabledict[ diff --git a/lib/sqlalchemy/sql/_py_util.py b/lib/sqlalchemy/sql/_py_util.py deleted file mode 100644 index df372bf5d5..0000000000 --- a/lib/sqlalchemy/sql/_py_util.py +++ /dev/null @@ -1,75 +0,0 @@ -# sql/_py_util.py -# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors -# -# -# This module is part of SQLAlchemy and is released under -# the MIT License: https://www.opensource.org/licenses/mit-license.php - -from __future__ import annotations - -import typing -from typing import Any -from typing import Dict -from typing import Tuple -from typing import Union - -from ..util.typing import Literal - -if typing.TYPE_CHECKING: - from .cache_key import CacheConst - - -class prefix_anon_map(Dict[str, str]): - """A map that creates new keys for missing key access. - - Considers keys of the form " " to produce - new symbols "_", where "index" is an incrementing integer - corresponding to . - - Inlines the approach taken by :class:`sqlalchemy.util.PopulateDict` which - is otherwise usually used for this type of operation. - - """ - - def __missing__(self, key: str) -> str: - (ident, derived) = key.split(" ", 1) - anonymous_counter = self.get(derived, 1) - self[derived] = anonymous_counter + 1 # type: ignore - value = f"{derived}_{anonymous_counter}" - self[key] = value - return value - - -class cache_anon_map( - Dict[Union[int, "Literal[CacheConst.NO_CACHE]"], Union[Literal[True], str]] -): - """A map that creates new keys for missing key access. - - Produces an incrementing sequence given a series of unique keys. - - This is similar to the compiler prefix_anon_map class although simpler. - - Inlines the approach taken by :class:`sqlalchemy.util.PopulateDict` which - is otherwise usually used for this type of operation. - - """ - - _index = 0 - - def get_anon(self, object_: Any) -> Tuple[str, bool]: - idself = id(object_) - if idself in self: - s_val = self[idself] - assert s_val is not True - return s_val, True - else: - # inline of __missing__ - self[idself] = id_ = str(self._index) - self._index += 1 - - return id_, False - - def __missing__(self, key: int) -> str: - self[key] = val = str(self._index) - self._index += 1 - return val diff --git a/lib/sqlalchemy/sql/_util_cy.py b/lib/sqlalchemy/sql/_util_cy.py new file mode 100644 index 0000000000..2d15b1c7e2 --- /dev/null +++ b/lib/sqlalchemy/sql/_util_cy.py @@ -0,0 +1,108 @@ +# sql/_util_cy.py +# Copyright (C) 2010-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from __future__ import annotations + +from typing import Dict +from typing import Tuple +from typing import TYPE_CHECKING +from typing import Union + +from ..util.typing import Literal + +if TYPE_CHECKING: + from .cache_key import CacheConst + +# START GENERATED CYTHON IMPORT +# This section is automatically generated by the script tools/cython_imports.py +try: + # NOTE: the cython compiler needs this "import cython" in the file, it + # can't be only "from sqlalchemy.util import cython" with the fallback + # in that module + import cython +except ModuleNotFoundError: + from sqlalchemy.util import cython + + +def _is_compiled() -> bool: + """Utility function to indicate if this module is compiled or not.""" + return cython.compiled # type: ignore[no-any-return] + + +# END GENERATED CYTHON IMPORT + + +@cython.cclass +class prefix_anon_map(Dict[str, str]): + """A map that creates new keys for missing key access. + + Considers keys of the form " " to produce + new symbols "_", where "index" is an incrementing integer + corresponding to . + + Inlines the approach taken by :class:`sqlalchemy.util.PopulateDict` which + is otherwise usually used for this type of operation. + + """ + + def __missing__(self, key: str, /) -> str: + derived: str + value: str + self_dict: dict = self # type: ignore[type-arg] + + derived = key.split(" ", 1)[1] + + anonymous_counter: int = self_dict.get(derived, 1) + self_dict[derived] = anonymous_counter + 1 + value = f"{derived}_{anonymous_counter}" + self_dict[key] = value + return value + + +@cython.cclass +class anon_map( + Dict[ + Union[int, str, "Literal[CacheConst.NO_CACHE]"], + Union[Literal[True], str], + ] +): + """A map that creates new keys for missing key access. + + Produces an incrementing sequence given a series of unique keys. + + This is similar to the compiler prefix_anon_map class although simpler. + + Inlines the approach taken by :class:`sqlalchemy.util.PopulateDict` which + is otherwise usually used for this type of operation. + + """ + + if cython.compiled: + _index: cython.uint + + def __cinit__(self): # type: ignore[no-untyped-def] + self._index = 0 + + else: + _index: int = 0 # type: ignore[no-redef] + + def get_anon(self, obj: object, /) -> Tuple[str, bool]: + self_dict: dict = self # type: ignore[type-arg] + + idself = id(obj) + if idself in self_dict: + return self_dict[idself], True + else: + return self.__missing__(idself), False + + def __missing__(self, key: Union[int, str], /) -> str: + val: str + self_dict: dict = self # type: ignore[type-arg] + + self_dict[key] = val = str(self._index) + self._index += 1 + return val diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index 05025909a4..3e7c24eaff 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -35,10 +35,11 @@ from typing import TYPE_CHECKING from typing import TypeVar from typing import Union +from ._util_cy import anon_map as anon_map +from ._util_cy import prefix_anon_map as prefix_anon_map # noqa: F401 from .. import exc from .. import util from ..util import langhelpers -from ..util._has_cy import HAS_CYEXTENSION from ..util.typing import Literal from ..util.typing import Self @@ -46,17 +47,6 @@ if TYPE_CHECKING: from .annotation import _AnnotationDict from .elements import ColumnElement -if typing.TYPE_CHECKING or not HAS_CYEXTENSION: - from ._py_util import prefix_anon_map as prefix_anon_map - from ._py_util import cache_anon_map as anon_map -else: - from sqlalchemy.cyextension.util import ( # noqa: F401,E501 - prefix_anon_map as prefix_anon_map, - ) - from sqlalchemy.cyextension.util import ( # noqa: F401,E501 - cache_anon_map as anon_map, - ) - __all__ = [ "iterate", diff --git a/lib/sqlalchemy/testing/plugin/pytestplugin.py b/lib/sqlalchemy/testing/plugin/pytestplugin.py index 1a4d4bb30a..6024b39add 100644 --- a/lib/sqlalchemy/testing/plugin/pytestplugin.py +++ b/lib/sqlalchemy/testing/plugin/pytestplugin.py @@ -136,7 +136,7 @@ def _log_sqlalchemy_info(session): import sqlalchemy from sqlalchemy import __version__ from sqlalchemy.util import has_compiled_ext - from sqlalchemy.util._has_cy import _CYEXTENSION_MSG + from sqlalchemy.util._has_cython import _CYEXTENSION_MSG greet = "sqlalchemy installation" site = "no user site" if sys.flags.no_user_site else "user site loaded" @@ -146,9 +146,9 @@ def _log_sqlalchemy_info(session): ] if has_compiled_ext(): - from sqlalchemy.cyextension import util + from sqlalchemy.engine import _util_cy - msgs.append(f"compiled extension enabled, e.g. {util.__file__} ") + msgs.append(f"compiled extension enabled, e.g. {_util_cy.__file__} ") else: msgs.append(f"compiled extension not enabled; {_CYEXTENSION_MSG}") diff --git a/lib/sqlalchemy/util/_collections.py b/lib/sqlalchemy/util/_collections.py index 5dd0179505..3d092a0223 100644 --- a/lib/sqlalchemy/util/_collections.py +++ b/lib/sqlalchemy/util/_collections.py @@ -35,33 +35,15 @@ from typing import Union from typing import ValuesView import weakref -from ._has_cy import HAS_CYEXTENSION +from ._collections_cy import IdentitySet as IdentitySet +from ._collections_cy import OrderedSet as OrderedSet +from ._collections_cy import unique_list as unique_list # noqa: F401 +from ._immutabledict_cy import immutabledict as immutabledict +from ._immutabledict_cy import ImmutableDictBase as ImmutableDictBase +from ._immutabledict_cy import ReadOnlyContainer as ReadOnlyContainer from .typing import is_non_string_iterable from .typing import Literal -if typing.TYPE_CHECKING or not HAS_CYEXTENSION: - from ._py_collections import immutabledict as immutabledict - from ._py_collections import IdentitySet as IdentitySet - from ._py_collections import ReadOnlyContainer as ReadOnlyContainer - from ._py_collections import ImmutableDictBase as ImmutableDictBase - from ._py_collections import OrderedSet as OrderedSet - from ._py_collections import unique_list as unique_list -else: - from sqlalchemy.cyextension.immutabledict import ( - ReadOnlyContainer as ReadOnlyContainer, - ) - from sqlalchemy.cyextension.immutabledict import ( - ImmutableDictBase as ImmutableDictBase, - ) - from sqlalchemy.cyextension.immutabledict import ( - immutabledict as immutabledict, - ) - from sqlalchemy.cyextension.collections import IdentitySet as IdentitySet - from sqlalchemy.cyextension.collections import OrderedSet as OrderedSet - from sqlalchemy.cyextension.collections import ( # noqa - unique_list as unique_list, - ) - _T = TypeVar("_T", bound=Any) _KT = TypeVar("_KT", bound=Any) @@ -144,7 +126,7 @@ class FacadeDict(ImmutableDictBase[_KT, _VT]): """A dictionary that is not publicly mutable.""" def __new__(cls, *args: Any) -> FacadeDict[Any, Any]: - new = ImmutableDictBase.__new__(cls) + new: FacadeDict[Any, Any] = ImmutableDictBase.__new__(cls) return new def copy(self) -> NoReturn: @@ -320,13 +302,7 @@ class WeakSequence(Sequence[_T]): return obj() -class OrderedIdentitySet(IdentitySet): - def __init__(self, iterable: Optional[Iterable[Any]] = None): - IdentitySet.__init__(self) - self._members = OrderedDict() - if iterable: - for o in iterable: - self.add(o) +OrderedIdentitySet = IdentitySet class PopulateDict(Dict[_KT, _VT]): diff --git a/lib/sqlalchemy/util/_collections_cy.py b/lib/sqlalchemy/util/_collections_cy.py new file mode 100644 index 0000000000..0931ac450c --- /dev/null +++ b/lib/sqlalchemy/util/_collections_cy.py @@ -0,0 +1,528 @@ +# util/_collections_cy.py +# Copyright (C) 2010-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: disable-error-code="misc, no-any-return, no-untyped-def, override" + +from __future__ import annotations + +from typing import AbstractSet +from typing import Any +from typing import Dict +from typing import Hashable +from typing import Iterable +from typing import Iterator +from typing import List +from typing import NoReturn +from typing import Optional +from typing import Set +from typing import Tuple +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + +from .typing import Self + +# START GENERATED CYTHON IMPORT +# This section is automatically generated by the script tools/cython_imports.py +try: + # NOTE: the cython compiler needs this "import cython" in the file, it + # can't be only "from sqlalchemy.util import cython" with the fallback + # in that module + import cython +except ModuleNotFoundError: + from sqlalchemy.util import cython + + +def _is_compiled() -> bool: + """Utility function to indicate if this module is compiled or not.""" + return cython.compiled # type: ignore[no-any-return] + + +# END GENERATED CYTHON IMPORT + +if cython.compiled: + from cython.cimports.cpython.long import PyLong_FromUnsignedLongLong +elif TYPE_CHECKING: + + def PyLong_FromUnsignedLongLong(v: Any) -> int: ... + + +_T = TypeVar("_T") +_S = TypeVar("_S") + + +@cython.ccall +def unique_list(seq: Iterable[_T]) -> List[_T]: + # this version seems somewhat faster for smaller sizes, but it's + # significantly slower on larger sizes + # w = {x:None for x in seq} + # return PyDict_Keys(w) if cython.compiled else list(w) + if cython.compiled: + seen: Set[_T] = set() + return [x for x in seq if x not in seen and not set.add(seen, x)] + else: + return list(dict.fromkeys(seq)) + + # In case passing an hashfunc is required in the future two version were + # tested: + # - this version is faster but returns the *last* element matching the + # hash. + # from cython.cimports.cpython.dict import PyDict_Values + # w: dict = {hashfunc(x): x for x in seq} + # return PyDict_Values(w) if cython.compiled else list(w.values()) + # - this version is slower but returns the *first* element matching the + # hash. + # seen: set = set() + # res: list = [] + # for x in seq: + # h = hashfunc(x) + # if h not in seen: + # res.append(x) + # seen.add(h) + # return res + + +@cython.cclass +class OrderedSet(Set[_T]): + """A set implementation that maintains insertion order.""" + + __slots__ = ("_list",) + _list: List[_T] + + @classmethod + def __class_getitem__(cls, key: Any) -> type[Self]: + return cls + + def __init__(self, d: Optional[Iterable[_T]] = None) -> None: + if d is not None: + if isinstance(d, set) or isinstance(d, dict): + self._list = list(d) + else: + self._list = unique_list(d) + set.__init__(self, self._list) + else: + self._list = [] + set.__init__(self) + + def copy(self) -> OrderedSet[_T]: + return self._from_list(list(self._list)) + + @cython.final + @cython.cfunc + @cython.inline + def _from_list(self, new_list: List[_T]) -> OrderedSet: # type: ignore[type-arg] # noqa: E501 + new: OrderedSet = OrderedSet.__new__(OrderedSet) # type: ignore[type-arg] # noqa: E501 + new._list = new_list + set.update(new, new_list) + return new + + def add(self, element: _T, /) -> None: + if element not in self: + self._list.append(element) + set.add(self, element) + + def remove(self, element: _T, /) -> None: + # set.remove will raise if element is not in self + set.remove(self, element) + self._list.remove(element) + + def pop(self) -> _T: + try: + value = self._list.pop() + except IndexError: + raise KeyError("pop from an empty set") from None + set.remove(self, value) + return value + + def insert(self, pos: cython.Py_ssize_t, element: _T, /) -> None: + if element not in self: + self._list.insert(pos, element) + set.add(self, element) + + def discard(self, element: _T, /) -> None: + if element in self: + set.remove(self, element) + self._list.remove(element) + + def clear(self) -> None: + set.clear(self) # type: ignore[arg-type] + self._list = [] + + def __getitem__(self, key: cython.Py_ssize_t) -> _T: + return self._list[key] + + def __iter__(self) -> Iterator[_T]: + return iter(self._list) + + def __add__(self, other: Iterator[_T]) -> OrderedSet[_T]: + return self.union(other) + + def __repr__(self) -> str: + return "%s(%r)" % (self.__class__.__name__, self._list) + + __str__ = __repr__ + + # @cython.ccall # cdef function cannot have star argument + def update(self, *iterables: Iterable[_T]) -> None: + for iterable in iterables: + for element in iterable: + # inline of add. mainly for python, since for cython we + # could create an @cfunc @inline _add function that would + # perform the same + if element not in self: + self._list.append(element) + set.add(self, element) + + def __ior__( + self: OrderedSet[Union[_T, _S]], iterable: AbstractSet[_S] + ) -> OrderedSet[Union[_T, _S]]: + self.update(iterable) + return self + + # @cython.ccall # cdef function cannot have star argument + def union(self, *other: Iterable[_S]) -> OrderedSet[Union[_T, _S]]: + result: OrderedSet[Union[_T, _S]] = self._from_list(list(self._list)) + result.update(*other) + return result + + def __or__(self, other: AbstractSet[_S]) -> OrderedSet[Union[_T, _S]]: + return self.union(other) + + # @cython.ccall # cdef function cannot have star argument + def intersection(self, *other: Iterable[Hashable]) -> OrderedSet[_T]: + other_set: Set[Any] = set.intersection(self, *other) + return self._from_list([a for a in self._list if a in other_set]) + + def __and__(self, other: AbstractSet[Hashable]) -> OrderedSet[_T]: + return self.intersection(other) + + @cython.ccall + @cython.annotation_typing(False) # avoid cython crash from generic return + def symmetric_difference( + self, other: Iterable[_S], / + ) -> OrderedSet[Union[_T, _S]]: + collection: Iterable[Any] + other_set: Set[_S] + if isinstance(other, set): + other_set = cython.cast(set, other) + collection = other_set + elif hasattr(other, "__len__"): + collection = other + other_set = set(other) + else: + collection = list(other) + other_set = set(collection) + result: OrderedSet[Union[_T, _S]] = self._from_list( + [a for a in self._list if a not in other_set] + ) + result.update([a for a in collection if a not in self]) + return result + + def __xor__(self, other: AbstractSet[_S]) -> OrderedSet[Union[_T, _S]]: + return self.symmetric_difference(other) + + # @cython.ccall # cdef function cannot have star argument + def difference(self, *other: Iterable[Hashable]) -> OrderedSet[_T]: + other_set: Set[Any] = set.difference(self, *other) + return self._from_list([a for a in self._list if a in other_set]) + + def __sub__(self, other: AbstractSet[Hashable]) -> OrderedSet[_T]: + return self.difference(other) + + # @cython.ccall # cdef function cannot have star argument + def intersection_update(self, *other: Iterable[Hashable]) -> None: + set.intersection_update(self, *other) + self._list = [a for a in self._list if a in self] + + def __iand__(self, other: AbstractSet[Hashable]) -> OrderedSet[_T]: + self.intersection_update(other) + return self + + @cython.ccall + @cython.annotation_typing(False) # avoid cython crash from generic return + def symmetric_difference_update(self, other: Iterable[_T], /) -> None: + collection = other if hasattr(other, "__len__") else list(other) + set.symmetric_difference_update(self, collection) + self._list = [a for a in self._list if a in self] + self._list += [a for a in collection if a in self] + + def __ixor__( + self: OrderedSet[Union[_T, _S]], other: AbstractSet[_S] + ) -> OrderedSet[Union[_T, _S]]: + self.symmetric_difference_update(other) + return self + + # @cython.ccall # cdef function cannot have star argument + def difference_update(self, *other: Iterable[Hashable]) -> None: + set.difference_update(self, *other) + self._list = [a for a in self._list if a in self] + + def __isub__(self, other: AbstractSet[Hashable]) -> OrderedSet[_T]: + self.difference_update(other) + return self + + +if cython.compiled: + + @cython.final + @cython.inline + @cython.cfunc + @cython.annotation_typing(False) + def _get_id(item: Any) -> int: + return PyLong_FromUnsignedLongLong( + cython.cast( + cython.ulonglong, + cython.cast(cython.pointer(cython.void), item), + ) + ) + +else: + _get_id = id + + +@cython.cclass +class IdentitySet: + """A set that considers only object id() for uniqueness. + + This strategy has edge cases for builtin types- it's possible to have + two 'foo' strings in one of these sets, for example. Use sparingly. + + """ + + __slots__ = ("_members",) + _members: Dict[int, Any] + + def __init__(self, iterable: Optional[Iterable[Any]] = None): + # the code assumes this class is ordered + self._members = {} + if iterable: + self.update(iterable) + + def add(self, value: Any, /) -> None: + self._members[_get_id(value)] = value + + def __contains__(self, value) -> bool: + return _get_id(value) in self._members + + @cython.ccall + def remove(self, value: Any, /): + del self._members[_get_id(value)] + + def discard(self, value, /) -> None: + try: + self.remove(value) + except KeyError: + pass + + def pop(self) -> Any: + pair: Tuple[Any, Any] + try: + pair = self._members.popitem() + return pair[1] + except KeyError: + raise KeyError("pop from an empty set") + + def clear(self) -> None: + self._members.clear() + + def __eq__(self, other: Any) -> bool: + other_: IdentitySet + if isinstance(other, IdentitySet): + other_ = other + return self._members == other_._members + else: + return False + + def __ne__(self, other: Any) -> bool: + other_: IdentitySet + if isinstance(other, IdentitySet): + other_ = other + return self._members != other_._members + else: + return True + + @cython.ccall + def issubset(self, iterable: Iterable[Any], /) -> cython.bint: + other: IdentitySet + if isinstance(iterable, IdentitySet): + other = iterable + else: + other = self.__class__(iterable) + + return self._members.keys() <= other._members.keys() + + def __le__(self, other: Any) -> bool: + if not isinstance(other, IdentitySet): + return NotImplemented + return self.issubset(other) + + def __lt__(self, other: Any) -> bool: + if not isinstance(other, IdentitySet): + return NotImplemented + return len(self) < len(other) and self.issubset(other) + + @cython.ccall + def issuperset(self, iterable: Iterable[Any], /) -> cython.bint: + other: IdentitySet + if isinstance(iterable, IdentitySet): + other = iterable + else: + other = self.__class__(iterable) + + return self._members.keys() >= other._members.keys() + + def __ge__(self, other: Any) -> bool: + if not isinstance(other, IdentitySet): + return NotImplemented + return self.issuperset(other) + + def __gt__(self, other: Any) -> bool: + if not isinstance(other, IdentitySet): + return NotImplemented + return len(self) > len(other) and self.issuperset(other) + + @cython.ccall + def union(self, iterable: Iterable[Any], /) -> IdentitySet: + result: IdentitySet = self.__class__() + result._members.update(self._members) + result.update(iterable) + return result + + def __or__(self, other: Any) -> IdentitySet: + if not isinstance(other, IdentitySet): + return NotImplemented + return self.union(other) + + @cython.ccall + def update(self, iterable: Iterable[Any], /): + members: Dict[int, Any] = self._members + if isinstance(iterable, IdentitySet): + members.update(cython.cast(IdentitySet, iterable)._members) + else: + for obj in iterable: + members[_get_id(obj)] = obj + + def __ior__(self, other: Any) -> IdentitySet: + if not isinstance(other, IdentitySet): + return NotImplemented + self.update(other) + return self + + @cython.ccall + def difference(self, iterable: Iterable[Any], /) -> IdentitySet: + result: IdentitySet = self.__new__(self.__class__) + if isinstance(iterable, IdentitySet): + other = cython.cast(IdentitySet, iterable)._members.keys() + else: + other = {_get_id(obj) for obj in iterable} + + result._members = { + k: v for k, v in self._members.items() if k not in other + } + return result + + def __sub__(self, other: IdentitySet) -> IdentitySet: + if not isinstance(other, IdentitySet): + return NotImplemented + return self.difference(other) + + # def difference_update(self, iterable: Iterable[Any]) -> None: + @cython.ccall + def difference_update(self, iterable: Iterable[Any], /): + other: IdentitySet = self.difference(iterable) + self._members = other._members + + def __isub__(self, other: IdentitySet) -> IdentitySet: + if not isinstance(other, IdentitySet): + return NotImplemented + self.difference_update(other) + return self + + @cython.ccall + def intersection(self, iterable: Iterable[Any], /) -> IdentitySet: + result: IdentitySet = self.__new__(self.__class__) + if isinstance(iterable, IdentitySet): + other = cython.cast(IdentitySet, iterable)._members + else: + other = {_get_id(obj) for obj in iterable} + result._members = { + k: v for k, v in self._members.items() if k in other + } + return result + + def __and__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + return self.intersection(other) + + # def intersection_update(self, iterable: Iterable[Any]) -> None: + @cython.ccall + def intersection_update(self, iterable: Iterable[Any], /): + other: IdentitySet = self.intersection(iterable) + self._members = other._members + + def __iand__(self, other: IdentitySet) -> IdentitySet: + if not isinstance(other, IdentitySet): + return NotImplemented + self.intersection_update(other) + return self + + @cython.ccall + def symmetric_difference(self, iterable: Iterable[Any], /) -> IdentitySet: + result: IdentitySet = self.__new__(self.__class__) + other: Dict[int, Any] + if isinstance(iterable, IdentitySet): + other = cython.cast(IdentitySet, iterable)._members + else: + other = {_get_id(obj): obj for obj in iterable} + result._members = { + k: v for k, v in self._members.items() if k not in other + } + result._members.update( + [(k, v) for k, v in other.items() if k not in self._members] + ) + return result + + def __xor__(self, other: IdentitySet) -> IdentitySet: + if not isinstance(other, IdentitySet): + return NotImplemented + return self.symmetric_difference(other) + + # def symmetric_difference_update(self, iterable: Iterable[Any]) -> None: + @cython.ccall + def symmetric_difference_update(self, iterable: Iterable[Any], /): + other: IdentitySet = self.symmetric_difference(iterable) + self._members = other._members + + def __ixor__(self, other: IdentitySet) -> IdentitySet: + if not isinstance(other, IdentitySet): + return NotImplemented + self.symmetric_difference(other) + return self + + @cython.ccall + def copy(self) -> IdentitySet: + cp: IdentitySet = self.__new__(self.__class__) + cp._members = self._members.copy() + return cp + + def __copy__(self) -> IdentitySet: + return self.copy() + + def __len__(self) -> int: + return len(self._members) + + def __iter__(self) -> Iterator[Any]: + return iter(self._members.values()) + + def __hash__(self) -> NoReturn: + raise TypeError("set objects are unhashable") + + def __repr__(self) -> str: + return "%s(%r)" % ( + self.__class__.__name__, + list(self._members.values()), + ) diff --git a/lib/sqlalchemy/util/_has_cy.py b/lib/sqlalchemy/util/_has_cy.py deleted file mode 100644 index 7713e236ac..0000000000 --- a/lib/sqlalchemy/util/_has_cy.py +++ /dev/null @@ -1,40 +0,0 @@ -# util/_has_cy.py -# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors -# -# -# This module is part of SQLAlchemy and is released under -# the MIT License: https://www.opensource.org/licenses/mit-license.php -# mypy: ignore-errors - -import os -import typing - - -def _import_cy_extensions(): - # all cython extension extension modules are treated as optional by the - # setup, so to ensure that all are compiled, all should be imported here - from ..cyextension import collections - from ..cyextension import immutabledict - from ..cyextension import processors - from ..cyextension import resultproxy - from ..cyextension import util - - return (collections, immutabledict, processors, resultproxy, util) - - -_CYEXTENSION_MSG: str -if not typing.TYPE_CHECKING: - if os.environ.get("DISABLE_SQLALCHEMY_CEXT_RUNTIME"): - HAS_CYEXTENSION = False - _CYEXTENSION_MSG = "DISABLE_SQLALCHEMY_CEXT_RUNTIME is set" - else: - try: - _import_cy_extensions() - except ImportError as err: - HAS_CYEXTENSION = False - _CYEXTENSION_MSG = str(err) - else: - _CYEXTENSION_MSG = "Loaded" - HAS_CYEXTENSION = True -else: - HAS_CYEXTENSION = False diff --git a/lib/sqlalchemy/util/_has_cython.py b/lib/sqlalchemy/util/_has_cython.py new file mode 100644 index 0000000000..ef99d58143 --- /dev/null +++ b/lib/sqlalchemy/util/_has_cython.py @@ -0,0 +1,44 @@ +# util/_has_cython.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +import typing + + +def _all_cython_modules(): + """Returns all modules that can be compiled using cython. + Call ``_is_compiled()`` to check if the module is compiled or not. + """ + from . import _collections_cy + from . import _immutabledict_cy + from ..engine import _processors_cy + from ..engine import _row_cy + from ..engine import _util_cy as engine_util + from ..sql import _util_cy as sql_util + + return ( + _collections_cy, + _immutabledict_cy, + _processors_cy, + _row_cy, + engine_util, + sql_util, + ) + + +_CYEXTENSION_MSG: str +if not typing.TYPE_CHECKING: + HAS_CYEXTENSION = all(m._is_compiled() for m in _all_cython_modules()) + if HAS_CYEXTENSION: + _CYEXTENSION_MSG = "Loaded" + else: + _CYEXTENSION_MSG = ", ".join( + m.__name__ for m in _all_cython_modules() if not m._is_compiled() + ) + _CYEXTENSION_MSG = f"Modules {_CYEXTENSION_MSG} are not compiled" +else: + HAS_CYEXTENSION = False diff --git a/lib/sqlalchemy/util/_immutabledict_cy.py b/lib/sqlalchemy/util/_immutabledict_cy.py new file mode 100644 index 0000000000..cf1867de17 --- /dev/null +++ b/lib/sqlalchemy/util/_immutabledict_cy.py @@ -0,0 +1,208 @@ +# util/_immutabledict_cy.py +# Copyright (C) 2010-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: disable-error-code="misc, arg-type" +from __future__ import annotations + +from typing import Any +from typing import Dict +from typing import Hashable +from typing import Mapping +from typing import NoReturn +from typing import Optional +from typing import TypeVar + +from .typing import Self + +# START GENERATED CYTHON IMPORT +# This section is automatically generated by the script tools/cython_imports.py +try: + # NOTE: the cython compiler needs this "import cython" in the file, it + # can't be only "from sqlalchemy.util import cython" with the fallback + # in that module + import cython +except ModuleNotFoundError: + from sqlalchemy.util import cython + + +def _is_compiled() -> bool: + """Utility function to indicate if this module is compiled or not.""" + return cython.compiled # type: ignore[no-any-return] + + +# END GENERATED CYTHON IMPORT + +if cython.compiled: + from cython.cimports.cpython.dict import PyDict_Update +else: + PyDict_Update = dict.update + + +def _immutable_fn(obj: object) -> NoReturn: + raise TypeError(f"{obj.__class__.__name__} object is immutable") + + +class ReadOnlyContainer: + __slots__ = () + + def _readonly(self) -> NoReturn: + raise TypeError( + f"{self.__class__.__name__} object is immutable and/or readonly" + ) + + def __delitem__(self, key: Any) -> NoReturn: + self._readonly() + + def __setitem__(self, key: Any, value: Any) -> NoReturn: + self._readonly() + + def __setattr__(self, key: Any, value: Any) -> NoReturn: + self._readonly() + + +_KT = TypeVar("_KT", bound=Hashable) +_VT = TypeVar("_VT", bound=Any) + + +@cython.cclass +class ImmutableDictBase(Dict[_KT, _VT]): + # NOTE: this method is required in 3.9 and speeds up the use case + # ImmutableDictBase[str,int](a_dict) significantly + @classmethod + def __class_getitem__( # type: ignore[override] + cls, key: Any + ) -> type[Self]: + return cls + + def __delitem__(self, key: Any) -> NoReturn: + _immutable_fn(self) + + def __setitem__(self, key: Any, value: Any) -> NoReturn: + _immutable_fn(self) + + def __setattr__(self, key: Any, value: Any) -> NoReturn: + _immutable_fn(self) + + def clear(self) -> NoReturn: + _immutable_fn(self) + + def pop(self, key: Any, default: Optional[Any] = None) -> NoReturn: + _immutable_fn(self) + + def popitem(self) -> NoReturn: + _immutable_fn(self) + + def setdefault(self, key: Any, default: Optional[Any] = None) -> NoReturn: + _immutable_fn(self) + + def update(self, *arg: Any, **kw: Any) -> NoReturn: + _immutable_fn(self) + + +# NOTE: can't extend from ImmutableDictBase[_KT, _VT] due to a compiler +# crash in doing so. Extending from ImmutableDictBase is ok, but requires +# a type checking section and other workaround for the crash +@cython.cclass +class immutabledict(Dict[_KT, _VT]): + """An immutable version of a dict.""" + + # ImmutableDictBase start + @classmethod + def __class_getitem__( # type: ignore[override] + cls, key: Any + ) -> type[Self]: + return cls + + def __delitem__(self, key: Any) -> NoReturn: + _immutable_fn(self) + + def __setitem__(self, key: Any, value: Any) -> NoReturn: + _immutable_fn(self) + + def __setattr__(self, key: Any, value: Any) -> NoReturn: + _immutable_fn(self) + + def clear(self) -> NoReturn: + _immutable_fn(self) + + def pop(self, key: Any, default: Optional[Any] = None) -> NoReturn: + _immutable_fn(self) + + def popitem(self) -> NoReturn: + _immutable_fn(self) + + def setdefault(self, key: Any, default: Optional[Any] = None) -> NoReturn: + _immutable_fn(self) + + def update(self, *arg: Any, **kw: Any) -> NoReturn: + _immutable_fn(self) + + # ImmutableDictBase end + + def __repr__(self) -> str: + return f"immutabledict({dict.__repr__(self)})" + + @cython.annotation_typing(False) # avoid cython crash from generic return + def union( + self, other: Optional[Mapping[_KT, _VT]] = None, / + ) -> immutabledict[_KT, _VT]: + if not other: + return self + # new + update is faster than immutabledict(self) + result: immutabledict = immutabledict() # type: ignore[type-arg] + PyDict_Update(result, self) + if isinstance(other, dict): + # c version of PyDict_Update supports only dicts + PyDict_Update(result, other) + else: + dict.update(result, other) + return result + + @cython.annotation_typing(False) # avoid cython crash from generic return + def merge_with( + self, *dicts: Optional[Mapping[_KT, _VT]] + ) -> immutabledict[_KT, _VT]: + result: Optional[immutabledict] = None # type: ignore[type-arg] + d: object + if not dicts: + return self + for d in dicts: + if d is not None and len(d) > 0: + if result is None: + # new + update is faster than immutabledict(self) + result = immutabledict() + PyDict_Update(result, self) + if isinstance(d, dict): + # c version of PyDict_Update supports only dicts + PyDict_Update(result, d) + else: + dict.update(result, d) + + return self if result is None else result + + def copy(self) -> Self: + return self + + def __reduce__(self) -> Any: + return immutabledict, (dict(self),) + + # PEP 584 + def __ior__(self, __value: Any, /) -> NoReturn: + _immutable_fn(self) + + def __or__( # type: ignore[override] + self, __value: Mapping[_KT, _VT], / + ) -> immutabledict[_KT, _VT]: + return immutabledict( + dict.__or__(self, __value), # type: ignore[call-overload] + ) + + def __ror__( # type: ignore[override] + self, __value: Mapping[_KT, _VT], / + ) -> immutabledict[_KT, _VT]: + return immutabledict( + dict.__ror__(self, __value), # type: ignore[call-overload] + ) diff --git a/lib/sqlalchemy/util/_py_collections.py b/lib/sqlalchemy/util/_py_collections.py deleted file mode 100644 index e05626eaf7..0000000000 --- a/lib/sqlalchemy/util/_py_collections.py +++ /dev/null @@ -1,541 +0,0 @@ -# util/_py_collections.py -# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors -# -# -# This module is part of SQLAlchemy and is released under -# the MIT License: https://www.opensource.org/licenses/mit-license.php -# mypy: allow-untyped-defs, allow-untyped-calls - -from __future__ import annotations - -from itertools import filterfalse -from typing import AbstractSet -from typing import Any -from typing import Callable -from typing import cast -from typing import Collection -from typing import Dict -from typing import Iterable -from typing import Iterator -from typing import List -from typing import Mapping -from typing import NoReturn -from typing import Optional -from typing import Set -from typing import Tuple -from typing import TYPE_CHECKING -from typing import TypeVar -from typing import Union - -from ..util.typing import Self - -_T = TypeVar("_T", bound=Any) -_S = TypeVar("_S", bound=Any) -_KT = TypeVar("_KT", bound=Any) -_VT = TypeVar("_VT", bound=Any) - - -class ReadOnlyContainer: - __slots__ = () - - def _readonly(self, *arg: Any, **kw: Any) -> NoReturn: - raise TypeError( - "%s object is immutable and/or readonly" % self.__class__.__name__ - ) - - def _immutable(self, *arg: Any, **kw: Any) -> NoReturn: - raise TypeError("%s object is immutable" % self.__class__.__name__) - - def __delitem__(self, key: Any) -> NoReturn: - self._readonly() - - def __setitem__(self, key: Any, value: Any) -> NoReturn: - self._readonly() - - def __setattr__(self, key: str, value: Any) -> NoReturn: - self._readonly() - - -class ImmutableDictBase(ReadOnlyContainer, Dict[_KT, _VT]): - if TYPE_CHECKING: - - def __new__(cls, *args: Any) -> Self: ... - - def __init__(cls, *args: Any): ... - - def _readonly(self, *arg: Any, **kw: Any) -> NoReturn: - self._immutable() - - def clear(self) -> NoReturn: - self._readonly() - - def pop(self, key: Any, default: Optional[Any] = None) -> NoReturn: - self._readonly() - - def popitem(self) -> NoReturn: - self._readonly() - - def setdefault(self, key: Any, default: Optional[Any] = None) -> NoReturn: - self._readonly() - - def update(self, *arg: Any, **kw: Any) -> NoReturn: - self._readonly() - - -class immutabledict(ImmutableDictBase[_KT, _VT]): - def __new__(cls, *args): - new = ImmutableDictBase.__new__(cls) - dict.__init__(new, *args) - return new - - def __init__( - self, *args: Union[Mapping[_KT, _VT], Iterable[Tuple[_KT, _VT]]] - ): - pass - - def __reduce__(self): - return immutabledict, (dict(self),) - - def union( - self, __d: Optional[Mapping[_KT, _VT]] = None - ) -> immutabledict[_KT, _VT]: - if not __d: - return self - - new = ImmutableDictBase.__new__(self.__class__) - dict.__init__(new, self) - dict.update(new, __d) # type: ignore - return new - - def _union_w_kw( - self, __d: Optional[Mapping[_KT, _VT]] = None, **kw: _VT - ) -> immutabledict[_KT, _VT]: - # not sure if C version works correctly w/ this yet - if not __d and not kw: - return self - - new = ImmutableDictBase.__new__(self.__class__) - dict.__init__(new, self) - if __d: - dict.update(new, __d) # type: ignore - dict.update(new, kw) # type: ignore - return new - - def merge_with( - self, *dicts: Optional[Mapping[_KT, _VT]] - ) -> immutabledict[_KT, _VT]: - new = None - for d in dicts: - if d: - if new is None: - new = ImmutableDictBase.__new__(self.__class__) - dict.__init__(new, self) - dict.update(new, d) # type: ignore - if new is None: - return self - - return new - - def __repr__(self) -> str: - return "immutabledict(%s)" % dict.__repr__(self) - - # PEP 584 - def __ior__(self, __value: Any, /) -> NoReturn: # type: ignore - self._readonly() - - def __or__( # type: ignore[override] - self, __value: Mapping[_KT, _VT], / - ) -> immutabledict[_KT, _VT]: - return immutabledict( - super().__or__(__value), # type: ignore[call-overload] - ) - - def __ror__( # type: ignore[override] - self, __value: Mapping[_KT, _VT], / - ) -> immutabledict[_KT, _VT]: - return immutabledict( - super().__ror__(__value), # type: ignore[call-overload] - ) - - -class OrderedSet(Set[_T]): - __slots__ = ("_list",) - - _list: List[_T] - - def __init__(self, d: Optional[Iterable[_T]] = None) -> None: - if d is not None: - self._list = unique_list(d) - super().update(self._list) - else: - self._list = [] - - def copy(self) -> OrderedSet[_T]: - cp = self.__class__() - cp._list = self._list.copy() - set.update(cp, cp._list) - return cp - - def add(self, element: _T) -> None: - if element not in self: - self._list.append(element) - super().add(element) - - def remove(self, element: _T) -> None: - super().remove(element) - self._list.remove(element) - - def pop(self) -> _T: - try: - value = self._list.pop() - except IndexError: - raise KeyError("pop from an empty set") from None - super().remove(value) - return value - - def insert(self, pos: int, element: _T) -> None: - if element not in self: - self._list.insert(pos, element) - super().add(element) - - def discard(self, element: _T) -> None: - if element in self: - self._list.remove(element) - super().remove(element) - - def clear(self) -> None: - super().clear() - self._list = [] - - def __getitem__(self, key: int) -> _T: - return self._list[key] - - def __iter__(self) -> Iterator[_T]: - return iter(self._list) - - def __add__(self, other: Iterator[_T]) -> OrderedSet[_T]: - return self.union(other) - - def __repr__(self) -> str: - return "%s(%r)" % (self.__class__.__name__, self._list) - - __str__ = __repr__ - - def update(self, *iterables: Iterable[_T]) -> None: - for iterable in iterables: - for e in iterable: - if e not in self: - self._list.append(e) - super().add(e) - - def __ior__(self, other: AbstractSet[_S]) -> OrderedSet[Union[_T, _S]]: - self.update(other) - return self - - def union(self, *other: Iterable[_S]) -> OrderedSet[Union[_T, _S]]: - result: OrderedSet[Union[_T, _S]] = self.copy() - result.update(*other) - return result - - def __or__(self, other: AbstractSet[_S]) -> OrderedSet[Union[_T, _S]]: - return self.union(other) - - def intersection(self, *other: Iterable[Any]) -> OrderedSet[_T]: - other_set: Set[Any] = set() - other_set.update(*other) - return self.__class__(a for a in self if a in other_set) - - def __and__(self, other: AbstractSet[object]) -> OrderedSet[_T]: - return self.intersection(other) - - def symmetric_difference(self, other: Iterable[_T]) -> OrderedSet[_T]: - collection: Collection[_T] - if isinstance(other, set): - collection = other_set = other - elif isinstance(other, Collection): - collection = other - other_set = set(other) - else: - collection = list(other) - other_set = set(collection) - result = self.__class__(a for a in self if a not in other_set) - result.update(a for a in collection if a not in self) - return result - - def __xor__(self, other: AbstractSet[_S]) -> OrderedSet[Union[_T, _S]]: - return cast(OrderedSet[Union[_T, _S]], self).symmetric_difference( - other - ) - - def difference(self, *other: Iterable[Any]) -> OrderedSet[_T]: - other_set = super().difference(*other) - return self.__class__(a for a in self._list if a in other_set) - - def __sub__(self, other: AbstractSet[Optional[_T]]) -> OrderedSet[_T]: - return self.difference(other) - - def intersection_update(self, *other: Iterable[Any]) -> None: - super().intersection_update(*other) - self._list = [a for a in self._list if a in self] - - def __iand__(self, other: AbstractSet[object]) -> OrderedSet[_T]: - self.intersection_update(other) - return self - - def symmetric_difference_update(self, other: Iterable[Any]) -> None: - collection = other if isinstance(other, Collection) else list(other) - super().symmetric_difference_update(collection) - self._list = [a for a in self._list if a in self] - self._list += [a for a in collection if a in self] - - def __ixor__(self, other: AbstractSet[_S]) -> OrderedSet[Union[_T, _S]]: - self.symmetric_difference_update(other) - return cast(OrderedSet[Union[_T, _S]], self) - - def difference_update(self, *other: Iterable[Any]) -> None: - super().difference_update(*other) - self._list = [a for a in self._list if a in self] - - def __isub__(self, other: AbstractSet[Optional[_T]]) -> OrderedSet[_T]: # type: ignore # noqa: E501 - self.difference_update(other) - return self - - -class IdentitySet: - """A set that considers only object id() for uniqueness. - - This strategy has edge cases for builtin types- it's possible to have - two 'foo' strings in one of these sets, for example. Use sparingly. - - """ - - _members: Dict[int, Any] - - def __init__(self, iterable: Optional[Iterable[Any]] = None): - self._members = dict() - if iterable: - self.update(iterable) - - def add(self, value: Any) -> None: - self._members[id(value)] = value - - def __contains__(self, value: Any) -> bool: - return id(value) in self._members - - def remove(self, value: Any) -> None: - del self._members[id(value)] - - def discard(self, value: Any) -> None: - try: - self.remove(value) - except KeyError: - pass - - def pop(self) -> Any: - try: - pair = self._members.popitem() - return pair[1] - except KeyError: - raise KeyError("pop from an empty set") - - def clear(self) -> None: - self._members.clear() - - def __eq__(self, other: Any) -> bool: - if isinstance(other, IdentitySet): - return self._members == other._members - else: - return False - - def __ne__(self, other: Any) -> bool: - if isinstance(other, IdentitySet): - return self._members != other._members - else: - return True - - def issubset(self, iterable: Iterable[Any]) -> bool: - if isinstance(iterable, self.__class__): - other = iterable - else: - other = self.__class__(iterable) - - if len(self) > len(other): - return False - for m in filterfalse( - other._members.__contains__, iter(self._members.keys()) - ): - return False - return True - - def __le__(self, other: Any) -> bool: - if not isinstance(other, IdentitySet): - return NotImplemented - return self.issubset(other) - - def __lt__(self, other: Any) -> bool: - if not isinstance(other, IdentitySet): - return NotImplemented - return len(self) < len(other) and self.issubset(other) - - def issuperset(self, iterable: Iterable[Any]) -> bool: - if isinstance(iterable, self.__class__): - other = iterable - else: - other = self.__class__(iterable) - - if len(self) < len(other): - return False - - for m in filterfalse( - self._members.__contains__, iter(other._members.keys()) - ): - return False - return True - - def __ge__(self, other: Any) -> bool: - if not isinstance(other, IdentitySet): - return NotImplemented - return self.issuperset(other) - - def __gt__(self, other: Any) -> bool: - if not isinstance(other, IdentitySet): - return NotImplemented - return len(self) > len(other) and self.issuperset(other) - - def union(self, iterable: Iterable[Any]) -> IdentitySet: - result = self.__class__() - members = self._members - result._members.update(members) - result._members.update((id(obj), obj) for obj in iterable) - return result - - def __or__(self, other: Any) -> IdentitySet: - if not isinstance(other, IdentitySet): - return NotImplemented - return self.union(other) - - def update(self, iterable: Iterable[Any]) -> None: - self._members.update((id(obj), obj) for obj in iterable) - - def __ior__(self, other: Any) -> IdentitySet: - if not isinstance(other, IdentitySet): - return NotImplemented - self.update(other) - return self - - def difference(self, iterable: Iterable[Any]) -> IdentitySet: - result = self.__new__(self.__class__) - other: Collection[Any] - - if isinstance(iterable, self.__class__): - other = iterable._members - else: - other = {id(obj) for obj in iterable} - result._members = { - k: v for k, v in self._members.items() if k not in other - } - return result - - def __sub__(self, other: IdentitySet) -> IdentitySet: - if not isinstance(other, IdentitySet): - return NotImplemented - return self.difference(other) - - def difference_update(self, iterable: Iterable[Any]) -> None: - self._members = self.difference(iterable)._members - - def __isub__(self, other: IdentitySet) -> IdentitySet: - if not isinstance(other, IdentitySet): - return NotImplemented - self.difference_update(other) - return self - - def intersection(self, iterable: Iterable[Any]) -> IdentitySet: - result = self.__new__(self.__class__) - - other: Collection[Any] - - if isinstance(iterable, self.__class__): - other = iterable._members - else: - other = {id(obj) for obj in iterable} - result._members = { - k: v for k, v in self._members.items() if k in other - } - return result - - def __and__(self, other: IdentitySet) -> IdentitySet: - if not isinstance(other, IdentitySet): - return NotImplemented - return self.intersection(other) - - def intersection_update(self, iterable: Iterable[Any]) -> None: - self._members = self.intersection(iterable)._members - - def __iand__(self, other: IdentitySet) -> IdentitySet: - if not isinstance(other, IdentitySet): - return NotImplemented - self.intersection_update(other) - return self - - def symmetric_difference(self, iterable: Iterable[Any]) -> IdentitySet: - result = self.__new__(self.__class__) - if isinstance(iterable, self.__class__): - other = iterable._members - else: - other = {id(obj): obj for obj in iterable} - result._members = { - k: v for k, v in self._members.items() if k not in other - } - result._members.update( - (k, v) for k, v in other.items() if k not in self._members - ) - return result - - def __xor__(self, other: IdentitySet) -> IdentitySet: - if not isinstance(other, IdentitySet): - return NotImplemented - return self.symmetric_difference(other) - - def symmetric_difference_update(self, iterable: Iterable[Any]) -> None: - self._members = self.symmetric_difference(iterable)._members - - def __ixor__(self, other: IdentitySet) -> IdentitySet: - if not isinstance(other, IdentitySet): - return NotImplemented - self.symmetric_difference(other) - return self - - def copy(self) -> IdentitySet: - result = self.__new__(self.__class__) - result._members = self._members.copy() - return result - - __copy__ = copy - - def __len__(self) -> int: - return len(self._members) - - def __iter__(self) -> Iterator[Any]: - return iter(self._members.values()) - - def __hash__(self) -> NoReturn: - raise TypeError("set objects are unhashable") - - def __repr__(self) -> str: - return "%s(%r)" % (type(self).__name__, list(self._members.values())) - - -def unique_list( - seq: Iterable[_T], hashfunc: Optional[Callable[[_T], int]] = None -) -> List[_T]: - seen: Set[Any] = set() - seen_add = seen.add - if not hashfunc: - return [x for x in seq if x not in seen and not seen_add(x)] - else: - return [ - x - for x in seq - if hashfunc(x) not in seen and not seen_add(hashfunc(x)) - ] diff --git a/lib/sqlalchemy/util/cython.py b/lib/sqlalchemy/util/cython.py new file mode 100644 index 0000000000..c143138b8e --- /dev/null +++ b/lib/sqlalchemy/util/cython.py @@ -0,0 +1,61 @@ +# util/cython.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +from __future__ import annotations + +from typing import Any +from typing import Callable +from typing import Type +from typing import TypeVar + +_T = TypeVar("_T") +_NO_OP = Callable[[_T], _T] + +# cython module shims +# -- +IS_SHIM = True +# constants +compiled = False + +# types +int = int # noqa: A001 +bint = bool +longlong = int +ulonglong = int +Py_ssize_t = int +uint = int +float = float # noqa: A001 +double = float +void = Any + + +# functions +def _no_op(fn: _T) -> _T: + return fn + + +cclass = _no_op # equivalent to "cdef class" +ccall = _no_op # equivalent to "cpdef" function +cfunc = _no_op # equivalent to "cdef" function +inline = _no_op +final = _no_op +pointer = _no_op # not sure how to express a pointer to a type + + +def declare(t: Type[_T], value: Any = None, **kw: Any) -> _T: + return value # type: ignore[no-any-return] + + +def annotation_typing(_: bool) -> _NO_OP[_T]: + return _no_op + + +def exceptval(value: Any = None, *, check: bool = False) -> _NO_OP[_T]: + return _no_op + + +def cast(type_: Type[_T], value: Any, *, typecheck: bool = False) -> _T: + return value # type: ignore[no-any-return] diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py index 31c205fbc6..f73a579744 100644 --- a/lib/sqlalchemy/util/langhelpers.py +++ b/lib/sqlalchemy/util/langhelpers.py @@ -15,6 +15,7 @@ from __future__ import annotations import collections import enum from functools import update_wrapper +import importlib.util import inspect import itertools import operator @@ -24,6 +25,7 @@ import textwrap import threading import types from types import CodeType +from types import ModuleType from typing import Any from typing import Callable from typing import cast @@ -47,18 +49,14 @@ import warnings from . import _collections from . import compat -from ._has_cy import HAS_CYEXTENSION from .typing import Literal from .. import exc _T = TypeVar("_T") _T_co = TypeVar("_T_co", covariant=True) _F = TypeVar("_F", bound=Callable[..., Any]) -_MP = TypeVar("_MP", bound="memoized_property[Any]") _MA = TypeVar("_MA", bound="HasMemoized.memoized_attribute[Any]") -_HP = TypeVar("_HP", bound="hybridproperty[Any]") -_HM = TypeVar("_HM", bound="hybridmethod[Any]") - +_M = TypeVar("_M", bound=ModuleType) if compat.py310: @@ -2200,6 +2198,8 @@ def repr_tuple_names(names: List[str]) -> Optional[str]: def has_compiled_ext(raise_=False): + from ._has_cython import HAS_CYEXTENSION + if HAS_CYEXTENSION: return True elif raise_: @@ -2209,3 +2209,27 @@ def has_compiled_ext(raise_=False): ) else: return False + + +def load_uncompiled_module(module: _M) -> _M: + """Load the non-compied version of a module that is also + compiled with cython. + """ + full_name = module.__name__ + assert module.__spec__ + parent_name = module.__spec__.parent + assert parent_name + parent_module = sys.modules[parent_name] + assert parent_module.__spec__ + package_path = parent_module.__spec__.origin + assert package_path and package_path.endswith("__init__.py") + + name = full_name.split(".")[-1] + module_path = package_path.replace("__init__.py", f"{name}.py") + + py_spec = importlib.util.spec_from_file_location(full_name, module_path) + assert py_spec + py_module = importlib.util.module_from_spec(py_spec) + assert py_spec.loader + py_spec.loader.exec_module(py_module) + return cast(_M, py_module) diff --git a/pyproject.toml b/pyproject.toml index bc9e5706ae..08d2259fdf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,8 +1,8 @@ [build-system] build-backend = "setuptools.build_meta" requires = [ - "setuptools>=61.2", - "cython>=0.29.24; platform_python_implementation == 'CPython'", # Skip cython when using pypy + "setuptools>=47", + "cython>=3; platform_python_implementation == 'CPython'", # Skip cython when using pypy ] @@ -189,7 +189,10 @@ module = [ warn_unused_ignores = true strict = true +[[tool.mypy.overrides]] +module = ["cython", "cython.*"] +ignore_missing_imports = true [tool.cibuildwheel] test-requires = "pytest pytest-xdist" diff --git a/setup.py b/setup.py index ad4e4002db..e0971fa30d 100644 --- a/setup.py +++ b/setup.py @@ -29,34 +29,36 @@ if DISABLE_EXTENSION and REQUIRE_EXTENSION: "'REQUIRE_SQLALCHEMY_CEXT' environment variables" ) +# when adding a cython module, also update the imports in _has_cython +# it is tested in test_setup_defines_all_files +CYTHON_MODULES = ( + "engine._processors_cy", + "engine._row_cy", + "engine._util_cy", + "sql._util_cy", + "util._collections_cy", + "util._immutabledict_cy", +) if HAS_CYTHON and IS_CPYTHON and not DISABLE_EXTENSION: assert _cy_Extension is not None assert _cy_build_ext is not None - # when adding a cython module, also update the imports in _has_cy - cython_files = [ - "collections.pyx", - "immutabledict.pyx", - "processors.pyx", - "resultproxy.pyx", - "util.pyx", - ] cython_directives = {"language_level": "3"} - module_prefix = "sqlalchemy.cyextension." - source_prefix = "lib/sqlalchemy/cyextension/" + module_prefix = "sqlalchemy." + source_prefix = "lib/sqlalchemy/" ext_modules = cast( "list[Extension]", [ _cy_Extension( - f"{module_prefix}{os.path.splitext(file)[0]}", - sources=[f"{source_prefix}{file}"], + f"{module_prefix}{module}", + sources=[f"{source_prefix}{module.replace('.', '/')}.py"], cython_directives=cython_directives, optional=not REQUIRE_EXTENSION, ) - for file in cython_files + for module in CYTHON_MODULES ], ) diff --git a/test/aaa_profiling/test_memusage.py b/test/aaa_profiling/test_memusage.py index fc6be0f096..94629b1416 100644 --- a/test/aaa_profiling/test_memusage.py +++ b/test/aaa_profiling/test_memusage.py @@ -283,7 +283,7 @@ class MemUsageTest(EnsureZeroed): def test_DecimalResultProcessor_init(self): @profile_memory() def go(): - to_decimal_processor_factory({}, 10) + to_decimal_processor_factory(dict, 10) go() diff --git a/test/base/test_result.py b/test/base/test_result.py index 3bbd1b8788..57970c740b 100644 --- a/test/base/test_result.py +++ b/test/base/test_result.py @@ -1,3 +1,6 @@ +import operator +import sys + from sqlalchemy import exc from sqlalchemy import testing from sqlalchemy.engine import result @@ -11,6 +14,7 @@ from sqlalchemy.testing.assertions import expect_deprecated from sqlalchemy.testing.assertions import expect_raises from sqlalchemy.testing.util import picklers from sqlalchemy.util import compat +from sqlalchemy.util.langhelpers import load_uncompiled_module class ResultTupleTest(fixtures.TestBase): @@ -96,7 +100,6 @@ class ResultTupleTest(fixtures.TestBase): # row as tuple getter doesn't accept ints. for ints, just # use plain python - import operator getter = operator.itemgetter(2, 0, 1) @@ -201,11 +204,31 @@ class ResultTupleTest(fixtures.TestBase): eq_(kt._fields, ("a", "b")) eq_(kt._asdict(), {"a": 1, "b": 3}) + @testing.fixture + def _load_module(self): + from sqlalchemy.engine import _row_cy as _cy_row + + _py_row = load_uncompiled_module(_cy_row) + + # allow pickle to serialize the two rowproxy_reconstructor functions + # create a new virtual module + new_name = _py_row.__name__ + "py_only" + sys.modules[new_name] = _py_row + _py_row.__name__ = new_name + for item in vars(_py_row).values(): + # only the rowproxy_reconstructor module is required to change, + # but set every one for consistency + if getattr(item, "__module__", None) == _cy_row.__name__: + item.__module__ = new_name + yield _cy_row, _py_row + sys.modules.pop(new_name) + @testing.requires.cextensions @testing.variation("direction", ["py_to_cy", "cy_to_py"]) - def test_serialize_cy_py_cy(self, direction: testing.Variation): - from sqlalchemy.engine import _py_row - from sqlalchemy.cyextension import resultproxy as _cy_row + def test_serialize_cy_py_cy( + self, direction: testing.Variation, _load_module + ): + _cy_row, _py_row = _load_module global Row @@ -256,10 +279,8 @@ class ResultTupleTest(fixtures.TestBase): parent, [None, str, None, str.upper], parent._key_to_index, data ) eq_(row_some_p._to_tuple_instance(), (1, "99", "42", "FOO")) - row_shorter = result.Row( - parent, [None, str], parent._key_to_index, data - ) - eq_(row_shorter._to_tuple_instance(), (1, "99")) + with expect_raises(AssertionError): + result.Row(parent, [None, str], parent._key_to_index, data) def test_tuplegetter(self): data = list(range(10, 20)) diff --git a/test/base/test_utils.py b/test/base/test_utils.py index de8712c852..0ca60c7931 100644 --- a/test/base/test_utils.py +++ b/test/base/test_utils.py @@ -37,8 +37,7 @@ from sqlalchemy.util import langhelpers from sqlalchemy.util import preloaded from sqlalchemy.util import WeakSequence from sqlalchemy.util._collections import merge_lists_w_ordering -from sqlalchemy.util._has_cy import _import_cy_extensions -from sqlalchemy.util._has_cy import HAS_CYEXTENSION +from sqlalchemy.util._has_cython import _all_cython_modules class WeakSequenceTest(fixtures.TestBase): @@ -3618,15 +3617,41 @@ class MethodOveriddenTest(fixtures.TestBase): class CyExtensionTest(fixtures.TestBase): - @testing.only_if(lambda: HAS_CYEXTENSION, "No Cython") + __requires__ = ("cextensions",) + def test_all_cyext_imported(self): - ext = _import_cy_extensions() + ext = _all_cython_modules() lib_folder = (Path(__file__).parent / ".." / ".." / "lib").resolve() sa_folder = lib_folder / "sqlalchemy" - cython_files = [f.resolve() for f in sa_folder.glob("**/*.pyx")] + cython_files = [f.resolve() for f in sa_folder.glob("**/*_cy.py")] eq_(len(ext), len(cython_files)) names = { - ".".join(f.relative_to(lib_folder).parts).replace(".pyx", "") + ".".join(f.relative_to(lib_folder).parts).replace(".py", "") for f in cython_files } eq_({m.__name__ for m in ext}, set(names)) + + @testing.combinations(*_all_cython_modules()) + def test_load_uncompiled_module(self, module): + is_true(module._is_compiled()) + py_module = langhelpers.load_uncompiled_module(module) + is_false(py_module._is_compiled()) + eq_(py_module.__name__, module.__name__) + eq_(py_module.__package__, module.__package__) + + def test_setup_defines_all_files(self): + try: + import setuptools # noqa: F401 + except ImportError: + testing.skip_test("setuptools is required") + with mock.patch("setuptools.setup", mock.MagicMock()), mock.patch.dict( + "os.environ", + {"DISABLE_SQLALCHEMY_CEXT": "", "REQUIRE_SQLALCHEMY_CEXT": ""}, + ): + import setup + + setup_modules = {f"sqlalchemy.{m}" for m in setup.CYTHON_MODULES} + expected = {e.__name__ for e in _all_cython_modules()} + print(expected) + print(setup_modules) + eq_(setup_modules, expected) diff --git a/test/engine/test_processors.py b/test/engine/test_processors.py index 5f28e3ea0e..d49396e99d 100644 --- a/test/engine/test_processors.py +++ b/test/engine/test_processors.py @@ -5,9 +5,11 @@ from types import MappingProxyType from sqlalchemy import exc from sqlalchemy.engine import processors from sqlalchemy.testing import assert_raises_message +from sqlalchemy.testing import combinations from sqlalchemy.testing import eq_ from sqlalchemy.testing import expect_raises_message from sqlalchemy.testing import fixtures +from sqlalchemy.testing import is_none from sqlalchemy.util import immutabledict @@ -33,9 +35,9 @@ class CyBooleanProcessorTest(_BooleanProcessorTest): @classmethod def setup_test_class(cls): - from sqlalchemy.cyextension import processors + from sqlalchemy.engine import _processors_cy - cls.module = processors + cls.module = _processors_cy class _DateProcessorTest(fixtures.TestBase): @@ -72,13 +74,13 @@ class _DateProcessorTest(fixtures.TestBase): eq_(self.module.str_to_date("2022-04-03"), datetime.date(2022, 4, 3)) - def test_date_no_string(self): - assert_raises_message( - TypeError, - "fromisoformat: argument must be str", - self.module.str_to_date, - 2012, - ) + @combinations("str_to_datetime", "str_to_time", "str_to_date") + def test_no_string(self, meth): + with expect_raises_message( + TypeError, "fromisoformat: argument must be str" + ): + fn = getattr(self.module, meth) + fn(2012) def test_datetime_no_string_custom_reg(self): assert_raises_message( @@ -101,37 +103,29 @@ class _DateProcessorTest(fixtures.TestBase): 2012, ) - def test_date_invalid_string(self): - assert_raises_message( - ValueError, - "Invalid isoformat string: '5:a'", - self.module.str_to_date, - "5:a", - ) - - def test_datetime_invalid_string(self): - assert_raises_message( - ValueError, - "Invalid isoformat string: '5:a'", - self.module.str_to_datetime, - "5:a", - ) + @combinations("str_to_datetime", "str_to_time", "str_to_date") + def test_invalid_string(self, meth): + with expect_raises_message( + ValueError, "Invalid isoformat string: '5:a'" + ): + fn = getattr(self.module, meth) + fn("5:a") - def test_time_invalid_string(self): - assert_raises_message( - ValueError, - "Invalid isoformat string: '5:a'", - self.module.str_to_time, - "5:a", - ) + @combinations("str_to_datetime", "str_to_time", "str_to_date") + def test_none(self, meth): + fn = getattr(self.module, meth) + is_none(fn(None)) class PyDateProcessorTest(_DateProcessorTest): @classmethod def setup_test_class(cls): - from sqlalchemy.engine import _py_processors + from sqlalchemy.engine import _processors_cy + from sqlalchemy.util.langhelpers import load_uncompiled_module + + py_mod = load_uncompiled_module(_processors_cy) - cls.module = _py_processors + cls.module = py_mod class CyDateProcessorTest(_DateProcessorTest): @@ -139,9 +133,10 @@ class CyDateProcessorTest(_DateProcessorTest): @classmethod def setup_test_class(cls): - from sqlalchemy.cyextension import processors + from sqlalchemy.engine import _processors_cy - cls.module = processors + assert _processors_cy._is_compiled() + cls.module = _processors_cy class _DistillArgsTest(fixtures.TestBase): @@ -281,8 +276,10 @@ class _DistillArgsTest(fixtures.TestBase): class PyDistillArgsTest(_DistillArgsTest): @classmethod def setup_test_class(cls): - from sqlalchemy.engine import _py_util + from sqlalchemy.engine import _util_cy + from sqlalchemy.util.langhelpers import load_uncompiled_module + _py_util = load_uncompiled_module(_util_cy) cls.module = _py_util @@ -291,6 +288,7 @@ class CyDistillArgsTest(_DistillArgsTest): @classmethod def setup_test_class(cls): - from sqlalchemy.cyextension import util + from sqlalchemy.engine import _util_cy - cls.module = util + assert _util_cy._is_compiled() + cls.module = _util_cy diff --git a/test/perf/compiled_extensions.py b/test/perf/compiled_extensions.py index 0982d96ea7..682496a4a8 100644 --- a/test/perf/compiled_extensions.py +++ b/test/perf/compiled_extensions.py @@ -8,6 +8,7 @@ from types import MappingProxyType from sqlalchemy import bindparam from sqlalchemy import column +from sqlalchemy.util.langhelpers import load_uncompiled_module def test_case(fn=None, *, number=None): @@ -48,7 +49,7 @@ class Case: try: return fn() except Exception as e: - print(f"Error loading {fn}: {e}") + print(f"Error loading {fn}: {e!r}") @classmethod def import_object(cls): @@ -92,7 +93,7 @@ class Case: results = defaultdict(dict) for name, impl in objects: - print(f"Running {name} ", end="", flush=True) + print(f"Running {name:<10} ", end="", flush=True) impl_case = cls(impl) fails = [] for m in methods: @@ -121,9 +122,11 @@ class Case: class ImmutableDict(Case): @staticmethod def python(): - from sqlalchemy.util._py_collections import immutabledict + from sqlalchemy.util import _immutabledict_cy - return immutabledict + py_immutabledict = load_uncompiled_module(_immutabledict_cy) + assert not py_immutabledict._is_compiled() + return py_immutabledict.immutabledict @staticmethod def c(): @@ -133,9 +136,10 @@ class ImmutableDict(Case): @staticmethod def cython(): - from sqlalchemy.cyextension.immutabledict import immutabledict + from sqlalchemy.util import _immutabledict_cy - return immutabledict + assert _immutabledict_cy._is_compiled() + return _immutabledict_cy.immutabledict IMPLEMENTATIONS = { "python": python.__func__, @@ -179,6 +183,7 @@ class ImmutableDict(Case): @test_case def union(self): self.d1.union(self.small) + self.d1.union(self.small.items()) @test_case def union_large(self): @@ -187,6 +192,7 @@ class ImmutableDict(Case): @test_case def merge_with(self): self.d1.merge_with(self.small) + self.d1.merge_with(self.small.items()) @test_case def merge_with_large(self): @@ -263,12 +269,14 @@ class ImmutableDict(Case): self.d1 != "foo" -class Processor(Case): +class Processors(Case): @staticmethod def python(): - from sqlalchemy.engine import processors + from sqlalchemy.engine import _processors_cy - return processors + py_processors = load_uncompiled_module(_processors_cy) + assert not py_processors._is_compiled() + return py_processors @staticmethod def c(): @@ -282,13 +290,10 @@ class Processor(Case): @staticmethod def cython(): - from sqlalchemy.cyextension import processors as mod + from sqlalchemy.engine import _processors_cy - mod.to_decimal_processor_factory = ( - lambda t, s: mod.DecimalResultProcessor(t, "%%.%df" % s).process - ) - - return mod + assert _processors_cy._is_compiled() + return _processors_cy IMPLEMENTATIONS = { "python": python.__func__, @@ -298,10 +303,7 @@ class Processor(Case): NUMBER = 500_000 def init_objects(self): - self.to_dec = self.impl.to_decimal_processor_factory(Decimal, 10) - - self.bytes = token_urlsafe(2048).encode() - self.text = token_urlsafe(2048) + self.to_dec = self.impl.to_decimal_processor_factory(Decimal, 3) @classmethod def update_results(cls, results): @@ -323,6 +325,7 @@ class Processor(Case): self.impl.to_str(123) self.impl.to_str(True) self.impl.to_str(self) + self.impl.to_str("self") @test_case def to_float(self): @@ -332,6 +335,9 @@ class Processor(Case): self.impl.to_float(42) self.impl.to_float(0) self.impl.to_float(42.0) + self.impl.to_float("nan") + self.impl.to_float("42") + self.impl.to_float("42.0") @test_case def str_to_datetime(self): @@ -351,11 +357,16 @@ class Processor(Case): self.impl.str_to_date("2020-01-01") @test_case - def to_decimal(self): - self.to_dec(None) is None + def to_decimal_call(self): + assert self.to_dec(None) is None self.to_dec(123.44) self.to_dec(99) - self.to_dec(99) + self.to_dec(1 / 3) + + @test_case + def to_decimal_pf_make(self): + self.impl.to_decimal_processor_factory(Decimal, 3) + self.impl.to_decimal_processor_factory(Decimal, 7) class DistillParam(Case): @@ -363,15 +374,18 @@ class DistillParam(Case): @staticmethod def python(): - from sqlalchemy.engine import _py_util + from sqlalchemy.engine import _util_cy - return _py_util + py_util = load_uncompiled_module(_util_cy) + assert not py_util._is_compiled() + return py_util @staticmethod def cython(): - from sqlalchemy.cyextension import util as mod + from sqlalchemy.engine import _util_cy - return mod + assert _util_cy._is_compiled() + return _util_cy IMPLEMENTATIONS = { "python": python.__func__, @@ -458,15 +472,18 @@ class IdentitySet(Case): @staticmethod def python(): - from sqlalchemy.util._py_collections import IdentitySet + from sqlalchemy.util import _collections_cy - return IdentitySet + py_coll = load_uncompiled_module(_collections_cy) + assert not py_coll._is_compiled() + return py_coll.IdentitySet @staticmethod def cython(): - from sqlalchemy.cyextension import collections + from sqlalchemy.util import _collections_cy - return collections.IdentitySet + assert _collections_cy._is_compiled() + return _collections_cy.IdentitySet IMPLEMENTATIONS = { "set": set_fn.__func__, @@ -478,7 +495,6 @@ class IdentitySet(Case): def init_objects(self): self.val1 = list(range(10)) self.val2 = list(wrap(token_urlsafe(4 * 2048), 4)) - self.imp_1 = self.impl(self.val1) self.imp_2 = self.impl(self.val2) @@ -488,45 +504,41 @@ class IdentitySet(Case): cls._divide_results(results, "cython", "python", "cy / py") cls._divide_results(results, "cython", "set", "cy / set") - @test_case + @test_case(number=2_500_000) def init_empty(self): - i = self.impl - for _ in range(10000): - i() + self.impl() - @test_case + @test_case(number=2_500) def init(self): - i, v = self.impl, self.val2 - for _ in range(500): - i(v) + self.impl(self.val1) + self.impl(self.val2) - @test_case + @test_case(number=5_000) def init_from_impl(self): - for _ in range(500): - self.impl(self.imp_2) + self.impl(self.imp_2) - @test_case + @test_case(number=100) def add(self): ii = self.impl() - for _ in range(10): - for i in range(1000): - ii.add(str(i)) + x = 25_000 + for i in range(x): + ii.add(str(i % (x / 2))) @test_case def contains(self): ii = self.impl(self.val2) - for _ in range(500): + for _ in range(1_000): for x in self.val1 + self.val2: x in ii - @test_case + @test_case(number=200) def remove(self): v = [str(i) for i in range(7500)] ii = self.impl(v) for x in v[:5000]: ii.remove(x) - @test_case + @test_case(number=200) def discard(self): v = [str(i) for i in range(7500)] ii = self.impl(v) @@ -535,7 +547,7 @@ class IdentitySet(Case): @test_case def pop(self): - for x in range(1000): + for x in range(50_000): ii = self.impl(self.val1) for x in self.val1: ii.pop() @@ -543,152 +555,137 @@ class IdentitySet(Case): @test_case def clear(self): i, v = self.impl, self.val1 - for _ in range(5000): + for _ in range(125_000): ii = i(v) ii.clear() - @test_case + @test_case(number=2_500_000) def eq(self): - for x in range(1000): - self.imp_1 == self.imp_1 - self.imp_1 == self.imp_2 - self.imp_1 == self.val2 + self.imp_1 == self.imp_1 + self.imp_1 == self.imp_2 + self.imp_1 == self.val2 - @test_case + @test_case(number=2_500_000) def ne(self): - for x in range(1000): - self.imp_1 != self.imp_1 - self.imp_1 != self.imp_2 - self.imp_1 != self.val2 + self.imp_1 != self.imp_1 + self.imp_1 != self.imp_2 + self.imp_1 != self.val2 - @test_case + @test_case(number=20_000) def issubset(self): - for _ in range(250): - self.imp_1.issubset(self.imp_1) - self.imp_1.issubset(self.imp_2) - self.imp_1.issubset(self.val1) - self.imp_1.issubset(self.val2) + self.imp_1.issubset(self.imp_1) + self.imp_1.issubset(self.imp_2) + self.imp_1.issubset(self.val1) + self.imp_1.issubset(self.val2) - @test_case + @test_case(number=50_000) def le(self): - for x in range(1000): - self.imp_1 <= self.imp_1 - self.imp_1 <= self.imp_2 - self.imp_2 <= self.imp_1 - self.imp_2 <= self.imp_2 + self.imp_1 <= self.imp_1 + self.imp_1 <= self.imp_2 + self.imp_2 <= self.imp_1 + self.imp_2 <= self.imp_2 - @test_case + @test_case(number=2_500_000) def lt(self): - for x in range(2500): - self.imp_1 < self.imp_1 - self.imp_1 < self.imp_2 - self.imp_2 < self.imp_1 - self.imp_2 < self.imp_2 + self.imp_1 < self.imp_1 + self.imp_1 < self.imp_2 + self.imp_2 < self.imp_1 + self.imp_2 < self.imp_2 - @test_case + @test_case(number=20_000) def issuperset(self): - for _ in range(250): - self.imp_1.issuperset(self.imp_1) - self.imp_1.issuperset(self.imp_2) - self.imp_1.issubset(self.val1) - self.imp_1.issubset(self.val2) + self.imp_1.issuperset(self.imp_1) + self.imp_1.issuperset(self.imp_2) + self.imp_1.issubset(self.val1) + self.imp_1.issubset(self.val2) - @test_case + @test_case(number=50_000) def ge(self): - for x in range(1000): - self.imp_1 >= self.imp_1 - self.imp_1 >= self.imp_2 - self.imp_2 >= self.imp_1 - self.imp_2 >= self.imp_2 + self.imp_1 >= self.imp_1 + self.imp_1 >= self.imp_2 + self.imp_2 >= self.imp_1 + self.imp_2 >= self.imp_2 - @test_case + @test_case(number=2_500_000) def gt(self): - for x in range(2500): - self.imp_1 > self.imp_1 - self.imp_2 > self.imp_2 - self.imp_2 > self.imp_1 - self.imp_2 > self.imp_2 + self.imp_1 > self.imp_1 + self.imp_2 > self.imp_2 + self.imp_2 > self.imp_1 + self.imp_2 > self.imp_2 - @test_case + @test_case(number=10_000) def union(self): - for _ in range(250): - self.imp_1.union(self.imp_2) + self.imp_1.union(self.imp_2) - @test_case + @test_case(number=10_000) def or_test(self): - for _ in range(250): - self.imp_1 | self.imp_2 + self.imp_1 | self.imp_2 @test_case def update(self): ii = self.impl(self.val1) - for _ in range(250): + for _ in range(1_000): ii.update(self.imp_2) @test_case def ior(self): ii = self.impl(self.val1) - for _ in range(250): + for _ in range(1_000): ii |= self.imp_2 @test_case def difference(self): - for _ in range(250): + for _ in range(2_500): self.imp_1.difference(self.imp_2) self.imp_1.difference(self.val2) - @test_case + @test_case(number=250_000) def sub(self): - for _ in range(500): - self.imp_1 - self.imp_2 + self.imp_1 - self.imp_2 @test_case def difference_update(self): ii = self.impl(self.val1) - for _ in range(250): + for _ in range(2_500): ii.difference_update(self.imp_2) ii.difference_update(self.val2) @test_case def isub(self): ii = self.impl(self.val1) - for _ in range(500): + for _ in range(250_000): ii -= self.imp_2 - @test_case + @test_case(number=20_000) def intersection(self): - for _ in range(250): - self.imp_1.intersection(self.imp_2) - self.imp_1.intersection(self.val2) + self.imp_1.intersection(self.imp_2) + self.imp_1.intersection(self.val2) - @test_case + @test_case(number=250_000) def and_test(self): - for _ in range(500): - self.imp_1 & self.imp_2 + self.imp_1 & self.imp_2 @test_case def intersection_up(self): ii = self.impl(self.val1) - for _ in range(250): + for _ in range(2_500): ii.intersection_update(self.imp_2) ii.intersection_update(self.val2) @test_case def iand(self): ii = self.impl(self.val1) - for _ in range(500): + for _ in range(250_000): ii &= self.imp_2 - @test_case + @test_case(number=2_500) def symmetric_diff(self): - for _ in range(125): - self.imp_1.symmetric_difference(self.imp_2) - self.imp_1.symmetric_difference(self.val2) + self.imp_1.symmetric_difference(self.imp_2) + self.imp_1.symmetric_difference(self.val2) - @test_case + @test_case(number=2_500) def xor(self): - for _ in range(250): - self.imp_1 ^ self.imp_2 + self.imp_1 ^ self.imp_2 @test_case def symmetric_diff_up(self): @@ -703,29 +700,25 @@ class IdentitySet(Case): for _ in range(250): ii ^= self.imp_2 - @test_case + @test_case(number=25_000) def copy(self): - for _ in range(250): - self.imp_1.copy() - self.imp_2.copy() + self.imp_1.copy() + self.imp_2.copy() - @test_case + @test_case(number=2_500_000) def len(self): - for x in range(5000): - len(self.imp_1) - len(self.imp_2) + len(self.imp_1) + len(self.imp_2) - @test_case + @test_case(number=25_000) def iter(self): - for _ in range(2000): - list(self.imp_1) - list(self.imp_2) + list(self.imp_1) + list(self.imp_2) - @test_case + @test_case(number=10_000) def repr(self): - for _ in range(250): - str(self.imp_1) - str(self.imp_2) + str(self.imp_1) + str(self.imp_2) class OrderedSet(IdentitySet): @@ -735,15 +728,18 @@ class OrderedSet(IdentitySet): @staticmethod def python(): - from sqlalchemy.util._py_collections import OrderedSet + from sqlalchemy.util import _collections_cy - return OrderedSet + py_coll = load_uncompiled_module(_collections_cy) + assert not py_coll._is_compiled() + return py_coll.OrderedSet @staticmethod def cython(): - from sqlalchemy.cyextension import collections + from sqlalchemy.util import _collections_cy - return collections.OrderedSet + assert _collections_cy._is_compiled() + return _collections_cy.OrderedSet @staticmethod def ordered_lib(): @@ -768,22 +764,87 @@ class OrderedSet(IdentitySet): def add_op(self): ii = self.impl(self.val1) v2 = self.impl(self.val2) - for _ in range(1000): + for _ in range(500): ii + v2 @test_case def getitem(self): ii = self.impl(self.val1) - for _ in range(1000): + for _ in range(250_000): for i in range(len(self.val1)): ii[i] @test_case def insert(self): - ii = self.impl(self.val1) for _ in range(5): - for i in range(1000): - ii.insert(-i % 2, 1) + ii = self.impl(self.val1) + for i in range(5_000): + ii.insert(i // 2, i) + ii.insert(-i % 2, i) + + +class UniqueList(Case): + @staticmethod + def python(): + from sqlalchemy.util import _collections_cy + + py_coll = load_uncompiled_module(_collections_cy) + assert not py_coll._is_compiled() + return py_coll.unique_list + + @staticmethod + def cython(): + from sqlalchemy.util import _collections_cy + + assert _collections_cy._is_compiled() + return _collections_cy.unique_list + + IMPLEMENTATIONS = { + "python": python.__func__, + "cython": cython.__func__, + } + + @classmethod + def update_results(cls, results): + cls._divide_results(results, "cython", "python", "cy / py") + + def init_objects(self): + self.int_small = list(range(10)) + self.int_vlarge = list(range(25_000)) * 2 + d = wrap(token_urlsafe(100 * 2048), 4) + assert len(d) > 50_000 + self.vlarge = d[:50_000] + self.large = d[:500] + self.small = d[:15] + + @test_case + def small_str(self): + self.impl(self.small) + + @test_case(number=50_000) + def large_str(self): + self.impl(self.large) + + @test_case(number=250) + def vlarge_str(self): + self.impl(self.vlarge) + + @test_case + def small_range(self): + self.impl(range(10)) + + @test_case + def small_int(self): + self.impl(self.int_small) + + @test_case(number=25_000) + def large_int(self): + self.impl([1, 1, 1, 2, 3] * 100) + self.impl(range(1000)) + + @test_case(number=250) + def vlarge_int(self): + self.impl(self.int_vlarge) class TupleGetter(Case): @@ -791,9 +852,11 @@ class TupleGetter(Case): @staticmethod def python(): - from sqlalchemy.engine._py_row import tuplegetter + from sqlalchemy.engine import _util_cy - return tuplegetter + py_util = load_uncompiled_module(_util_cy) + assert not py_util._is_compiled() + return py_util.tuplegetter @staticmethod def c(): @@ -803,9 +866,10 @@ class TupleGetter(Case): @staticmethod def cython(): - from sqlalchemy.cyextension import resultproxy + from sqlalchemy.engine import _util_cy - return resultproxy.tuplegetter + assert _util_cy._is_compiled() + return _util_cy.tuplegetter IMPLEMENTATIONS = { "python": python.__func__, @@ -855,9 +919,11 @@ class TupleGetter(Case): class BaseRow(Case): @staticmethod def python(): - from sqlalchemy.engine._py_row import BaseRow + from sqlalchemy.engine import _row_cy - return BaseRow + py_res = load_uncompiled_module(_row_cy) + assert not py_res._is_compiled() + return py_res.BaseRow @staticmethod def c(): @@ -867,9 +933,10 @@ class BaseRow(Case): @staticmethod def cython(): - from sqlalchemy.cyextension import resultproxy + from sqlalchemy.engine import _row_cy - return resultproxy.BaseRow + assert _row_cy._is_compiled() + return _row_cy.BaseRow IMPLEMENTATIONS = { "python": python.__func__, @@ -909,9 +976,11 @@ class BaseRow(Case): self.row_long_state = self.row_long.__getstate__() assert len(ascii_letters) == 52 + _proc = [None, int, float, None, str] * 10 + _proc += [int, float] self.parent_proc = SimpleResultMetaData( tuple(ascii_letters), - _processors=[None, int, float, None, str] * 10, # cut the last 2 + _processors=_proc, ) self.row_proc_args = ( self.parent_proc, @@ -1024,7 +1093,7 @@ class BaseRow(Case): self.row_long.x self.row_long.y - @test_case(number=50_000) + @test_case(number=25_000) def get_by_key_recreate(self): self.init_objects() row = self.row @@ -1041,7 +1110,7 @@ class BaseRow(Case): l_row._get_by_key_impl_mapping("w") l_row._get_by_key_impl_mapping("o") - @test_case(number=50_000) + @test_case(number=10_000) def getattr_recreate(self): self.init_objects() row = self.row @@ -1059,18 +1128,21 @@ class BaseRow(Case): l_row.o -class CacheAnonMap(Case): +class AnonMap(Case): @staticmethod def python(): - from sqlalchemy.sql._py_util import cache_anon_map + from sqlalchemy.sql import _util_cy - return cache_anon_map + py_util = load_uncompiled_module(_util_cy) + assert not py_util._is_compiled() + return py_util.anon_map @staticmethod def cython(): - from sqlalchemy.cyextension.util import cache_anon_map + from sqlalchemy.sql import _util_cy - return cache_anon_map + assert _util_cy._is_compiled() + return _util_cy.anon_map IMPLEMENTATIONS = {"python": python.__func__, "cython": cython.__func__} @@ -1090,34 +1162,41 @@ class CacheAnonMap(Case): cls._divide_results(results, "cython", "python", "cy / py") @test_case - def test_get_anon_non_present(self): + def test_make(self): + self.impl() + + @test_case + def test_get_anon_np(self): self.impl_w_non_present.get_anon(self.object_1) @test_case - def test_get_anon_present(self): + def test_get_anon_p(self): self.impl_w_present.get_anon(self.object_1) @test_case - def test_has_key_non_present(self): + def test_has_key_np(self): id(self.object_1) in self.impl_w_non_present @test_case - def test_has_key_present(self): + def test_has_key_p(self): id(self.object_1) in self.impl_w_present class PrefixAnonMap(Case): @staticmethod def python(): - from sqlalchemy.sql._py_util import prefix_anon_map + from sqlalchemy.sql import _util_cy - return prefix_anon_map + py_util = load_uncompiled_module(_util_cy) + assert not py_util._is_compiled() + return py_util.prefix_anon_map @staticmethod def cython(): - from sqlalchemy.cyextension.util import prefix_anon_map + from sqlalchemy.sql import _util_cy - return prefix_anon_map + assert _util_cy._is_compiled() + return _util_cy.prefix_anon_map IMPLEMENTATIONS = {"python": python.__func__, "cython": cython.__func__} @@ -1137,11 +1216,15 @@ class PrefixAnonMap(Case): cls._divide_results(results, "cython", "python", "cy / py") @test_case - def test_apply_non_present(self): + def test_make(self): + self.impl() + + @test_case + def test_apply_np(self): self.name.apply_map(self.impl_w_non_present) @test_case - def test_apply_present(self): + def test_apply_p(self): self.name.apply_map(self.impl_w_present) diff --git a/test/profiles.txt b/test/profiles.txt index d8226f4a89..d1549bf947 100644 --- a/test/profiles.txt +++ b/test/profiles.txt @@ -134,7 +134,7 @@ test.aaa_profiling.test_misc.CacheKeyTest.test_statement_key_is_cached x86_64_li # TEST: test.aaa_profiling.test_misc.CacheKeyTest.test_statement_key_is_not_cached test.aaa_profiling.test_misc.CacheKeyTest.test_statement_key_is_not_cached x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_cextensions 4003 -test.aaa_profiling.test_misc.CacheKeyTest.test_statement_key_is_not_cached x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 6103 +test.aaa_profiling.test_misc.CacheKeyTest.test_statement_key_is_not_cached x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 7503 # TEST: test.aaa_profiling.test_misc.EnumTest.test_create_enum_from_pep_435_w_expensive_members @@ -387,7 +387,7 @@ test.aaa_profiling.test_resultset.ResultSetTest.test_fetch_by_key_mappings x86_6 test.aaa_profiling.test_resultset.ResultSetTest.test_fetch_by_key_mappings x86_64_linux_cpython_3.11_postgresql_psycopg2_dbapiunicode_cextensions 2649 test.aaa_profiling.test_resultset.ResultSetTest.test_fetch_by_key_mappings x86_64_linux_cpython_3.11_postgresql_psycopg2_dbapiunicode_nocextensions 14656 test.aaa_profiling.test_resultset.ResultSetTest.test_fetch_by_key_mappings x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_cextensions 2614 -test.aaa_profiling.test_resultset.ResultSetTest.test_fetch_by_key_mappings x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 14621 +test.aaa_profiling.test_resultset.ResultSetTest.test_fetch_by_key_mappings x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 36612 # TEST: test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-0] @@ -413,7 +413,7 @@ test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-1] x86_64 test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-1] x86_64_linux_cpython_3.11_postgresql_psycopg2_dbapiunicode_cextensions 14 test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-1] x86_64_linux_cpython_3.11_postgresql_psycopg2_dbapiunicode_nocextensions 15 test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-1] x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_cextensions 14 -test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-1] x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 15 +test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-1] x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 16 # TEST: test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-2] @@ -426,7 +426,7 @@ test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-2] x86_64 test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-2] x86_64_linux_cpython_3.11_postgresql_psycopg2_dbapiunicode_cextensions 14 test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-2] x86_64_linux_cpython_3.11_postgresql_psycopg2_dbapiunicode_nocextensions 15 test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-2] x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_cextensions 14 -test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-2] x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 15 +test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-2] x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 16 # TEST: test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[True-1] @@ -439,7 +439,7 @@ test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[True-1] x86_64_ test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[True-1] x86_64_linux_cpython_3.11_postgresql_psycopg2_dbapiunicode_cextensions 17 test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[True-1] x86_64_linux_cpython_3.11_postgresql_psycopg2_dbapiunicode_nocextensions 18 test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[True-1] x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_cextensions 17 -test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[True-1] x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 18 +test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[True-1] x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 19 # TEST: test.aaa_profiling.test_resultset.ResultSetTest.test_raw_string @@ -452,7 +452,7 @@ test.aaa_profiling.test_resultset.ResultSetTest.test_raw_string x86_64_linux_cpy test.aaa_profiling.test_resultset.ResultSetTest.test_raw_string x86_64_linux_cpython_3.11_postgresql_psycopg2_dbapiunicode_cextensions 299 test.aaa_profiling.test_resultset.ResultSetTest.test_raw_string x86_64_linux_cpython_3.11_postgresql_psycopg2_dbapiunicode_nocextensions 5301 test.aaa_profiling.test_resultset.ResultSetTest.test_raw_string x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_cextensions 272 -test.aaa_profiling.test_resultset.ResultSetTest.test_raw_string x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 5274 +test.aaa_profiling.test_resultset.ResultSetTest.test_raw_string x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 6272 # TEST: test.aaa_profiling.test_resultset.ResultSetTest.test_raw_unicode @@ -465,7 +465,7 @@ test.aaa_profiling.test_resultset.ResultSetTest.test_raw_unicode x86_64_linux_cp test.aaa_profiling.test_resultset.ResultSetTest.test_raw_unicode x86_64_linux_cpython_3.11_postgresql_psycopg2_dbapiunicode_cextensions 299 test.aaa_profiling.test_resultset.ResultSetTest.test_raw_unicode x86_64_linux_cpython_3.11_postgresql_psycopg2_dbapiunicode_nocextensions 5301 test.aaa_profiling.test_resultset.ResultSetTest.test_raw_unicode x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_cextensions 272 -test.aaa_profiling.test_resultset.ResultSetTest.test_raw_unicode x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 5274 +test.aaa_profiling.test_resultset.ResultSetTest.test_raw_unicode x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 6272 # TEST: test.aaa_profiling.test_resultset.ResultSetTest.test_string @@ -478,7 +478,7 @@ test.aaa_profiling.test_resultset.ResultSetTest.test_string x86_64_linux_cpython test.aaa_profiling.test_resultset.ResultSetTest.test_string x86_64_linux_cpython_3.11_postgresql_psycopg2_dbapiunicode_cextensions 640 test.aaa_profiling.test_resultset.ResultSetTest.test_string x86_64_linux_cpython_3.11_postgresql_psycopg2_dbapiunicode_nocextensions 5647 test.aaa_profiling.test_resultset.ResultSetTest.test_string x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_cextensions 605 -test.aaa_profiling.test_resultset.ResultSetTest.test_string x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 5612 +test.aaa_profiling.test_resultset.ResultSetTest.test_string x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 6603 # TEST: test.aaa_profiling.test_resultset.ResultSetTest.test_unicode @@ -491,4 +491,4 @@ test.aaa_profiling.test_resultset.ResultSetTest.test_unicode x86_64_linux_cpytho test.aaa_profiling.test_resultset.ResultSetTest.test_unicode x86_64_linux_cpython_3.11_postgresql_psycopg2_dbapiunicode_cextensions 640 test.aaa_profiling.test_resultset.ResultSetTest.test_unicode x86_64_linux_cpython_3.11_postgresql_psycopg2_dbapiunicode_nocextensions 5647 test.aaa_profiling.test_resultset.ResultSetTest.test_unicode x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_cextensions 605 -test.aaa_profiling.test_resultset.ResultSetTest.test_unicode x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 5612 +test.aaa_profiling.test_resultset.ResultSetTest.test_unicode x86_64_linux_cpython_3.11_sqlite_pysqlite_dbapiunicode_nocextensions 6603 diff --git a/tools/cython_imports.py b/tools/cython_imports.py new file mode 100644 index 0000000000..4e7a425da5 --- /dev/null +++ b/tools/cython_imports.py @@ -0,0 +1,73 @@ +from pathlib import Path +import re + + +from sqlalchemy.util.tool_support import code_writer_cmd + +sa_path = Path(__file__).parent.parent / "lib/sqlalchemy" + + +section_re = re.compile( + r"^# START GENERATED CYTHON IMPORT$\n(.*)\n" + r"^# END GENERATED CYTHON IMPORT$", + re.MULTILINE | re.DOTALL, +) +# start = re.compile("^# START GENERATED CYTHON IMPORT$") +# end = re.compile("^# END GENERATED CYTHON IMPORT$") +code = '''\ +# START GENERATED CYTHON IMPORT +# This section is automatically generated by the script tools/cython_imports.py +try: + # NOTE: the cython compiler needs this "import cython" in the file, it + # can't be only "from sqlalchemy.util import cython" with the fallback + # in that module + import cython +except ModuleNotFoundError: + from sqlalchemy.util import cython + + +def _is_compiled() -> bool: + """Utility function to indicate if this module is compiled or not.""" + return cython.compiled # type: ignore[no-any-return] + + +# END GENERATED CYTHON IMPORT\ +''' + + +def run_file(cmd: code_writer_cmd, file: Path): + content = file.read_text("utf-8") + count = 0 + + def repl_fn(match): + nonlocal count + count += 1 + return code + + content = section_re.sub(repl_fn, content) + if count == 0: + raise ValueError( + "Expected to find comment '# START GENERATED CYTHON IMPORT' " + f"in cython file {file}, but none found" + ) + if count > 1: + raise ValueError( + "Expected to find a single comment '# START GENERATED CYTHON " + f"IMPORT' in cython file {file}, but {count} found" + ) + cmd.write_output_file_from_text(content, file) + + +def run(cmd: code_writer_cmd): + i = 0 + for file in sa_path.glob(f"**/*_cy.py"): + run_file(cmd, file) + i += 1 + cmd.write_status(f"\nDone. Processed {i} files.") + + +if __name__ == "__main__": + cmd = code_writer_cmd(__file__) + + with cmd.run_program(): + run(cmd) diff --git a/tox.ini b/tox.ini index 22446bb844..14a873844c 100644 --- a/tox.ini +++ b/tox.ini @@ -241,15 +241,14 @@ commands = # run flake8-unused-arguments only on some files / modules flake8 --extend-ignore='' ./lib/sqlalchemy/ext/asyncio ./lib/sqlalchemy/orm/scoping.py black --check ./lib/ ./test/ ./examples/ setup.py doc/build/conf.py - # test with cython and without cython exts running slotscheck -m sqlalchemy - env DISABLE_SQLALCHEMY_CEXT_RUNTIME=1 slotscheck -m sqlalchemy python ./tools/format_docs_code.py --check python ./tools/generate_tuple_map_overloads.py --check python ./tools/generate_proxy_methods.py --check python ./tools/sync_test_files.py --check python ./tools/generate_sql_functions.py --check python ./tools/normalize_file_headers.py --check + python ./tools/cython_imports.py --check python ./tools/walk_packages.py -- 2.47.2