]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Minor optimization to the code
authorFederico Caselli <cfederico87@gmail.com>
Thu, 18 Feb 2021 18:12:45 +0000 (19:12 +0100)
committerFederico Caselli <cfederico87@gmail.com>
Thu, 18 Feb 2021 19:37:22 +0000 (20:37 +0100)
* remove the c version of distill params since it's actually slower than
the python one
* add a function to langhelpers to check if the cextensions are active
* minor cleanup to the OrderedSet implementation

Change-Id: Iec3d0c3f0f42cdf51f802aaca342ba37b8783b85

.github/workflows/create-wheels.yaml
lib/sqlalchemy/cextension/utils.c [deleted file]
lib/sqlalchemy/engine/util.py
lib/sqlalchemy/testing/profiling.py
lib/sqlalchemy/testing/requirements.py
lib/sqlalchemy/util/__init__.py
lib/sqlalchemy/util/_collections.py
lib/sqlalchemy/util/langhelpers.py
setup.py
test/engine/test_processors.py

index b11a14619c291c63ce6da2079b30ac20ab599cab..e6fd1e3bd5161d9dc986ce7aa9d438710582d2a0 100644 (file)
@@ -87,7 +87,7 @@ jobs:
         # for python 2.7 visual studio 9 is missing
         if: matrix.os != 'windows-latest' || matrix.python-version != '2.7'
         run: |
-          python -c 'from sqlalchemy import cprocessors, cresultproxy, cutils'
+          python -c 'from sqlalchemy.util import has_compiled_ext; assert has_compiled_ext()'
 
       - name: Test created wheel
         # the mock reconnect test seems to fail on the ci in windows
@@ -222,7 +222,7 @@ jobs:
           then
             pip install greenlet "importlib-metadata;python_version<'3.8'"
             pip install -f dist --no-index sqlalchemy
-            python -c 'from sqlalchemy import cprocessors, cresultproxy, cutils'
+            python -c 'from sqlalchemy.util import has_compiled_ext; assert has_compiled_ext()'
             pip install pytest pytest-xdist ${{ matrix.extra-requires }}
             pytest -n2 -q test --nomemory --notimingintensive
           else
@@ -362,7 +362,7 @@ jobs:
             python --version &&
             pip install greenlet \"importlib-metadata;python_version<'3.8'\" &&
             pip install -f dist --no-index sqlalchemy &&
-            python -c 'from sqlalchemy import cprocessors, cresultproxy, cutils' &&
+            python -c 'from sqlalchemy.util import has_compiled_ext; assert has_compiled_ext()' &&
             pip install pytest pytest-xdist ${{ matrix.extra-requires }} &&
             pytest -n2 -q test --nomemory --notimingintensive"
 
diff --git a/lib/sqlalchemy/cextension/utils.c b/lib/sqlalchemy/cextension/utils.c
deleted file mode 100644 (file)
index e06843c..0000000
+++ /dev/null
@@ -1,249 +0,0 @@
-/*
-utils.c
-Copyright (C) 2012-2021 the SQLAlchemy authors and contributors <see AUTHORS file>
-
-This module is part of SQLAlchemy and is released under
-the MIT License: http://www.opensource.org/licenses/mit-license.php
-*/
-
-#include <Python.h>
-
-#define MODULE_NAME "cutils"
-#define MODULE_DOC "Module containing C versions of utility functions."
-
-/*
-    Given arguments from the calling form *multiparams, **params,
-    return a list of bind parameter structures, usually a list of
-    dictionaries.
-
-    In the case of 'raw' execution which accepts positional parameters,
-    it may be a list of tuples or lists.
-
- */
-static PyObject *
-distill_params(PyObject *self, PyObject *args)
-{
-       // TODO: pass the Connection in so that there can be a standard
-       // method for warning on parameter format
-
-       PyObject *connection, *multiparams, *params;
-       PyObject *enclosing_list, *double_enclosing_list;
-       PyObject *zero_element, *zero_element_item;
-    PyObject *tmp;
-       Py_ssize_t multiparam_size, zero_element_length;
-
-       if (!PyArg_UnpackTuple(args, "_distill_params", 3, 3, &connection, &multiparams, &params)) {
-               return NULL;
-       }
-
-       if (multiparams != Py_None) {
-               multiparam_size = PyTuple_Size(multiparams);
-               if (multiparam_size < 0) {
-                       return NULL;
-               }
-       }
-       else {
-               multiparam_size = 0;
-       }
-
-       if (multiparam_size == 0) {
-               if (params != Py_None && PyMapping_Size(params) != 0) {
-
-                   tmp = PyObject_CallMethod(connection, "_warn_for_legacy_exec_format", "");
-               if (tmp == NULL) {
-                   return NULL;
-               }
-
-                       enclosing_list = PyList_New(1);
-                       if (enclosing_list == NULL) {
-                               return NULL;
-                       }
-                       Py_INCREF(params);
-                       if (PyList_SetItem(enclosing_list, 0, params) == -1) {
-                               Py_DECREF(params);
-                               Py_DECREF(enclosing_list);
-                               return NULL;
-                       }
-               }
-               else {
-                       enclosing_list = PyList_New(0);
-                       if (enclosing_list == NULL) {
-                               return NULL;
-                       }
-               }
-               return enclosing_list;
-       }
-       else if (multiparam_size == 1) {
-               zero_element = PyTuple_GetItem(multiparams, 0);
-               if (PyTuple_Check(zero_element) || PyList_Check(zero_element)) {
-                       zero_element_length = PySequence_Length(zero_element);
-
-                       if (zero_element_length != 0) {
-                               zero_element_item = PySequence_GetItem(zero_element, 0);
-                               if (zero_element_item == NULL) {
-                                       return NULL;
-                               }
-                       }
-                       else {
-                               zero_element_item = NULL;
-                       }
-
-                       if (zero_element_length == 0 ||
-                                       (
-                                               PyObject_HasAttrString(zero_element_item, "__iter__") &&
-                                               !PyObject_HasAttrString(zero_element_item, "strip")
-                                       )
-                               ) {
-                               /*
-                                * execute(stmt, [{}, {}, {}, ...])
-                        * execute(stmt, [(), (), (), ...])
-                                */
-                               Py_XDECREF(zero_element_item);
-                               Py_INCREF(zero_element);
-                               return zero_element;
-                       }
-                       else {
-                               /*
-                                * execute(stmt, ("value", "value"))
-                                */
-                               Py_XDECREF(zero_element_item);
-
-                               enclosing_list = PyList_New(1);
-                               if (enclosing_list == NULL) {
-                                       return NULL;
-                               }
-                               Py_INCREF(zero_element);
-                               if (PyList_SetItem(enclosing_list, 0, zero_element) == -1) {
-                                       Py_DECREF(zero_element);
-                                       Py_DECREF(enclosing_list);
-                                       return NULL;
-                               }
-                               return enclosing_list;
-                       }
-               }
-               else if (PyObject_HasAttrString(zero_element, "keys")) {
-                       /*
-                        * execute(stmt, {"key":"value"})
-                        */
-                       enclosing_list = PyList_New(1);
-                       if (enclosing_list ==  NULL) {
-                               return NULL;
-                       }
-                       Py_INCREF(zero_element);
-                       if (PyList_SetItem(enclosing_list, 0, zero_element) == -1) {
-                               Py_DECREF(zero_element);
-                               Py_DECREF(enclosing_list);
-                               return NULL;
-                       }
-                       return enclosing_list;
-               } else {
-                   tmp = PyObject_CallMethod(connection, "_warn_for_legacy_exec_format", "");
-               if (tmp == NULL) {
-                   return NULL;
-               }
-
-                       enclosing_list = PyList_New(1);
-                       if (enclosing_list ==  NULL) {
-                               return NULL;
-                       }
-                       double_enclosing_list = PyList_New(1);
-                       if (double_enclosing_list == NULL) {
-                               Py_DECREF(enclosing_list);
-                               return NULL;
-                       }
-                       Py_INCREF(zero_element);
-                       if (PyList_SetItem(enclosing_list, 0, zero_element) == -1) {
-                               Py_DECREF(zero_element);
-                               Py_DECREF(enclosing_list);
-                               Py_DECREF(double_enclosing_list);
-                               return NULL;
-                       }
-                       if (PyList_SetItem(double_enclosing_list, 0, enclosing_list) == -1) {
-                               Py_DECREF(zero_element);
-                               Py_DECREF(enclosing_list);
-                               Py_DECREF(double_enclosing_list);
-                               return NULL;
-                       }
-                       return double_enclosing_list;
-               }
-       }
-       else {
-
-           tmp = PyObject_CallMethod(connection, "_warn_for_legacy_exec_format", "");
-        if (tmp == NULL) {
-            return NULL;
-        }
-
-               zero_element = PyTuple_GetItem(multiparams, 0);
-               if (PyObject_HasAttrString(zero_element, "__iter__") &&
-                               !PyObject_HasAttrString(zero_element, "strip")
-                       ) {
-                       Py_INCREF(multiparams);
-                       return multiparams;
-               }
-               else {
-                       enclosing_list = PyList_New(1);
-                       if (enclosing_list ==  NULL) {
-                               return NULL;
-                       }
-                       Py_INCREF(multiparams);
-                       if (PyList_SetItem(enclosing_list, 0, multiparams) == -1) {
-                               Py_DECREF(multiparams);
-                               Py_DECREF(enclosing_list);
-                               return NULL;
-                       }
-                       return enclosing_list;
-               }
-       }
-}
-
-static PyMethodDef module_methods[] = {
-    {"_distill_params", distill_params, METH_VARARGS,
-     "Distill an execute() parameter structure."},
-    {NULL, NULL, 0, NULL}        /* Sentinel */
-};
-
-#ifndef PyMODINIT_FUNC  /* declarations for DLL import/export */
-#define PyMODINIT_FUNC void
-#endif
-
-#if PY_MAJOR_VERSION >= 3
-
-#define INITERROR return NULL
-
-static struct PyModuleDef module_def = {
-    PyModuleDef_HEAD_INIT,
-    MODULE_NAME,
-    MODULE_DOC,
-    -1,
-    module_methods
- };
-
-PyMODINIT_FUNC
-PyInit_cutils(void)
-
-#else
-
-#define INITERROR return
-
-PyMODINIT_FUNC
-initcutils(void)
-
-#endif
-
-{
-    PyObject *m;
-
-#if PY_MAJOR_VERSION >= 3
-    m = PyModule_Create(&module_def);
-#else
-    m = Py_InitModule3(MODULE_NAME, module_methods, MODULE_DOC);
-#endif
-    if (m == NULL)
-        INITERROR;
-
-#if PY_MAJOR_VERSION >= 3
-    return m;
-#endif
-}
-
index 4e302f464f98ea4bd109b02db4a8bd1fcc52edbf..ede2631985ae17ab2429a4904f425226126b4310 100644 (file)
@@ -34,63 +34,55 @@ _no_tuple = ()
 _no_kw = util.immutabledict()
 
 
-def py_fallback():
-    # TODO: pass the Connection in so that there can be a standard
-    # method for warning on parameter format
-    def _distill_params(connection, multiparams, params):  # noqa
-        r"""Given arguments from the calling form \*multiparams, \**params,
-        return a list of bind parameter structures, usually a list of
-        dictionaries.
-
-        In the case of 'raw' execution which accepts positional parameters,
-        it may be a list of tuples or lists.
-
-        """
-
-        # C version will fail if this assertion is not true.
-        # assert isinstance(multiparams, tuple)
-
-        if not multiparams:
-            if params:
-                connection._warn_for_legacy_exec_format()
-                return [params]
+def _distill_params(connection, multiparams, params):
+    r"""Given arguments from the calling form \*multiparams, \**params,
+    return a list of bind parameter structures, usually a list of
+    dictionaries.
+
+    In the case of 'raw' execution which accepts positional parameters,
+    it may be a list of tuples or lists.
+
+    """
+
+    if not multiparams:
+        if params:
+            connection._warn_for_legacy_exec_format()
+            return [params]
+        else:
+            return []
+    elif len(multiparams) == 1:
+        zero = multiparams[0]
+        if isinstance(zero, (list, tuple)):
+            if (
+                not zero
+                or hasattr(zero[0], "__iter__")
+                and not hasattr(zero[0], "strip")
+            ):
+                # execute(stmt, [{}, {}, {}, ...])
+                # execute(stmt, [(), (), (), ...])
+                return zero
             else:
-                return []
-        elif len(multiparams) == 1:
-            zero = multiparams[0]
-            if isinstance(zero, (list, tuple)):
-                if (
-                    not zero
-                    or hasattr(zero[0], "__iter__")
-                    and not hasattr(zero[0], "strip")
-                ):
-                    # execute(stmt, [{}, {}, {}, ...])
-                    # execute(stmt, [(), (), (), ...])
-                    return zero
-                else:
-                    # this is used by exec_driver_sql only, so a deprecation
-                    # warning would already be coming from passing a plain
-                    # textual statement with positional parameters to
-                    # execute().
-                    # execute(stmt, ("value", "value"))
-                    return [zero]
-            elif hasattr(zero, "keys"):
-                # execute(stmt, {"key":"value"})
+                # this is used by exec_driver_sql only, so a deprecation
+                # warning would already be coming from passing a plain
+                # textual statement with positional parameters to
+                # execute().
+                # execute(stmt, ("value", "value"))
                 return [zero]
-            else:
-                connection._warn_for_legacy_exec_format()
-                # execute(stmt, "value")
-                return [[zero]]
+        elif hasattr(zero, "keys"):
+            # execute(stmt, {"key":"value"})
+            return [zero]
         else:
             connection._warn_for_legacy_exec_format()
-            if hasattr(multiparams[0], "__iter__") and not hasattr(
-                multiparams[0], "strip"
-            ):
-                return multiparams
-            else:
-                return [multiparams]
-
-    return locals()
+            # execute(stmt, "value")
+            return [[zero]]
+    else:
+        connection._warn_for_legacy_exec_format()
+        if hasattr(multiparams[0], "__iter__") and not hasattr(
+            multiparams[0], "strip"
+        ):
+            return multiparams
+        else:
+            return [multiparams]
 
 
 def _distill_cursor_params(connection, multiparams, params):
@@ -161,9 +153,3 @@ def _distill_params_20(params):
         return (params,), _no_kw
     else:
         raise exc.ArgumentError("mapping or sequence expected for parameters")
-
-
-try:
-    from sqlalchemy.cutils import _distill_params  # noqa
-except ImportError:
-    globals().update(py_fallback())
index 16c6d458c940c462c10d6ac3145a146947102261..5e4f19273689abbcba93418a3895e926a1b09014 100644 (file)
@@ -22,6 +22,7 @@ import sys
 
 from . import config
 from .util import gc_collect
+from ..util import has_compiled_ext
 
 
 try:
@@ -109,7 +110,7 @@ class ProfileStatsFile(object):
             if config.db.dialect.convert_unicode
             else "dbapiunicode"
         )
-        _has_cext = config.requirements._has_cextensions()
+        _has_cext = has_compiled_ext()
         platform_tokens.append(_has_cext and "cextensions" or "nocextensions")
         return "_".join(platform_tokens)
 
index d8da9c8184657ae34f0246f508e45f95390c057c..f16ba326cc8abde621f11ecbdd77964f7df46df1 100644 (file)
@@ -1258,7 +1258,7 @@ class SuiteRequirements(Requirements):
     @property
     def cextensions(self):
         return exclusions.skip_if(
-            lambda: not self._has_cextensions(), "C extensions not installed"
+            lambda: not util.has_compiled_ext(), "C extensions not installed"
         )
 
     def _has_sqlite(self):
@@ -1270,14 +1270,6 @@ class SuiteRequirements(Requirements):
         except ImportError:
             return False
 
-    def _has_cextensions(self):
-        try:
-            from sqlalchemy import cresultproxy, cprocessors  # noqa
-
-            return True
-        except ImportError:
-            return False
-
     @property
     def async_dialect(self):
         """dialect makes use of await_() to invoke operations on the DBAPI."""
index 2d86b8b633024459f4a52bba7b4b4ca7fc16610c..4b61658b21df9600fcdff157d78acf4629254d2a 100644 (file)
@@ -140,6 +140,7 @@ from .langhelpers import get_callable_argspec  # noqa
 from .langhelpers import get_cls_kwargs  # noqa
 from .langhelpers import get_func_kwargs  # noqa
 from .langhelpers import getargspec_init  # noqa
+from .langhelpers import has_compiled_ext  # noqa
 from .langhelpers import HasMemoized  # noqa
 from .langhelpers import hybridmethod  # noqa
 from .langhelpers import hybridproperty  # noqa
index b18cc13de46a9763f37da07110d2be1857804022..7484a8f1a1eef96c03c813150dcd5f84bf89b6ff 100644 (file)
@@ -363,7 +363,6 @@ else:
 class OrderedSet(set):
     def __init__(self, d=None):
         set.__init__(self)
-        self._list = []
         if d is not None:
             self._list = unique_list(d)
             set.update(self, self._list)
@@ -521,7 +520,10 @@ class IdentitySet(object):
             return True
 
     def issubset(self, iterable):
-        other = self.__class__(iterable)
+        if isinstance(iterable, self.__class__):
+            other = iterable
+        else:
+            other = self.__class__(iterable)
 
         if len(self) > len(other):
             return False
@@ -542,7 +544,10 @@ class IdentitySet(object):
         return len(self) < len(other) and self.issubset(other)
 
     def issuperset(self, iterable):
-        other = self.__class__(iterable)
+        if isinstance(iterable, self.__class__):
+            other = iterable
+        else:
+            other = self.__class__(iterable)
 
         if len(self) < len(other):
             return False
@@ -587,7 +592,10 @@ class IdentitySet(object):
     def difference(self, iterable):
         result = self.__class__()
         members = self._members
-        other = {id(obj) for obj in iterable}
+        if isinstance(iterable, self.__class__):
+            other = set(iterable._members.keys())
+        else:
+            other = {id(obj) for obj in iterable}
         result._members.update(
             ((k, v) for k, v in members.items() if k not in other)
         )
@@ -610,7 +618,10 @@ class IdentitySet(object):
     def intersection(self, iterable):
         result = self.__class__()
         members = self._members
-        other = {id(obj) for obj in iterable}
+        if isinstance(iterable, self.__class__):
+            other = set(iterable._members.keys())
+        else:
+            other = {id(obj) for obj in iterable}
         result._members.update(
             (k, v) for k, v in members.items() if k in other
         )
@@ -633,7 +644,10 @@ class IdentitySet(object):
     def symmetric_difference(self, iterable):
         result = self.__class__()
         members = self._members
-        other = {id(obj): obj for obj in iterable}
+        if isinstance(iterable, self.__class__):
+            other = iterable._members
+        else:
+            other = {id(obj): obj for obj in iterable}
         result._members.update(
             ((k, v) for k, v in members.items() if k not in other)
         )
index 457d2875de84c45c99be49063552337b1b7804d0..eb582b5284c66f439b53dd9dd650cd286ac5f251 100644 (file)
@@ -1879,3 +1879,14 @@ def repr_tuple_names(names):
         return ", ".join(res)
     else:
         return "%s, ..., %s" % (", ".join(res[0:3]), res[-1])
+
+
+def has_compiled_ext():
+    try:
+        from sqlalchemy import cimmutabledict  # noqa F401
+        from sqlalchemy import cprocessors  # noqa F401
+        from sqlalchemy import cresultproxy  # noqa F401
+
+        return True
+    except ImportError:
+        return False
index 1d2de5a61e82aeb2813b8808abea1af878693351..55a3cee6f984b065d7c4dfc44f10d5b9fa6f0571 100644 (file)
--- a/setup.py
+++ b/setup.py
@@ -44,11 +44,6 @@ ext_modules = [
         sources=["lib/sqlalchemy/cextension/immutabledict.c"],
         extra_compile_args=extra_compile_args,
     ),
-    Extension(
-        "sqlalchemy.cutils",
-        sources=["lib/sqlalchemy/cextension/utils.c"],
-        extra_compile_args=extra_compile_args,
-    ),
 ]
 
 
index 5a4220c827d12cb0b54eb8d3d5cf7583e5835389..ad643a44654bc1820006539267dddeb118834970 100644 (file)
@@ -188,21 +188,4 @@ class PyDistillArgsTest(_DistillArgsTest):
     def setup_test_class(cls):
         from sqlalchemy.engine import util
 
-        cls.module = type(
-            "util",
-            (object,),
-            dict(
-                (k, staticmethod(v))
-                for k, v in list(util.py_fallback().items())
-            ),
-        )
-
-
-class CDistillArgsTest(_DistillArgsTest):
-    __requires__ = ("cextensions",)
-
-    @classmethod
-    def setup_test_class(cls):
-        from sqlalchemy import cutils as util
-
         cls.module = util