]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Add immutabledict C code
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 22 May 2020 04:06:06 +0000 (00:06 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 23 May 2020 04:05:13 +0000 (00:05 -0400)
Start trying to convert fundamental objects to
C as we now rely on a fairly small core of things,
and 1.4 is having problems with complexity added being
slower than the performance gains we are trying to build in.

immutabledict here does seem to bench as twice as fast as the
Python one, see below.  However, it does not appear to be
used prominently enough to make any dent in the performance
tests.

at the very least it may provide us some more lift-and-copy
code for more C extensions.

import timeit

from sqlalchemy.util._collections import not_immutabledict, immutabledict

def run(dict_cls):
    for i in range(1000000):
        d1 = dict_cls({"x": 5, "y": 4})

        d2 = d1.union({"x": 17, "new key": "some other value"}, None)

        assert list(d2) == ["x", "y", "new key"]

print(
    timeit.timeit(
        "run(d)", "from __main__ import run, not_immutabledict as d", number=1
    )
)
print(
    timeit.timeit(
        "run(d)", "from __main__ import run, immutabledict as d", number=1
    )
)

output:

python: 1.8799766399897635
C code: 0.8880784640205093

Change-Id: I29e7104dc21dcc7cdf895bf274003af2e219bf6d

15 files changed:
lib/sqlalchemy/cextension/immutabledict.c [new file with mode: 0644]
lib/sqlalchemy/cextension/resultproxy.c
lib/sqlalchemy/cextension/utils.c
lib/sqlalchemy/dialects/mssql/base.py
lib/sqlalchemy/dialects/postgresql/psycopg2.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/sql/schema.py
lib/sqlalchemy/util/__init__.py
lib/sqlalchemy/util/_collections.py
setup.py
test/base/test_utils.py
test/engine/test_execute.py
test/orm/test_options.py
test/orm/test_update_delete.py
test/sql/test_metadata.py

diff --git a/lib/sqlalchemy/cextension/immutabledict.c b/lib/sqlalchemy/cextension/immutabledict.c
new file mode 100644 (file)
index 0000000..2a19cf3
--- /dev/null
@@ -0,0 +1,475 @@
+/*
+immuatbledict.c
+Copyright (C) 2020 the SQLAlchemy authors and contributors <see AUTHORS file>
+
+This module is part of SQLAlchemy and is released under
+the MIT License: http://www.opensource.org/licenses/mit-license.php
+*/
+
+#include <Python.h>
+
+#define MODULE_NAME "cimmutabledict"
+#define MODULE_DOC "immutable dictionary implementation"
+
+
+typedef struct {
+    PyObject_HEAD
+    PyObject *dict;
+} ImmutableDict;
+
+static PyTypeObject ImmutableDictType;
+
+
+
+static PyObject *
+
+ImmutableDict_new(PyTypeObject *type, PyObject *args, PyObject *kw)
+
+{
+    ImmutableDict *new_obj;
+    PyObject *arg_dict = NULL;
+    PyObject *our_dict;
+
+    if (!PyArg_UnpackTuple(args, "ImmutableDict", 0, 1, &arg_dict)) {
+        return NULL;
+    }
+
+    if (arg_dict != NULL && PyDict_CheckExact(arg_dict)) {
+        // going on the unproven theory that doing PyDict_New + PyDict_Update
+        // is faster than just calling CallObject, as we do below to
+        // accommodate for other dictionary argument forms
+        our_dict = PyDict_New();
+        if (our_dict == NULL) {
+            return NULL;
+        }
+
+        if (PyDict_Update(our_dict, arg_dict) == -1) {
+            Py_DECREF(our_dict);
+            return NULL;
+        }
+    }
+    else {
+        // for other calling styles, let PyDict figure it out
+        our_dict = PyObject_Call((PyObject *) &PyDict_Type, args, kw);
+    }
+
+    new_obj = PyObject_GC_New(ImmutableDict, &ImmutableDictType);
+    if (new_obj == NULL) {
+        Py_DECREF(our_dict);
+        return NULL;
+    }
+    new_obj->dict = our_dict;
+    PyObject_GC_Track(new_obj);
+
+    return (PyObject *)new_obj;
+
+}
+
+
+Py_ssize_t
+ImmutableDict_length(ImmutableDict *self)
+{
+    return PyDict_Size(self->dict);
+}
+
+static PyObject *
+ImmutableDict_subscript(ImmutableDict *self, PyObject *key)
+{
+    PyObject *value;
+#if PY_MAJOR_VERSION >= 3
+    PyObject *err_bytes;
+#endif
+
+    value = PyDict_GetItem((PyObject *)self->dict, key);
+
+    if (value == NULL) {
+#if PY_MAJOR_VERSION >= 3
+        err_bytes = PyUnicode_AsUTF8String(key);
+        if (err_bytes == NULL)
+            return NULL;
+        PyErr_Format(PyExc_KeyError, "%s", PyBytes_AS_STRING(err_bytes));
+#else
+        PyErr_Format(PyExc_KeyError, "%s", PyString_AsString(key));
+#endif
+        return NULL;
+    }
+
+    Py_INCREF(value);
+
+    return value;
+}
+
+
+static void
+ImmutableDict_dealloc(ImmutableDict *self)
+{
+    PyObject_GC_UnTrack(self);
+    Py_XDECREF(self->dict);
+    PyObject_GC_Del(self);
+}
+
+
+static PyObject *
+ImmutableDict_reduce(ImmutableDict *self)
+{
+    return Py_BuildValue("O(O)", Py_TYPE(self), self->dict);
+}
+
+
+static PyObject *
+ImmutableDict_repr(ImmutableDict *self)
+{
+    return PyUnicode_FromFormat("immutabledict(%R)", self->dict);
+}
+
+
+static PyObject *
+ImmutableDict_union(PyObject *self, PyObject *args, PyObject *kw)
+{
+    PyObject *arg_dict, *new_dict;
+
+    ImmutableDict *new_obj;
+
+    if (!PyArg_UnpackTuple(args, "ImmutableDict", 0, 1, &arg_dict)) {
+        return NULL;
+    }
+
+    if (!PyDict_CheckExact(arg_dict)) {
+        // if we didnt get a dict, and got lists of tuples or
+        // keyword args, make a dict
+        arg_dict = PyObject_Call((PyObject *) &PyDict_Type, args, kw);
+        if (arg_dict == NULL) {
+            return NULL;
+        }
+    }
+    else {
+        // otherwise we will use the dict as is
+        Py_INCREF(arg_dict);
+    }
+
+    if (PyDict_Size(arg_dict) == 0) {
+        Py_DECREF(arg_dict);
+        Py_INCREF(self);
+        return self;
+    }
+
+    new_dict = PyDict_New();
+    if (new_dict == NULL) {
+        Py_DECREF(arg_dict);
+        return NULL;
+    }
+
+    if (PyDict_Update(new_dict, ((ImmutableDict *)self)->dict) == -1) {
+        Py_DECREF(arg_dict);
+        Py_DECREF(new_dict);
+        return NULL;
+    }
+
+    if (PyDict_Update(new_dict, arg_dict) == -1) {
+        Py_DECREF(arg_dict);
+        Py_DECREF(new_dict);
+        return NULL;
+    }
+
+    Py_DECREF(arg_dict);
+
+    new_obj = PyObject_GC_New(ImmutableDict, Py_TYPE(self));
+    if (new_obj == NULL) {
+        Py_DECREF(new_dict);
+        return NULL;
+    }
+
+    new_obj->dict = new_dict;
+
+    PyObject_GC_Track(new_obj);
+
+    return (PyObject *)new_obj;
+}
+
+
+static PyObject *
+ImmutableDict_merge_with(PyObject *self, PyObject *args)
+{
+    PyObject *element, *arg, *new_dict = NULL;
+
+    ImmutableDict *new_obj;
+
+    Py_ssize_t num_args = PyTuple_Size(args);
+    Py_ssize_t i;
+
+    for (i=0; i<num_args; i++) {
+        element = PyTuple_GetItem(args, i);
+
+        if (element == NULL) {
+            Py_XDECREF(new_dict);
+            return NULL;
+        }
+        else if (element == Py_None) {
+            // none was passed, skip it
+            continue;
+        }
+
+        if (!PyDict_CheckExact(element)) {
+            // not a dict, try to make a dict
+
+            arg = PyTuple_Pack(1, element);
+
+            element = PyObject_CallObject((PyObject *) &PyDict_Type, arg);
+
+            Py_DECREF(arg);
+
+            if (element == NULL) {
+                Py_XDECREF(new_dict);
+                return NULL;
+            }
+        }
+        else {
+            Py_INCREF(element);
+            if (PyDict_Size(element) == 0) {
+                continue;
+            }
+        }
+
+        // initialize a new dictionary only if we receive data that
+        // is not empty.  otherwise we return self at the end.
+        if (new_dict == NULL) {
+
+            new_dict = PyDict_New();
+            if (new_dict == NULL) {
+                Py_DECREF(element);
+                return NULL;
+            }
+
+            if (PyDict_Update(new_dict, ((ImmutableDict *)self)->dict) == -1) {
+                Py_DECREF(element);
+                Py_DECREF(new_dict);
+                return NULL;
+            }
+        }
+
+        if (PyDict_Update(new_dict, element) == -1) {
+            Py_DECREF(element);
+            Py_DECREF(new_dict);
+            return NULL;
+        }
+
+        Py_DECREF(element);
+    }
+
+
+    if (new_dict != NULL) {
+        new_obj = PyObject_GC_New(ImmutableDict, Py_TYPE(self));
+        if (new_obj == NULL) {
+            Py_DECREF(new_dict);
+            return NULL;
+        }
+
+        new_obj->dict = new_dict;
+        PyObject_GC_Track(new_obj);
+        return (PyObject *)new_obj;
+    }
+    else {
+        Py_INCREF(self);
+        return self;
+    }
+
+}
+
+
+static PyObject *
+ImmutableDict_get(ImmutableDict *self, PyObject *args)
+{
+    PyObject *key;
+    PyObject *default_value = Py_None;
+
+    if (!PyArg_UnpackTuple(args, "key", 1, 2, &key, &default_value)) {
+        return NULL;
+    }
+
+
+    return PyObject_CallMethod(self->dict, "get", "OO", key, default_value);
+}
+
+static PyObject *
+ImmutableDict_keys(ImmutableDict *self)
+{
+    return PyObject_CallMethod(self->dict, "keys", "");
+}
+
+static int
+ImmutableDict_traverse(ImmutableDict *self, visitproc visit, void *arg)
+{
+    Py_VISIT(self->dict);
+    return 0;
+}
+
+static PyObject *
+ImmutableDict_richcompare(ImmutableDict *self, PyObject *other, int op)
+{
+    return PyObject_RichCompare(self->dict, other, op);
+}
+
+static PyObject *
+ImmutableDict_iter(ImmutableDict *self)
+{
+    return PyObject_CallMethod(self->dict, "__iter__", "");
+}
+
+static PyObject *
+ImmutableDict_items(ImmutableDict *self)
+{
+    return PyObject_CallMethod(self->dict, "items", "");
+}
+
+static PyObject *
+ImmutableDict_values(ImmutableDict *self)
+{
+    return PyObject_CallMethod(self->dict, "values", "");
+
+}
+
+static PyObject *
+ImmutableDict_contains(ImmutableDict *self, PyObject *key)
+{
+    int ret;
+
+    ret = PyDict_Contains(self->dict, key);
+
+    if (ret == 1) Py_RETURN_TRUE;
+    else if (ret == 0) Py_RETURN_FALSE;
+    else return NULL;
+}
+
+static PyMethodDef ImmutableDict_methods[] = {
+    {"union", (PyCFunction) ImmutableDict_union, METH_VARARGS | METH_KEYWORDS,
+     "provide a union of this dictionary with the given dictionary-like arguments"},
+    {"merge_with", (PyCFunction) ImmutableDict_merge_with, METH_VARARGS,
+     "provide a union of this dictionary with those given"},
+    {"keys", (PyCFunction) ImmutableDict_keys, METH_NOARGS,
+     "return dictionary keys"},
+
+     {"__contains__",(PyCFunction)ImmutableDict_contains, METH_O,
+     "test a member for containment"},
+
+    {"items", (PyCFunction) ImmutableDict_items, METH_NOARGS,
+     "return dictionary items"},
+    {"values", (PyCFunction) ImmutableDict_values, METH_NOARGS,
+     "return dictionary values"},
+    {"get", (PyCFunction) ImmutableDict_get, METH_VARARGS,
+     "get a value"},
+    {"__reduce__",  (PyCFunction)ImmutableDict_reduce, METH_NOARGS,
+     "Pickle support method."},
+    {NULL},
+};
+
+
+static PyMappingMethods ImmutableDict_as_mapping = {
+    (lenfunc)ImmutableDict_length,       /* mp_length */
+    (binaryfunc)ImmutableDict_subscript, /* mp_subscript */
+    0                                   /* mp_ass_subscript */
+};
+
+
+
+
+static PyTypeObject ImmutableDictType = {
+    PyVarObject_HEAD_INIT(NULL, 0)
+    "sqlalchemy.cimmutabledict.immutabledict",          /* tp_name */
+    sizeof(ImmutableDict),               /* tp_basicsize */
+    0,                                  /* tp_itemsize */
+    (destructor)ImmutableDict_dealloc,  /* tp_dealloc */
+    0,                                  /* tp_print */
+    0,                                  /* tp_getattr */
+    0,                                  /* tp_setattr */
+    0,                                  /* tp_compare */
+    (reprfunc)ImmutableDict_repr,               /* tp_repr */
+    0,                                  /* tp_as_number */
+    0,                                   /* tp_as_sequence */
+    &ImmutableDict_as_mapping,            /* tp_as_mapping */
+    0,                                 /* tp_hash */
+    0,                                  /* tp_call */
+    0,                                  /* tp_str */
+    0,                                   /* tp_getattro */
+    0,                                  /* tp_setattro */
+    0,                                  /* tp_as_buffer */
+    Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC , /* tp_flags */
+    "immutable dictionary",                         /* tp_doc */
+    (traverseproc)ImmutableDict_traverse,          /* tp_traverse */
+    0,                                  /* tp_clear */
+    (richcmpfunc)ImmutableDict_richcompare, /* tp_richcompare */
+    0,                                  /* tp_weaklistoffset */
+    (getiterfunc)ImmutableDict_iter,     /* tp_iter */
+    0,                                  /* tp_iternext */
+    ImmutableDict_methods,               /* tp_methods */
+    0,                                  /* tp_members */
+    0,                                     /* tp_getset */
+    0,                                  /* tp_base */
+    0,                                  /* tp_dict */
+    0,                                  /* tp_descr_get */
+    0,                                  /* tp_descr_set */
+    0,                                  /* tp_dictoffset */
+    0,                                 /* tp_init */
+    0,                                  /* tp_alloc */
+    ImmutableDict_new,                   /* tp_new */
+    0,                                    /* tp_free */
+};
+
+
+
+
+
+static PyMethodDef module_methods[] = {
+    {NULL, NULL, 0, NULL}        /* Sentinel */
+};
+
+#ifndef PyMODINIT_FUNC  /* declarations for DLL import/export */
+#define PyMODINIT_FUNC void
+#endif
+
+
+#if PY_MAJOR_VERSION >= 3
+
+static struct PyModuleDef module_def = {
+    PyModuleDef_HEAD_INIT,
+    MODULE_NAME,
+    MODULE_DOC,
+    -1,
+    module_methods
+};
+
+#define INITERROR return NULL
+
+PyMODINIT_FUNC
+PyInit_cimmutabledict(void)
+
+#else
+
+#define INITERROR return
+
+PyMODINIT_FUNC
+initcimmutabledict(void)
+
+#endif
+
+{
+    PyObject *m;
+
+    if (PyType_Ready(&ImmutableDictType) < 0)
+        INITERROR;
+
+
+#if PY_MAJOR_VERSION >= 3
+    m = PyModule_Create(&module_def);
+#else
+    m = Py_InitModule3(MODULE_NAME, module_methods, MODULE_DOC);
+#endif
+    if (m == NULL)
+        INITERROR;
+
+    Py_INCREF(&ImmutableDictType);
+    PyModule_AddObject(m, "immutabledict", (PyObject *)&ImmutableDictType);
+
+#if PY_MAJOR_VERSION >= 3
+    return m;
+#endif
+}
index 244379116d24e55589e5cf66f0e0a102a8f977bc..ddf3e190046b999805b3f0a5dd74d3d940388ca0 100644 (file)
@@ -241,7 +241,6 @@ BaseRow_filter_on_values(BaseRow *self, PyObject *filters)
     row_class = PyObject_GetAttrString(sqlalchemy_engine_row, "Row");
 
     key_style = PyLong_FromLong(self->key_style);
-    Py_INCREF(key_style);
 
     new_obj = PyObject_CallFunction(
         row_class, "OOOOO", self->parent, filters, self->keymap,
index fb7fbe4e66bf0a3381d4bcb504b15052fa63089f..ab8b39335c215284f39aeeb389d3ba758bd5ce87 100644 (file)
@@ -46,7 +46,7 @@ distill_params(PyObject *self, PyObject *args)
        }
 
        if (multiparam_size == 0) {
-               if (params != Py_None && PyDict_Size(params) != 0) {
+               if (params != Py_None && PyMapping_Size(params) != 0) {
                        // TODO: this is keyword parameters, emit parameter format
                        // deprecation warning
                        enclosing_list = PyList_New(1);
index 5618a67f9de1953bb9bf6f63caed66a2d060f714..228baa84f67afd6fcb5a0716ce8a1e4988dfa1e3 100644 (file)
@@ -2361,7 +2361,7 @@ class MSDialect(default.DefaultDialect):
     }
 
     engine_config_types = default.DefaultDialect.engine_config_types.union(
-        [("legacy_schema_aliasing", util.asbool)]
+        {"legacy_schema_aliasing": util.asbool}
     )
 
     ischema_names = ischema_names
index 1eaf63ff3449dd69750d845c3bac16d63164c9ec..9585dd46753c8895dc68d4603bd018ee2f9a71b6 100644 (file)
@@ -658,7 +658,7 @@ class PGDialect_psycopg2(PGDialect):
     _has_native_jsonb = False
 
     engine_config_types = PGDialect.engine_config_types.union(
-        [("use_native_unicode", util.asbool)]
+        {"use_native_unicode": util.asbool}
     )
 
     colspecs = util.update_copy(
index 70b8a71e3e5a82062554e92e3403c45cc885c107..7f910afed50e83219e15c6c429dbba2656009b98 100644 (file)
@@ -1551,7 +1551,7 @@ class Query(Generative):
     def _options(self, conditional, *args):
         # most MapperOptions write to the '_attributes' dictionary,
         # so copy that as well
-        self._attributes = self._attributes.copy()
+        self._attributes = dict(self._attributes)
         if "_unbound_load_dedupes" not in self._attributes:
             self._attributes["_unbound_load_dedupes"] = set()
         opts = tuple(util.flatten_iterator(args))
@@ -1720,7 +1720,7 @@ class Query(Generative):
                 "params() takes zero or one positional argument, "
                 "which is a dictionary."
             )
-        self._params = self._params.copy()
+        self._params = dict(self._params)
         self._params.update(kwargs)
 
     @_generative
@@ -2277,7 +2277,7 @@ class Query(Generative):
         # dict, so that no existing dict in the path is mutated
         while "prev" in jp:
             f, prev = jp["prev"]
-            prev = prev.copy()
+            prev = dict(prev)
             prev[f] = jp.copy()
             jp["prev"] = (f, prev)
             jp = prev
@@ -4831,7 +4831,7 @@ class QueryContext(object):
         self.propagate_options = set(
             o for o in query._with_options if o.propagate_to_loaders
         )
-        self.attributes = query._attributes.copy()
+        self.attributes = dict(query._attributes)
         if self.refresh_state is not None:
             self.identity_token = query._refresh_identity_token
         else:
index fd1e3fa38f000713d3de9d28c7c5e61127b78c5c..689eda11d408361c44721983f031aef4e8a8f764 100644 (file)
@@ -3972,7 +3972,7 @@ class MetaData(SchemaItem):
                 examples.
 
         """
-        self.tables = util.immutabledict()
+        self.tables = util.FacadeDict()
         self.schema = quoted_name(schema, quote_schema)
         self.naming_convention = (
             naming_convention
@@ -4015,7 +4015,7 @@ class MetaData(SchemaItem):
 
     def _add_table(self, name, schema, table):
         key = _get_table_key(name, schema)
-        dict.__setitem__(self.tables, key, table)
+        self.tables._insert_item(key, table)
         if schema:
             self._schemas.add(schema)
 
index 6a0b065eea25b5375376cd996ca677919446f93b..55a6cdcf90b324365dee247f4369ab75400ef9d4 100644 (file)
@@ -16,6 +16,7 @@ from ._collections import collections_abc  # noqa
 from ._collections import column_dict  # noqa
 from ._collections import column_set  # noqa
 from ._collections import EMPTY_SET  # noqa
+from ._collections import FacadeDict  # noqa
 from ._collections import flatten_iterator  # noqa
 from ._collections import has_dupes  # noqa
 from ._collections import has_intersection  # noqa
index 0990acb8374e5648b1529444ea6c8803fe2051dd..065935c48265a99c4eb30c93d649121cf10f417e 100644 (file)
@@ -31,7 +31,67 @@ class ImmutableContainer(object):
     __delitem__ = __setitem__ = __setattr__ = _immutable
 
 
-class immutabledict(ImmutableContainer, dict):
+def _immutabledict_py_fallback():
+    class immutabledict(ImmutableContainer, dict):
+
+        clear = (
+            pop
+        ) = popitem = setdefault = update = ImmutableContainer._immutable
+
+        def __new__(cls, *args):
+            new = dict.__new__(cls)
+            dict.__init__(new, *args)
+            return new
+
+        def __init__(self, *args):
+            pass
+
+        def __reduce__(self):
+            return _immutabledict_reconstructor, (dict(self),)
+
+        def union(self, d):
+            if not d:
+                return self
+
+            new = dict.__new__(self.__class__)
+            dict.__init__(new, self)
+            dict.update(new, d)
+            return new
+
+        def merge_with(self, *dicts):
+            new = None
+            for d in dicts:
+                if d:
+                    if new is None:
+                        new = dict.__new__(self.__class__)
+                        dict.__init__(new, self)
+                    dict.update(new, d)
+            if new is None:
+                return self
+
+            return new
+
+        def __repr__(self):
+            return "immutabledict(%s)" % dict.__repr__(self)
+
+    return immutabledict
+
+
+try:
+    from sqlalchemy.cimmutabledict import immutabledict
+
+    collections_abc.Mapping.register(immutabledict)
+
+except ImportError:
+    immutabledict = _immutabledict_py_fallback()
+
+    def _immutabledict_reconstructor(*arg):
+        """do the pickle dance"""
+        return immutabledict(*arg)
+
+
+class FacadeDict(ImmutableContainer, dict):
+    """A dictionary that is not publicly mutable."""
 
     clear = pop = popitem = setdefault = update = ImmutableContainer._immutable
 
@@ -44,24 +104,17 @@ class immutabledict(ImmutableContainer, dict):
         pass
 
     def __reduce__(self):
-        return immutabledict, (dict(self),)
+        return FacadeDict, (dict(self),)
 
-    def union(self, d):
-        new = dict.__new__(self.__class__)
-        dict.__init__(new, self)
-        dict.update(new, d)
-        return new
+    def _insert_item(self, key, value):
+        """insert an item into the dictionary directly.
 
-    def merge_with(self, *dicts):
-        new = dict.__new__(self.__class__)
-        dict.__init__(new, self)
-        for d in dicts:
-            if d:
-                dict.update(new, d)
-        return new
+
+        """
+        dict.__setitem__(self, key, value)
 
     def __repr__(self):
-        return "immutabledict(%s)" % dict.__repr__(self)
+        return "FacadeDict(%s)" % dict.__repr__(self)
 
 
 class Properties(object):
index 3b175c9353420770e4398dea138724e23d130804..3fa3e864a0c724a6efec2e12289bb9da4273bc30 100644 (file)
--- a/setup.py
+++ b/setup.py
@@ -42,6 +42,11 @@ ext_modules = [
         sources=["lib/sqlalchemy/cextension/resultproxy.c"],
         extra_compile_args=extra_compile_args,
     ),
+    Extension(
+        "sqlalchemy.cimmutabledict",
+        sources=["lib/sqlalchemy/cextension/immutabledict.c"],
+        extra_compile_args=extra_compile_args,
+    ),
     Extension(
         "sqlalchemy.cutils",
         sources=["lib/sqlalchemy/cextension/utils.c"],
index a6d777c61df44b1417c204da017e7fcb517a0b2a..bba0cc16c77ac04699072ec3c0bce2d83d2796a3 100644 (file)
@@ -141,11 +141,100 @@ class OrderedSetTest(fixtures.TestBase):
         eq_(o.union(iter([3, 4, 6])), util.OrderedSet([2, 3, 4, 5, 6]))
 
 
-class FrozenDictTest(fixtures.TestBase):
+class ImmutableDictTest(fixtures.TestBase):
+    def test_union_no_change(self):
+        d = util.immutabledict({1: 2, 3: 4})
+
+        d2 = d.union({})
+
+        is_(d2, d)
+
+    def test_merge_with_no_change(self):
+        d = util.immutabledict({1: 2, 3: 4})
+
+        d2 = d.merge_with({}, None)
+
+        eq_(d2, {1: 2, 3: 4})
+        is_(d2, d)
+
+    def test_merge_with_dicts(self):
+        d = util.immutabledict({1: 2, 3: 4})
+
+        d2 = d.merge_with({3: 5, 7: 12}, {9: 18, 15: 25})
+
+        eq_(d, {1: 2, 3: 4})
+        eq_(d2, {1: 2, 3: 5, 7: 12, 9: 18, 15: 25})
+        assert isinstance(d2, util.immutabledict)
+
+        d3 = d.merge_with({17: 42})
+
+        eq_(d3, {1: 2, 3: 4, 17: 42})
+
+    def test_merge_with_tuples(self):
+        d = util.immutabledict({1: 2, 3: 4})
+
+        d2 = d.merge_with([(3, 5), (7, 12)], [(9, 18), (15, 25)])
+
+        eq_(d, {1: 2, 3: 4})
+        eq_(d2, {1: 2, 3: 5, 7: 12, 9: 18, 15: 25})
+
+    def test_union_dictionary(self):
+        d = util.immutabledict({1: 2, 3: 4})
+
+        d2 = d.union({3: 5, 7: 12})
+        assert isinstance(d2, util.immutabledict)
+
+        eq_(d, {1: 2, 3: 4})
+        eq_(d2, {1: 2, 3: 5, 7: 12})
+
+    def test_union_tuples(self):
+        d = util.immutabledict({1: 2, 3: 4})
+
+        d2 = d.union([(3, 5), (7, 12)])
+
+        eq_(d, {1: 2, 3: 4})
+        eq_(d2, {1: 2, 3: 5, 7: 12})
+
+    def test_keys(self):
+        d = util.immutabledict({1: 2, 3: 4})
+
+        eq_(set(d.keys()), {1, 3})
+
+    def test_values(self):
+        d = util.immutabledict({1: 2, 3: 4})
+
+        eq_(set(d.values()), {2, 4})
+
+    def test_items(self):
+        d = util.immutabledict({1: 2, 3: 4})
+
+        eq_(set(d.items()), {(1, 2), (3, 4)})
+
+    def test_contains(self):
+        d = util.immutabledict({1: 2, 3: 4})
+
+        assert 1 in d
+        assert "foo" not in d
+
+    def test_rich_compare(self):
+        d = util.immutabledict({1: 2, 3: 4})
+        d2 = util.immutabledict({1: 2, 3: 4})
+        d3 = util.immutabledict({5: 12})
+        d4 = {5: 12}
+
+        eq_(d, d2)
+        ne_(d, d3)
+        ne_(d, d4)
+        eq_(d3, d4)
+
     def test_serialize(self):
         d = util.immutabledict({1: 2, 3: 4})
         for loads, dumps in picklers():
-            print(loads(dumps(d)))
+            d2 = loads(dumps(d))
+
+            eq_(d2, {1: 2, 3: 4})
+
+            assert isinstance(d2, util.immutabledict)
 
 
 class MemoizedAttrTest(fixtures.TestBase):
index 285dd2acf700a35ca73c2991a6db874844c81edf..fcf31b8261d117efc4da60feadf0ba73c2cbde80 100644 (file)
@@ -49,6 +49,7 @@ from sqlalchemy.testing.schema import Column
 from sqlalchemy.testing.schema import Table
 from sqlalchemy.testing.util import gc_collect
 from sqlalchemy.testing.util import picklers
+from sqlalchemy.util import collections_abc
 
 
 class SomeException(Exception):
@@ -1414,13 +1415,13 @@ class EngineEventsTest(fixtures.TestBase):
             conn, clauseelement, multiparams, params, execution_options
         ):
             assert isinstance(multiparams, (list, tuple))
-            assert isinstance(params, dict)
+            assert isinstance(params, collections_abc.Mapping)
 
         def after_execute(
             conn, clauseelement, multiparams, params, result, execution_options
         ):
             assert isinstance(multiparams, (list, tuple))
-            assert isinstance(params, dict)
+            assert isinstance(params, collections_abc.Mapping)
 
         e1 = testing_engine(config.db_url)
         event.listen(e1, "before_execute", before_execute)
index 00e1d232b7f87a1274c1a4651c8002d017abcdbb..34fee847042e0812afc596140851ca07a848bf4f 100644 (file)
@@ -77,7 +77,7 @@ class PathTest(object):
         return orm_util.PathRegistry.coerce(self._make_path(path))
 
     def _assert_path_result(self, opt, q, paths):
-        q._attributes = q._attributes.copy()
+        q._attributes = dict(q._attributes)
         attr = {}
 
         if isinstance(opt, strategy_options._UnboundLoad):
@@ -1569,7 +1569,7 @@ class SubOptionsTest(PathTest, QueryTest):
 
     def _assert_opts(self, q, sub_opt, non_sub_opts):
         existing_attributes = q._attributes
-        q._attributes = q._attributes.copy()
+        q._attributes = dict(q._attributes)
         attr_a = {}
 
         for val in sub_opt._to_bind:
@@ -1580,7 +1580,7 @@ class SubOptionsTest(PathTest, QueryTest):
                 False,
             )
 
-        q._attributes = existing_attributes.copy()
+        q._attributes = dict(existing_attributes)
 
         attr_b = {}
 
index bec4ecd92928fc6ded183eb1c747d179380e9c55..9017ca84eef59f685d6a4028aa89fb54d70d29d3 100644 (file)
@@ -23,6 +23,7 @@ from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import mock
 from sqlalchemy.testing.schema import Column
 from sqlalchemy.testing.schema import Table
+from sqlalchemy.util import collections_abc
 
 
 class UpdateDeleteTest(fixtures.MappedTest):
@@ -681,7 +682,7 @@ class UpdateDeleteTest(fixtures.MappedTest):
             q.filter(User.id == 15).update({"name": "foob", "id": 123})
             # Confirm that parameters are a dict instead of tuple or list
             params = exec_.mock_calls[0][1][0]._values
-            assert isinstance(params, dict)
+            assert isinstance(params, collections_abc.Mapping)
 
     def test_update_preserve_parameter_order(self):
         User = self.classes.User
index 07f5b80db711c0ecdba4052778fcbbc67fff4150..a75aee6e99a39cd8ef52af417b94ec090fa5e554 100644 (file)
@@ -178,6 +178,7 @@ class MetaDataTest(fixtures.TestBase, ComparesTables):
         eq_(len(metadata.tables), 0)
 
     def test_metadata_tables_immutable(self):
+        # this use case was added due to #1917.
         metadata = MetaData()
 
         Table("t1", metadata, Column("x", Integer))