From: Federico Caselli Date: Thu, 18 Feb 2021 18:12:45 +0000 (+0100) Subject: Minor optimization to the code X-Git-Tag: rel_1_4_0~37^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=629273a31b1be9ed195e9082d40b8741ab31e073;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Minor optimization to the code * remove the c version of distill params since it's actually slower than the python one * add a function to langhelpers to check if the cextensions are active * minor cleanup to the OrderedSet implementation Change-Id: Iec3d0c3f0f42cdf51f802aaca342ba37b8783b85 --- diff --git a/.github/workflows/create-wheels.yaml b/.github/workflows/create-wheels.yaml index b11a14619c..e6fd1e3bd5 100644 --- a/.github/workflows/create-wheels.yaml +++ b/.github/workflows/create-wheels.yaml @@ -87,7 +87,7 @@ jobs: # for python 2.7 visual studio 9 is missing if: matrix.os != 'windows-latest' || matrix.python-version != '2.7' run: | - python -c 'from sqlalchemy import cprocessors, cresultproxy, cutils' + python -c 'from sqlalchemy.util import has_compiled_ext; assert has_compiled_ext()' - name: Test created wheel # the mock reconnect test seems to fail on the ci in windows @@ -222,7 +222,7 @@ jobs: then pip install greenlet "importlib-metadata;python_version<'3.8'" pip install -f dist --no-index sqlalchemy - python -c 'from sqlalchemy import cprocessors, cresultproxy, cutils' + python -c 'from sqlalchemy.util import has_compiled_ext; assert has_compiled_ext()' pip install pytest pytest-xdist ${{ matrix.extra-requires }} pytest -n2 -q test --nomemory --notimingintensive else @@ -362,7 +362,7 @@ jobs: python --version && pip install greenlet \"importlib-metadata;python_version<'3.8'\" && pip install -f dist --no-index sqlalchemy && - python -c 'from sqlalchemy import cprocessors, cresultproxy, cutils' && + python -c 'from sqlalchemy.util import has_compiled_ext; assert has_compiled_ext()' && pip install pytest pytest-xdist ${{ matrix.extra-requires }} && pytest -n2 -q test --nomemory --notimingintensive" diff --git a/lib/sqlalchemy/cextension/utils.c b/lib/sqlalchemy/cextension/utils.c deleted file mode 100644 index e06843c9de..0000000000 --- a/lib/sqlalchemy/cextension/utils.c +++ /dev/null @@ -1,249 +0,0 @@ -/* -utils.c -Copyright (C) 2012-2021 the SQLAlchemy authors and contributors - -This module is part of SQLAlchemy and is released under -the MIT License: http://www.opensource.org/licenses/mit-license.php -*/ - -#include - -#define MODULE_NAME "cutils" -#define MODULE_DOC "Module containing C versions of utility functions." - -/* - Given arguments from the calling form *multiparams, **params, - return a list of bind parameter structures, usually a list of - dictionaries. - - In the case of 'raw' execution which accepts positional parameters, - it may be a list of tuples or lists. - - */ -static PyObject * -distill_params(PyObject *self, PyObject *args) -{ - // TODO: pass the Connection in so that there can be a standard - // method for warning on parameter format - - PyObject *connection, *multiparams, *params; - PyObject *enclosing_list, *double_enclosing_list; - PyObject *zero_element, *zero_element_item; - PyObject *tmp; - Py_ssize_t multiparam_size, zero_element_length; - - if (!PyArg_UnpackTuple(args, "_distill_params", 3, 3, &connection, &multiparams, ¶ms)) { - return NULL; - } - - if (multiparams != Py_None) { - multiparam_size = PyTuple_Size(multiparams); - if (multiparam_size < 0) { - return NULL; - } - } - else { - multiparam_size = 0; - } - - if (multiparam_size == 0) { - if (params != Py_None && PyMapping_Size(params) != 0) { - - tmp = PyObject_CallMethod(connection, "_warn_for_legacy_exec_format", ""); - if (tmp == NULL) { - return NULL; - } - - enclosing_list = PyList_New(1); - if (enclosing_list == NULL) { - return NULL; - } - Py_INCREF(params); - if (PyList_SetItem(enclosing_list, 0, params) == -1) { - Py_DECREF(params); - Py_DECREF(enclosing_list); - return NULL; - } - } - else { - enclosing_list = PyList_New(0); - if (enclosing_list == NULL) { - return NULL; - } - } - return enclosing_list; - } - else if (multiparam_size == 1) { - zero_element = PyTuple_GetItem(multiparams, 0); - if (PyTuple_Check(zero_element) || PyList_Check(zero_element)) { - zero_element_length = PySequence_Length(zero_element); - - if (zero_element_length != 0) { - zero_element_item = PySequence_GetItem(zero_element, 0); - if (zero_element_item == NULL) { - return NULL; - } - } - else { - zero_element_item = NULL; - } - - if (zero_element_length == 0 || - ( - PyObject_HasAttrString(zero_element_item, "__iter__") && - !PyObject_HasAttrString(zero_element_item, "strip") - ) - ) { - /* - * execute(stmt, [{}, {}, {}, ...]) - * execute(stmt, [(), (), (), ...]) - */ - Py_XDECREF(zero_element_item); - Py_INCREF(zero_element); - return zero_element; - } - else { - /* - * execute(stmt, ("value", "value")) - */ - Py_XDECREF(zero_element_item); - - enclosing_list = PyList_New(1); - if (enclosing_list == NULL) { - return NULL; - } - Py_INCREF(zero_element); - if (PyList_SetItem(enclosing_list, 0, zero_element) == -1) { - Py_DECREF(zero_element); - Py_DECREF(enclosing_list); - return NULL; - } - return enclosing_list; - } - } - else if (PyObject_HasAttrString(zero_element, "keys")) { - /* - * execute(stmt, {"key":"value"}) - */ - enclosing_list = PyList_New(1); - if (enclosing_list == NULL) { - return NULL; - } - Py_INCREF(zero_element); - if (PyList_SetItem(enclosing_list, 0, zero_element) == -1) { - Py_DECREF(zero_element); - Py_DECREF(enclosing_list); - return NULL; - } - return enclosing_list; - } else { - tmp = PyObject_CallMethod(connection, "_warn_for_legacy_exec_format", ""); - if (tmp == NULL) { - return NULL; - } - - enclosing_list = PyList_New(1); - if (enclosing_list == NULL) { - return NULL; - } - double_enclosing_list = PyList_New(1); - if (double_enclosing_list == NULL) { - Py_DECREF(enclosing_list); - return NULL; - } - Py_INCREF(zero_element); - if (PyList_SetItem(enclosing_list, 0, zero_element) == -1) { - Py_DECREF(zero_element); - Py_DECREF(enclosing_list); - Py_DECREF(double_enclosing_list); - return NULL; - } - if (PyList_SetItem(double_enclosing_list, 0, enclosing_list) == -1) { - Py_DECREF(zero_element); - Py_DECREF(enclosing_list); - Py_DECREF(double_enclosing_list); - return NULL; - } - return double_enclosing_list; - } - } - else { - - tmp = PyObject_CallMethod(connection, "_warn_for_legacy_exec_format", ""); - if (tmp == NULL) { - return NULL; - } - - zero_element = PyTuple_GetItem(multiparams, 0); - if (PyObject_HasAttrString(zero_element, "__iter__") && - !PyObject_HasAttrString(zero_element, "strip") - ) { - Py_INCREF(multiparams); - return multiparams; - } - else { - enclosing_list = PyList_New(1); - if (enclosing_list == NULL) { - return NULL; - } - Py_INCREF(multiparams); - if (PyList_SetItem(enclosing_list, 0, multiparams) == -1) { - Py_DECREF(multiparams); - Py_DECREF(enclosing_list); - return NULL; - } - return enclosing_list; - } - } -} - -static PyMethodDef module_methods[] = { - {"_distill_params", distill_params, METH_VARARGS, - "Distill an execute() parameter structure."}, - {NULL, NULL, 0, NULL} /* Sentinel */ -}; - -#ifndef PyMODINIT_FUNC /* declarations for DLL import/export */ -#define PyMODINIT_FUNC void -#endif - -#if PY_MAJOR_VERSION >= 3 - -#define INITERROR return NULL - -static struct PyModuleDef module_def = { - PyModuleDef_HEAD_INIT, - MODULE_NAME, - MODULE_DOC, - -1, - module_methods - }; - -PyMODINIT_FUNC -PyInit_cutils(void) - -#else - -#define INITERROR return - -PyMODINIT_FUNC -initcutils(void) - -#endif - -{ - PyObject *m; - -#if PY_MAJOR_VERSION >= 3 - m = PyModule_Create(&module_def); -#else - m = Py_InitModule3(MODULE_NAME, module_methods, MODULE_DOC); -#endif - if (m == NULL) - INITERROR; - -#if PY_MAJOR_VERSION >= 3 - return m; -#endif -} - diff --git a/lib/sqlalchemy/engine/util.py b/lib/sqlalchemy/engine/util.py index 4e302f464f..ede2631985 100644 --- a/lib/sqlalchemy/engine/util.py +++ b/lib/sqlalchemy/engine/util.py @@ -34,63 +34,55 @@ _no_tuple = () _no_kw = util.immutabledict() -def py_fallback(): - # TODO: pass the Connection in so that there can be a standard - # method for warning on parameter format - def _distill_params(connection, multiparams, params): # noqa - r"""Given arguments from the calling form \*multiparams, \**params, - return a list of bind parameter structures, usually a list of - dictionaries. - - In the case of 'raw' execution which accepts positional parameters, - it may be a list of tuples or lists. - - """ - - # C version will fail if this assertion is not true. - # assert isinstance(multiparams, tuple) - - if not multiparams: - if params: - connection._warn_for_legacy_exec_format() - return [params] +def _distill_params(connection, multiparams, params): + r"""Given arguments from the calling form \*multiparams, \**params, + return a list of bind parameter structures, usually a list of + dictionaries. + + In the case of 'raw' execution which accepts positional parameters, + it may be a list of tuples or lists. + + """ + + if not multiparams: + if params: + connection._warn_for_legacy_exec_format() + return [params] + else: + return [] + elif len(multiparams) == 1: + zero = multiparams[0] + if isinstance(zero, (list, tuple)): + if ( + not zero + or hasattr(zero[0], "__iter__") + and not hasattr(zero[0], "strip") + ): + # execute(stmt, [{}, {}, {}, ...]) + # execute(stmt, [(), (), (), ...]) + return zero else: - return [] - elif len(multiparams) == 1: - zero = multiparams[0] - if isinstance(zero, (list, tuple)): - if ( - not zero - or hasattr(zero[0], "__iter__") - and not hasattr(zero[0], "strip") - ): - # execute(stmt, [{}, {}, {}, ...]) - # execute(stmt, [(), (), (), ...]) - return zero - else: - # this is used by exec_driver_sql only, so a deprecation - # warning would already be coming from passing a plain - # textual statement with positional parameters to - # execute(). - # execute(stmt, ("value", "value")) - return [zero] - elif hasattr(zero, "keys"): - # execute(stmt, {"key":"value"}) + # this is used by exec_driver_sql only, so a deprecation + # warning would already be coming from passing a plain + # textual statement with positional parameters to + # execute(). + # execute(stmt, ("value", "value")) return [zero] - else: - connection._warn_for_legacy_exec_format() - # execute(stmt, "value") - return [[zero]] + elif hasattr(zero, "keys"): + # execute(stmt, {"key":"value"}) + return [zero] else: connection._warn_for_legacy_exec_format() - if hasattr(multiparams[0], "__iter__") and not hasattr( - multiparams[0], "strip" - ): - return multiparams - else: - return [multiparams] - - return locals() + # execute(stmt, "value") + return [[zero]] + else: + connection._warn_for_legacy_exec_format() + if hasattr(multiparams[0], "__iter__") and not hasattr( + multiparams[0], "strip" + ): + return multiparams + else: + return [multiparams] def _distill_cursor_params(connection, multiparams, params): @@ -161,9 +153,3 @@ def _distill_params_20(params): return (params,), _no_kw else: raise exc.ArgumentError("mapping or sequence expected for parameters") - - -try: - from sqlalchemy.cutils import _distill_params # noqa -except ImportError: - globals().update(py_fallback()) diff --git a/lib/sqlalchemy/testing/profiling.py b/lib/sqlalchemy/testing/profiling.py index 16c6d458c9..5e4f192736 100644 --- a/lib/sqlalchemy/testing/profiling.py +++ b/lib/sqlalchemy/testing/profiling.py @@ -22,6 +22,7 @@ import sys from . import config from .util import gc_collect +from ..util import has_compiled_ext try: @@ -109,7 +110,7 @@ class ProfileStatsFile(object): if config.db.dialect.convert_unicode else "dbapiunicode" ) - _has_cext = config.requirements._has_cextensions() + _has_cext = has_compiled_ext() platform_tokens.append(_has_cext and "cextensions" or "nocextensions") return "_".join(platform_tokens) diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py index d8da9c8184..f16ba326cc 100644 --- a/lib/sqlalchemy/testing/requirements.py +++ b/lib/sqlalchemy/testing/requirements.py @@ -1258,7 +1258,7 @@ class SuiteRequirements(Requirements): @property def cextensions(self): return exclusions.skip_if( - lambda: not self._has_cextensions(), "C extensions not installed" + lambda: not util.has_compiled_ext(), "C extensions not installed" ) def _has_sqlite(self): @@ -1270,14 +1270,6 @@ class SuiteRequirements(Requirements): except ImportError: return False - def _has_cextensions(self): - try: - from sqlalchemy import cresultproxy, cprocessors # noqa - - return True - except ImportError: - return False - @property def async_dialect(self): """dialect makes use of await_() to invoke operations on the DBAPI.""" diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py index 2d86b8b633..4b61658b21 100644 --- a/lib/sqlalchemy/util/__init__.py +++ b/lib/sqlalchemy/util/__init__.py @@ -140,6 +140,7 @@ from .langhelpers import get_callable_argspec # noqa from .langhelpers import get_cls_kwargs # noqa from .langhelpers import get_func_kwargs # noqa from .langhelpers import getargspec_init # noqa +from .langhelpers import has_compiled_ext # noqa from .langhelpers import HasMemoized # noqa from .langhelpers import hybridmethod # noqa from .langhelpers import hybridproperty # noqa diff --git a/lib/sqlalchemy/util/_collections.py b/lib/sqlalchemy/util/_collections.py index b18cc13de4..7484a8f1a1 100644 --- a/lib/sqlalchemy/util/_collections.py +++ b/lib/sqlalchemy/util/_collections.py @@ -363,7 +363,6 @@ else: class OrderedSet(set): def __init__(self, d=None): set.__init__(self) - self._list = [] if d is not None: self._list = unique_list(d) set.update(self, self._list) @@ -521,7 +520,10 @@ class IdentitySet(object): return True def issubset(self, iterable): - other = self.__class__(iterable) + if isinstance(iterable, self.__class__): + other = iterable + else: + other = self.__class__(iterable) if len(self) > len(other): return False @@ -542,7 +544,10 @@ class IdentitySet(object): return len(self) < len(other) and self.issubset(other) def issuperset(self, iterable): - other = self.__class__(iterable) + if isinstance(iterable, self.__class__): + other = iterable + else: + other = self.__class__(iterable) if len(self) < len(other): return False @@ -587,7 +592,10 @@ class IdentitySet(object): def difference(self, iterable): result = self.__class__() members = self._members - other = {id(obj) for obj in iterable} + if isinstance(iterable, self.__class__): + other = set(iterable._members.keys()) + else: + other = {id(obj) for obj in iterable} result._members.update( ((k, v) for k, v in members.items() if k not in other) ) @@ -610,7 +618,10 @@ class IdentitySet(object): def intersection(self, iterable): result = self.__class__() members = self._members - other = {id(obj) for obj in iterable} + if isinstance(iterable, self.__class__): + other = set(iterable._members.keys()) + else: + other = {id(obj) for obj in iterable} result._members.update( (k, v) for k, v in members.items() if k in other ) @@ -633,7 +644,10 @@ class IdentitySet(object): def symmetric_difference(self, iterable): result = self.__class__() members = self._members - other = {id(obj): obj for obj in iterable} + if isinstance(iterable, self.__class__): + other = iterable._members + else: + other = {id(obj): obj for obj in iterable} result._members.update( ((k, v) for k, v in members.items() if k not in other) ) diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py index 457d2875de..eb582b5284 100644 --- a/lib/sqlalchemy/util/langhelpers.py +++ b/lib/sqlalchemy/util/langhelpers.py @@ -1879,3 +1879,14 @@ def repr_tuple_names(names): return ", ".join(res) else: return "%s, ..., %s" % (", ".join(res[0:3]), res[-1]) + + +def has_compiled_ext(): + try: + from sqlalchemy import cimmutabledict # noqa F401 + from sqlalchemy import cprocessors # noqa F401 + from sqlalchemy import cresultproxy # noqa F401 + + return True + except ImportError: + return False diff --git a/setup.py b/setup.py index 1d2de5a61e..55a3cee6f9 100644 --- a/setup.py +++ b/setup.py @@ -44,11 +44,6 @@ ext_modules = [ sources=["lib/sqlalchemy/cextension/immutabledict.c"], extra_compile_args=extra_compile_args, ), - Extension( - "sqlalchemy.cutils", - sources=["lib/sqlalchemy/cextension/utils.c"], - extra_compile_args=extra_compile_args, - ), ] diff --git a/test/engine/test_processors.py b/test/engine/test_processors.py index 5a4220c827..ad643a4465 100644 --- a/test/engine/test_processors.py +++ b/test/engine/test_processors.py @@ -188,21 +188,4 @@ class PyDistillArgsTest(_DistillArgsTest): def setup_test_class(cls): from sqlalchemy.engine import util - cls.module = type( - "util", - (object,), - dict( - (k, staticmethod(v)) - for k, v in list(util.py_fallback().items()) - ), - ) - - -class CDistillArgsTest(_DistillArgsTest): - __requires__ = ("cextensions",) - - @classmethod - def setup_test_class(cls): - from sqlalchemy import cutils as util - cls.module = util