]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Added an optional C extension to speed up the sql layer by
authorGaëtan de Menten <gdementen@gmail.com>
Sat, 13 Feb 2010 22:53:39 +0000 (22:53 +0000)
committerGaëtan de Menten <gdementen@gmail.com>
Sat, 13 Feb 2010 22:53:39 +0000 (22:53 +0000)
  reimplementing the highest impact functions.
  The actual speedups will depend heavily on your DBAPI and
  the mix of datatypes used in your tables, and can vary from
  a 50% improvement to more than 200%. It also provides a modest
  (~20%) indirect improvement to ORM speed for large queries.
  Note that it is *not* built/installed by default.
  See README for installation instructions.

- The most common result processors conversion function were
  moved to the new "processors" module.  Dialect authors are
  encouraged to use those functions whenever they correspond
  to their needs instead of implementing custom ones.

27 files changed:
CHANGES
README
lib/sqlalchemy/cextension/processors.c [new file with mode: 0644]
lib/sqlalchemy/cextension/resultproxy.c [new file with mode: 0644]
lib/sqlalchemy/dialects/access/base.py
lib/sqlalchemy/dialects/informix/base.py
lib/sqlalchemy/dialects/maxdb/base.py
lib/sqlalchemy/dialects/mssql/base.py
lib/sqlalchemy/dialects/mysql/base.py
lib/sqlalchemy/dialects/mysql/mysqldb.py
lib/sqlalchemy/dialects/mysql/oursql.py
lib/sqlalchemy/dialects/oracle/zxjdbc.py
lib/sqlalchemy/dialects/postgresql/pg8000.py
lib/sqlalchemy/dialects/postgresql/psycopg2.py
lib/sqlalchemy/dialects/postgresql/pypostgresql.py
lib/sqlalchemy/dialects/sqlite/base.py
lib/sqlalchemy/dialects/sybase/base.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/processors.py [new file with mode: 0644]
lib/sqlalchemy/test/profiling.py
lib/sqlalchemy/types.py
setup.py
test/aaa_profiling/test_resultset.py
test/aaa_profiling/test_zoomark.py
test/perf/stress_all.py [new file with mode: 0644]
test/perf/stresstest.py [new file with mode: 0644]
test/sql/test_query.py

diff --git a/CHANGES b/CHANGES
index cdf7e71f560edcc1c35fe141cd8ff2dceb10b21e..7b67dc705037e5a7bcb12e48c5882f0f8a878af8 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -59,6 +59,20 @@ CHANGES
       [ticket:1689]
       
 - sql
+  - Added an optional C extension to speed up the sql layer by
+    reimplementing RowProxy and the most common result processors.
+    The actual speedups will depend heavily on your DBAPI and
+    the mix of datatypes used in your tables, and can vary from
+    a 50% improvement to more than 200%.  It also provides a modest
+    (~20%) indirect improvement to ORM speed for large queries.
+    Note that it is *not* built/installed by default.
+    See README for installation instructions.
+
+  - The most common result processors conversion function were
+    moved to the new "processors" module.  Dialect authors are
+    encouraged to use those functions whenever they correspond
+    to their needs instead of implementing custom ones.
+
   - Added math negation operator support, -x.
   
   - FunctionElement subclasses are now directly executable the
diff --git a/README b/README
index 7caaf2723c6b2ff2373e9619272765e12a59f941..4bbdd20a116d3eb1aac6ec77f1758b6a8a3d5096 100644 (file)
--- a/README
+++ b/README
@@ -35,6 +35,15 @@ To install::
 To use without installation, include the ``lib`` directory in your Python
 path.
 
+Installing the C extension
+--------------------------
+
+Edit "setup.py" and set ``BUILD_CEXTENSIONS`` to ``True``, then install it as
+above. If you want only to build the extension and not install it, you can do
+so with::
+
+  python setup.py build
+
 Running Tests
 -------------
 
diff --git a/lib/sqlalchemy/cextension/processors.c b/lib/sqlalchemy/cextension/processors.c
new file mode 100644 (file)
index 0000000..23b7be4
--- /dev/null
@@ -0,0 +1,384 @@
+/*
+processors.c
+Copyright (C) 2010 Gaetan de Menten gdementen@gmail.com
+
+This module is part of SQLAlchemy and is released under
+the MIT License: http://www.opensource.org/licenses/mit-license.php
+*/
+
+#include <Python.h>
+#include <datetime.h>
+
+static PyObject *
+int_to_boolean(PyObject *self, PyObject *arg)
+{
+    long l = 0;
+    PyObject *res;
+
+    if (arg == Py_None)
+        Py_RETURN_NONE;
+
+    l = PyInt_AsLong(arg);
+    if (l == 0) {
+        res = Py_False;
+    } else if (l == 1) {
+        res = Py_True;
+    } else if ((l == -1) && PyErr_Occurred()) {
+        /* -1 can be either the actual value, or an error flag. */
+        return NULL;
+    } else {
+        PyErr_SetString(PyExc_ValueError,
+                        "int_to_boolean only accepts None, 0 or 1");
+        return NULL;
+    }
+
+    Py_INCREF(res);
+    return res;
+}
+
+static PyObject *
+to_str(PyObject *self, PyObject *arg)
+{
+    if (arg == Py_None)
+        Py_RETURN_NONE;
+
+    return PyObject_Str(arg);
+}
+
+static PyObject *
+to_float(PyObject *self, PyObject *arg)
+{
+    if (arg == Py_None)
+        Py_RETURN_NONE;
+
+    return PyNumber_Float(arg);
+}
+
+static PyObject *
+str_to_datetime(PyObject *self, PyObject *arg)
+{
+    const char *str;
+    unsigned int year, month, day, hour, minute, second, microsecond = 0;
+
+    if (arg == Py_None)
+        Py_RETURN_NONE;
+
+    str = PyString_AsString(arg);
+    if (str == NULL)
+        return NULL;
+
+    /* microseconds are optional */
+    /*
+    TODO: this is slightly less picky than the Python version which would
+    not accept "2000-01-01 00:00:00.". I don't know which is better, but they
+    should be coherent.
+    */
+    if (sscanf(str, "%4u-%2u-%2u %2u:%2u:%2u.%6u", &year, &month, &day,
+               &hour, &minute, &second, &microsecond) < 6) {
+        PyErr_SetString(PyExc_ValueError, "Couldn't parse datetime string.");
+        return NULL;
+    }
+    return PyDateTime_FromDateAndTime(year, month, day,
+                                      hour, minute, second, microsecond);
+}
+
+static PyObject *
+str_to_time(PyObject *self, PyObject *arg)
+{
+    const char *str;
+    unsigned int hour, minute, second, microsecond = 0;
+
+    if (arg == Py_None)
+        Py_RETURN_NONE;
+
+    str = PyString_AsString(arg);
+    if (str == NULL)
+        return NULL;
+
+    /* microseconds are optional */
+    /*
+    TODO: this is slightly less picky than the Python version which would
+    not accept "00:00:00.". I don't know which is better, but they should be
+    coherent.
+    */
+    if (sscanf(str, "%2u:%2u:%2u.%6u", &hour, &minute, &second,
+               &microsecond) < 3) {
+        PyErr_SetString(PyExc_ValueError, "Couldn't parse time string.");
+        return NULL;
+    }
+    return PyTime_FromTime(hour, minute, second, microsecond);
+}
+
+static PyObject *
+str_to_date(PyObject *self, PyObject *arg)
+{
+    const char *str;
+    unsigned int year, month, day;
+
+    if (arg == Py_None)
+        Py_RETURN_NONE;
+
+    str = PyString_AsString(arg);
+    if (str == NULL)
+        return NULL;
+
+    if (sscanf(str, "%4u-%2u-%2u", &year, &month, &day) != 3) {
+        PyErr_SetString(PyExc_ValueError, "Couldn't parse date string.");
+        return NULL;
+    }
+    return PyDate_FromDate(year, month, day);
+}
+
+
+/***********
+ * Structs *
+ ***********/
+
+typedef struct {
+    PyObject_HEAD
+    PyObject *encoding;
+    PyObject *errors;
+} UnicodeResultProcessor;
+
+typedef struct {
+    PyObject_HEAD
+    PyObject *type;
+} DecimalResultProcessor;
+
+
+
+/**************************
+ * UnicodeResultProcessor *
+ **************************/
+
+static int
+UnicodeResultProcessor_init(UnicodeResultProcessor *self, PyObject *args,
+                            PyObject *kwds)
+{
+    PyObject *encoding, *errors;
+    static char *kwlist[] = {"encoding", "errors", NULL};
+
+    if (!PyArg_ParseTupleAndKeywords(args, kwds, "S|S:init", kwlist,
+                                     &encoding, &errors))
+        return -1;
+
+    Py_INCREF(encoding);
+    self->encoding = encoding;
+
+    if (errors) {
+        Py_INCREF(errors);
+    } else {
+        errors = PyString_FromString("strict");
+        if (errors == NULL)
+            return -1;
+    }
+    self->errors = errors;
+
+    return 0;
+}
+
+static PyObject *
+UnicodeResultProcessor_process(UnicodeResultProcessor *self, PyObject *value)
+{
+    const char *encoding, *errors;
+    char *str;
+    Py_ssize_t len;
+
+    if (value == Py_None)
+        Py_RETURN_NONE;
+
+    if (PyString_AsStringAndSize(value, &str, &len))
+        return NULL;
+
+    encoding = PyString_AS_STRING(self->encoding);
+    errors = PyString_AS_STRING(self->errors);
+
+    return PyUnicode_Decode(str, len, encoding, errors);
+}
+
+static PyMethodDef UnicodeResultProcessor_methods[] = {
+    {"process", (PyCFunction)UnicodeResultProcessor_process, METH_O,
+     "The value processor itself."},
+    {NULL}  /* Sentinel */
+};
+
+static PyTypeObject UnicodeResultProcessorType = {
+    PyObject_HEAD_INIT(NULL)
+    0,                                          /* ob_size */
+    "sqlalchemy.cprocessors.UnicodeResultProcessor",        /* tp_name */
+    sizeof(UnicodeResultProcessor),             /* tp_basicsize */
+    0,                                          /* tp_itemsize */
+    0,                                          /* tp_dealloc */
+    0,                                          /* tp_print */
+    0,                                          /* tp_getattr */
+    0,                                          /* tp_setattr */
+    0,                                          /* tp_compare */
+    0,                                          /* tp_repr */
+    0,                                          /* tp_as_number */
+    0,                                          /* tp_as_sequence */
+    0,                                          /* tp_as_mapping */
+    0,                                          /* tp_hash  */
+    0,                                          /* tp_call */
+    0,                                          /* tp_str */
+    0,                                          /* tp_getattro */
+    0,                                          /* tp_setattro */
+    0,                                          /* tp_as_buffer */
+    Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE,   /* tp_flags */
+    "UnicodeResultProcessor objects",           /* tp_doc */
+    0,                                          /* tp_traverse */
+    0,                                          /* tp_clear */
+    0,                                          /* tp_richcompare */
+    0,                                          /* tp_weaklistoffset */
+    0,                                          /* tp_iter */
+    0,                                          /* tp_iternext */
+    UnicodeResultProcessor_methods,             /* tp_methods */
+    0,                                          /* tp_members */
+    0,                                          /* tp_getset */
+    0,                                          /* tp_base */
+    0,                                          /* tp_dict */
+    0,                                          /* tp_descr_get */
+    0,                                          /* tp_descr_set */
+    0,                                          /* tp_dictoffset */
+    (initproc)UnicodeResultProcessor_init,      /* tp_init */
+    0,                                          /* tp_alloc */
+    0,                                          /* tp_new */
+};
+
+/**************************
+ * DecimalResultProcessor *
+ **************************/
+
+static int
+DecimalResultProcessor_init(DecimalResultProcessor *self, PyObject *args,
+                            PyObject *kwds)
+{
+    PyObject *type;
+
+    if (!PyArg_ParseTuple(args, "O", &type))
+        return -1;
+
+    Py_INCREF(type);
+    self->type = type;
+
+    return 0;
+}
+
+static PyObject *
+DecimalResultProcessor_process(DecimalResultProcessor *self, PyObject *value)
+{
+    PyObject *str, *result;
+
+    if (value == Py_None)
+        Py_RETURN_NONE;
+
+    if (PyFloat_CheckExact(value)) {
+        /* Decimal does not accept float values directly */
+        str = PyObject_Str(value);
+        if (str == NULL)
+            return NULL;
+        result = PyObject_CallFunctionObjArgs(self->type, str, NULL);
+        Py_DECREF(str);
+        return result;
+    } else {
+        return PyObject_CallFunctionObjArgs(self->type, value, NULL);
+    }
+}
+
+static PyMethodDef DecimalResultProcessor_methods[] = {
+    {"process", (PyCFunction)DecimalResultProcessor_process, METH_O,
+     "The value processor itself."},
+    {NULL}  /* Sentinel */
+};
+
+static PyTypeObject DecimalResultProcessorType = {
+    PyObject_HEAD_INIT(NULL)
+    0,                                          /* ob_size */
+    "sqlalchemy.DecimalResultProcessor",        /* tp_name */
+    sizeof(DecimalResultProcessor),             /* tp_basicsize */
+    0,                                          /* tp_itemsize */
+    0,                                          /* tp_dealloc */
+    0,                                          /* tp_print */
+    0,                                          /* tp_getattr */
+    0,                                          /* tp_setattr */
+    0,                                          /* tp_compare */
+    0,                                          /* tp_repr */
+    0,                                          /* tp_as_number */
+    0,                                          /* tp_as_sequence */
+    0,                                          /* tp_as_mapping */
+    0,                                          /* tp_hash  */
+    0,                                          /* tp_call */
+    0,                                          /* tp_str */
+    0,                                          /* tp_getattro */
+    0,                                          /* tp_setattro */
+    0,                                          /* tp_as_buffer */
+    Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE,   /* tp_flags */
+    "DecimalResultProcessor objects",           /* tp_doc */
+    0,                                          /* tp_traverse */
+    0,                                          /* tp_clear */
+    0,                                          /* tp_richcompare */
+    0,                                          /* tp_weaklistoffset */
+    0,                                          /* tp_iter */
+    0,                                          /* tp_iternext */
+    DecimalResultProcessor_methods,             /* tp_methods */
+    0,                                          /* tp_members */
+    0,                                          /* tp_getset */
+    0,                                          /* tp_base */
+    0,                                          /* tp_dict */
+    0,                                          /* tp_descr_get */
+    0,                                          /* tp_descr_set */
+    0,                                          /* tp_dictoffset */
+    (initproc)DecimalResultProcessor_init,      /* tp_init */
+    0,                                          /* tp_alloc */
+    0,                                          /* tp_new */
+};
+
+#ifndef PyMODINIT_FUNC  /* declarations for DLL import/export */
+#define PyMODINIT_FUNC void
+#endif
+
+
+static PyMethodDef module_methods[] = {
+    {"int_to_boolean", int_to_boolean, METH_O,
+     "Convert an integer to a boolean."},
+    {"to_str", to_str, METH_O,
+     "Convert any value to its string representation."},
+    {"to_float", to_float, METH_O,
+     "Convert any value to its floating point representation."},
+    {"str_to_datetime", str_to_datetime, METH_O,
+     "Convert an ISO string to a datetime.datetime object."},
+    {"str_to_time", str_to_time, METH_O,
+     "Convert an ISO string to a datetime.time object."},
+    {"str_to_date", str_to_date, METH_O,
+     "Convert an ISO string to a datetime.date object."},
+    {NULL, NULL, 0, NULL}        /* Sentinel */
+};
+
+PyMODINIT_FUNC
+initcprocessors(void)
+{
+    PyObject *m;
+
+    UnicodeResultProcessorType.tp_new = PyType_GenericNew;
+    if (PyType_Ready(&UnicodeResultProcessorType) < 0)
+        return;
+
+    DecimalResultProcessorType.tp_new = PyType_GenericNew;
+    if (PyType_Ready(&DecimalResultProcessorType) < 0)
+        return;
+
+    m = Py_InitModule3("cprocessors", module_methods,
+                       "Module containing C versions of data processing functions.");
+    if (m == NULL)
+        return;
+
+    PyDateTime_IMPORT;
+
+    Py_INCREF(&UnicodeResultProcessorType);
+    PyModule_AddObject(m, "UnicodeResultProcessor",
+                       (PyObject *)&UnicodeResultProcessorType);
+
+    Py_INCREF(&DecimalResultProcessorType);
+    PyModule_AddObject(m, "DecimalResultProcessor",
+                       (PyObject *)&DecimalResultProcessorType);
+}
+
diff --git a/lib/sqlalchemy/cextension/resultproxy.c b/lib/sqlalchemy/cextension/resultproxy.c
new file mode 100644 (file)
index 0000000..14ea182
--- /dev/null
@@ -0,0 +1,586 @@
+/*
+resultproxy.c
+Copyright (C) 2010 Gaetan de Menten gdementen@gmail.com
+
+This module is part of SQLAlchemy and is released under
+the MIT License: http://www.opensource.org/licenses/mit-license.php
+*/
+
+#include <Python.h>
+
+
+/***********
+ * Structs *
+ ***********/
+
+typedef struct {
+    PyObject_HEAD
+    PyObject *parent;
+    PyObject *row;
+    PyObject *processors;
+    PyObject *keymap;
+} BaseRowProxy;
+
+/****************
+ * BaseRowProxy *
+ ****************/
+
+static PyObject *
+rowproxy_reconstructor(PyObject *self, PyObject *args)
+{
+    PyObject *cls, *state, *tmp;
+    BaseRowProxy *obj;
+
+    if (!PyArg_ParseTuple(args, "OO", &cls, &state))
+        return NULL;
+
+    obj = (BaseRowProxy *)PyObject_CallMethod(cls, "__new__", "O", cls);
+    if (obj == NULL)
+        return NULL;
+
+    tmp = PyObject_CallMethod((PyObject *)obj, "__setstate__", "O", state);
+    if (tmp == NULL) {
+        Py_DECREF(obj);
+        return NULL;
+    }
+    Py_DECREF(tmp);
+
+    if (obj->parent == NULL || obj->row == NULL ||
+        obj->processors == NULL || obj->keymap == NULL) {
+        PyErr_SetString(PyExc_RuntimeError,
+            "__setstate__ for BaseRowProxy subclasses must set values "
+            "for parent, row, processors and keymap");
+        Py_DECREF(obj);
+        return NULL;
+    }
+
+    return (PyObject *)obj;
+}
+
+static int
+BaseRowProxy_init(BaseRowProxy *self, PyObject *args, PyObject *kwds)
+{
+    PyObject *parent, *row, *processors, *keymap;
+
+    if (!PyArg_UnpackTuple(args, "BaseRowProxy", 4, 4,
+                           &parent, &row, &processors, &keymap))
+        return -1;
+
+    Py_INCREF(parent);
+    self->parent = parent;
+
+    if (!PyTuple_CheckExact(row)) {
+        PyErr_SetString(PyExc_TypeError, "row must be a tuple");
+        return -1;
+    }
+    Py_INCREF(row);
+    self->row = row;
+
+    if (!PyList_CheckExact(processors)) {
+        PyErr_SetString(PyExc_TypeError, "processors must be a list");
+        return -1;
+    }
+    Py_INCREF(processors);
+    self->processors = processors;
+
+    if (!PyDict_CheckExact(keymap)) {
+        PyErr_SetString(PyExc_TypeError, "keymap must be a dict");
+        return -1;
+    }
+    Py_INCREF(keymap);
+    self->keymap = keymap;
+
+    return 0;
+}
+
+/* We need the reduce method because otherwise the default implementation
+ * does very weird stuff for pickle protocol 0 and 1. It calls
+ * BaseRowProxy.__new__(RowProxy_instance) upon *pickling*.
+ */
+static PyObject *
+BaseRowProxy_reduce(PyObject *self)
+{
+       PyObject *method, *state;
+       PyObject *module, *reconstructor, *cls;
+
+       method = PyObject_GetAttrString(self, "__getstate__");
+       if (method == NULL)
+        return NULL;
+
+    state = PyObject_CallObject(method, NULL);
+    Py_DECREF(method);
+    if (state == NULL)
+        return NULL;
+
+    module = PyImport_ImportModule("sqlalchemy.engine.base");
+    if (module == NULL)
+        return NULL;
+
+    reconstructor = PyObject_GetAttrString(module, "rowproxy_reconstructor");
+    Py_DECREF(module);
+    if (reconstructor == NULL) {
+        Py_DECREF(state);
+        return NULL;
+    }
+
+    cls = PyObject_GetAttrString(self, "__class__");
+    if (cls == NULL) {
+        Py_DECREF(reconstructor);
+        Py_DECREF(state);
+        return NULL;
+    }
+
+    return Py_BuildValue("(N(NN))", reconstructor, cls, state);
+}
+
+static void
+BaseRowProxy_dealloc(BaseRowProxy *self)
+{
+    Py_XDECREF(self->parent);
+    Py_XDECREF(self->row);
+    Py_XDECREF(self->processors);
+    Py_XDECREF(self->keymap);
+    self->ob_type->tp_free((PyObject *)self);
+}
+
+static PyObject *
+BaseRowProxy_processvalues(PyObject *values, PyObject *processors, int astuple)
+{
+    Py_ssize_t num_values, num_processors;
+    PyObject **valueptr, **funcptr, **resultptr;
+    PyObject *func, *result, *processed_value;
+
+    num_values = Py_SIZE(values);
+    num_processors = Py_SIZE(processors);
+    if (num_values != num_processors) {
+        PyErr_SetString(PyExc_RuntimeError,
+            "number of values in row difer from number of column processors");
+        return NULL;
+    }
+
+    if (astuple) {
+        result = PyTuple_New(num_values);
+    } else {
+        result = PyList_New(num_values);
+    }
+    if (result == NULL)
+        return NULL;
+
+    /* we don't need to use PySequence_Fast as long as values, processors and
+     * result are simple tuple or lists. */
+    valueptr = PySequence_Fast_ITEMS(values);
+    funcptr = PySequence_Fast_ITEMS(processors);
+    resultptr = PySequence_Fast_ITEMS(result);
+    while (--num_values >= 0) {
+        func = *funcptr;
+        if (func != Py_None) {
+            processed_value = PyObject_CallFunctionObjArgs(func, *valueptr,
+                                                           NULL);
+            if (processed_value == NULL) {
+                Py_DECREF(result);
+                return NULL;
+            }
+            *resultptr = processed_value;
+        } else {
+            Py_INCREF(*valueptr);
+            *resultptr = *valueptr;
+        }
+        valueptr++;
+        funcptr++;
+        resultptr++;
+    }
+    return result;
+}
+
+static PyListObject *
+BaseRowProxy_values(BaseRowProxy *self)
+{
+    return (PyListObject *)BaseRowProxy_processvalues(self->row,
+                                                      self->processors, 0);
+}
+
+static PyTupleObject *
+BaseRowProxy_tuplevalues(BaseRowProxy *self)
+{
+    return (PyTupleObject *)BaseRowProxy_processvalues(self->row,
+                                                       self->processors, 1);
+}
+
+static PyObject *
+BaseRowProxy_iter(BaseRowProxy *self)
+{
+    PyObject *values, *result;
+
+    values = (PyObject *)BaseRowProxy_tuplevalues(self);
+    if (values == NULL)
+        return NULL;
+
+    result = PyObject_GetIter(values);
+    Py_DECREF(values);
+    if (result == NULL)
+        return NULL;
+
+    return result;
+}
+
+static Py_ssize_t
+BaseRowProxy_length(BaseRowProxy *self)
+{
+    return Py_SIZE(self->row);
+}
+
+static PyObject *
+BaseRowProxy_subscript(BaseRowProxy *self, PyObject *key)
+{
+    PyObject *processors, *values;
+    PyObject *processor, *value;
+    PyObject *record, *result, *indexobject;
+    PyObject *exc_module, *exception;
+    char *cstr_key;
+    long index;
+
+    if (PyInt_CheckExact(key)) {
+        index = PyInt_AS_LONG(key);
+    } else if (PyLong_CheckExact(key)) {
+        index = PyLong_AsLong(key);
+        if ((index == -1) && PyErr_Occurred())
+            /* -1 can be either the actual value, or an error flag. */
+            return NULL;
+    } else if (PySlice_Check(key)) {
+        values = PyObject_GetItem(self->row, key);
+        if (values == NULL)
+            return NULL;
+
+        processors = PyObject_GetItem(self->processors, key);
+        if (processors == NULL) {
+            Py_DECREF(values);
+            return NULL;
+        }
+
+        result = BaseRowProxy_processvalues(values, processors, 1);
+        Py_DECREF(values);
+        Py_DECREF(processors);
+        return result;
+    } else {
+        record = PyDict_GetItem((PyObject *)self->keymap, key);
+        if (record == NULL) {
+            record = PyObject_CallMethod(self->parent, "_key_fallback",
+                                         "O", key);
+            if (record == NULL)
+                return NULL;
+        }
+
+        indexobject = PyTuple_GetItem(record, 1);
+        if (indexobject == NULL)
+            return NULL;
+
+        if (indexobject == Py_None) {
+            exc_module = PyImport_ImportModule("sqlalchemy.exc");
+            if (exc_module == NULL)
+                return NULL;
+
+            exception = PyObject_GetAttrString(exc_module,
+                                               "InvalidRequestError");
+            Py_DECREF(exc_module);
+            if (exception == NULL)
+                return NULL;
+
+            cstr_key = PyString_AsString(key);
+            if (cstr_key == NULL)
+                return NULL;
+
+            PyErr_Format(exception,
+                    "Ambiguous column name '%s' in result set! "
+                    "try 'use_labels' option on select statement.", cstr_key);
+            return NULL;
+        }
+
+        index = PyInt_AsLong(indexobject);
+        if ((index == -1) && PyErr_Occurred())
+            /* -1 can be either the actual value, or an error flag. */
+            return NULL;
+    }
+    processor = PyList_GetItem(self->processors, index);
+    if (processor == NULL)
+        return NULL;
+
+    value = PyTuple_GetItem(self->row, index);
+    if (value == NULL)
+        return NULL;
+
+    if (processor != Py_None) {
+        return PyObject_CallFunctionObjArgs(processor, value, NULL);
+    } else {
+        Py_INCREF(value);
+        return value;
+    }
+}
+
+static PyObject *
+BaseRowProxy_getattro(BaseRowProxy *self, PyObject *name)
+{
+    PyObject *tmp;
+
+    if (!(tmp = PyObject_GenericGetAttr((PyObject *)self, name))) {
+        if (!PyErr_ExceptionMatches(PyExc_AttributeError))
+            return NULL;
+        PyErr_Clear();
+    }
+    else
+        return tmp;
+
+    return BaseRowProxy_subscript(self, name);
+}
+
+/***********************
+ * getters and setters *
+ ***********************/
+
+static PyObject *
+BaseRowProxy_getparent(BaseRowProxy *self, void *closure)
+{
+    Py_INCREF(self->parent);
+    return self->parent;
+}
+
+static int
+BaseRowProxy_setparent(BaseRowProxy *self, PyObject *value, void *closure)
+{
+    PyObject *module, *cls;
+
+    if (value == NULL) {
+        PyErr_SetString(PyExc_TypeError,
+                        "Cannot delete the 'parent' attribute");
+        return -1;
+    }
+
+    module = PyImport_ImportModule("sqlalchemy.engine.base");
+    if (module == NULL)
+        return -1;
+
+    cls = PyObject_GetAttrString(module, "ResultMetaData");
+    Py_DECREF(module);
+    if (cls == NULL)
+        return -1;
+
+    if (PyObject_IsInstance(value, cls) != 1) {
+        PyErr_SetString(PyExc_TypeError,
+                        "The 'parent' attribute value must be an instance of "
+                        "ResultMetaData");
+        return -1;
+    }
+    Py_DECREF(cls);
+    Py_XDECREF(self->parent);
+    Py_INCREF(value);
+    self->parent = value;
+
+    return 0;
+}
+
+static PyObject *
+BaseRowProxy_getrow(BaseRowProxy *self, void *closure)
+{
+    Py_INCREF(self->row);
+    return self->row;
+}
+
+static int
+BaseRowProxy_setrow(BaseRowProxy *self, PyObject *value, void *closure)
+{
+    if (value == NULL) {
+        PyErr_SetString(PyExc_TypeError,
+                        "Cannot delete the 'row' attribute");
+        return -1;
+    }
+
+    if (!PyTuple_CheckExact(value)) {
+        PyErr_SetString(PyExc_TypeError,
+                        "The 'row' attribute value must be a tuple");
+        return -1;
+    }
+
+    Py_XDECREF(self->row);
+    Py_INCREF(value);
+    self->row = value;
+
+    return 0;
+}
+
+static PyObject *
+BaseRowProxy_getprocessors(BaseRowProxy *self, void *closure)
+{
+    Py_INCREF(self->processors);
+    return self->processors;
+}
+
+static int
+BaseRowProxy_setprocessors(BaseRowProxy *self, PyObject *value, void *closure)
+{
+    if (value == NULL) {
+        PyErr_SetString(PyExc_TypeError,
+                        "Cannot delete the 'processors' attribute");
+        return -1;
+    }
+
+    if (!PyList_CheckExact(value)) {
+        PyErr_SetString(PyExc_TypeError,
+                        "The 'processors' attribute value must be a list");
+        return -1;
+    }
+
+    Py_XDECREF(self->processors);
+    Py_INCREF(value);
+    self->processors = value;
+
+    return 0;
+}
+
+static PyObject *
+BaseRowProxy_getkeymap(BaseRowProxy *self, void *closure)
+{
+    Py_INCREF(self->keymap);
+    return self->keymap;
+}
+
+static int
+BaseRowProxy_setkeymap(BaseRowProxy *self, PyObject *value, void *closure)
+{
+    if (value == NULL) {
+        PyErr_SetString(PyExc_TypeError,
+                        "Cannot delete the 'keymap' attribute");
+        return -1;
+    }
+
+    if (!PyDict_CheckExact(value)) {
+        PyErr_SetString(PyExc_TypeError,
+                        "The 'keymap' attribute value must be a dict");
+        return -1;
+    }
+
+    Py_XDECREF(self->keymap);
+    Py_INCREF(value);
+    self->keymap = value;
+
+    return 0;
+}
+
+static PyGetSetDef BaseRowProxy_getseters[] = {
+    {"_parent",
+     (getter)BaseRowProxy_getparent, (setter)BaseRowProxy_setparent,
+     "ResultMetaData",
+     NULL},
+    {"_row",
+     (getter)BaseRowProxy_getrow, (setter)BaseRowProxy_setrow,
+     "Original row tuple",
+     NULL},
+    {"_processors",
+     (getter)BaseRowProxy_getprocessors, (setter)BaseRowProxy_setprocessors,
+     "list of type processors",
+     NULL},
+    {"_keymap",
+     (getter)BaseRowProxy_getkeymap, (setter)BaseRowProxy_setkeymap,
+     "Key to (processor, index) dict",
+     NULL},
+    {NULL}
+};
+
+static PyMethodDef BaseRowProxy_methods[] = {
+    {"values", (PyCFunction)BaseRowProxy_values, METH_NOARGS,
+     "Return the values represented by this BaseRowProxy as a list."},
+       {"__reduce__",  (PyCFunction)BaseRowProxy_reduce, METH_NOARGS,
+        "Pickle support method."},
+    {NULL}  /* Sentinel */
+};
+
+static PySequenceMethods BaseRowProxy_as_sequence = {
+    (lenfunc)BaseRowProxy_length,   /* sq_length */
+    0,                              /* sq_concat */
+    0,                              /* sq_repeat */
+    0,                              /* sq_item */
+    0,                              /* sq_slice */
+    0,                              /* sq_ass_item */
+    0,                              /* sq_ass_slice */
+    0,                              /* sq_contains */
+    0,                              /* sq_inplace_concat */
+    0,                              /* sq_inplace_repeat */
+};
+
+static PyMappingMethods BaseRowProxy_as_mapping = {
+    (lenfunc)BaseRowProxy_length,       /* mp_length */
+    (binaryfunc)BaseRowProxy_subscript, /* mp_subscript */
+    0                                   /* mp_ass_subscript */
+};
+
+static PyTypeObject BaseRowProxyType = {
+    PyObject_HEAD_INIT(NULL)
+    0,                                  /* ob_size */
+    "sqlalchemy.cresultproxy.BaseRowProxy",          /* tp_name */
+    sizeof(BaseRowProxy),               /* tp_basicsize */
+    0,                                  /* tp_itemsize */
+    (destructor)BaseRowProxy_dealloc,   /* tp_dealloc */
+    0,                                  /* tp_print */
+    0,                                  /* tp_getattr */
+    0,                                  /* tp_setattr */
+    0,                                  /* tp_compare */
+    0,                                  /* tp_repr */
+    0,                                  /* tp_as_number */
+    &BaseRowProxy_as_sequence,          /* tp_as_sequence */
+    &BaseRowProxy_as_mapping,           /* tp_as_mapping */
+    0,                                  /* tp_hash */
+    0,                                  /* tp_call */
+    0,                                  /* tp_str */
+    (getattrofunc)BaseRowProxy_getattro,/* tp_getattro */
+    0,                                  /* tp_setattro */
+    0,                                  /* tp_as_buffer */
+    Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE,               /* tp_flags */
+    "BaseRowProxy is a abstract base class for RowProxy",   /* tp_doc */
+    0,                                  /* tp_traverse */
+    0,                                  /* tp_clear */
+    0,                                  /* tp_richcompare */
+    0,                                  /* tp_weaklistoffset */
+    (getiterfunc)BaseRowProxy_iter,     /* tp_iter */
+    0,                                  /* tp_iternext */
+    BaseRowProxy_methods,               /* tp_methods */
+    0,                                  /* tp_members */
+    BaseRowProxy_getseters,             /* tp_getset */
+    0,                                  /* tp_base */
+    0,                                  /* tp_dict */
+    0,                                  /* tp_descr_get */
+    0,                                  /* tp_descr_set */
+    0,                                  /* tp_dictoffset */
+    (initproc)BaseRowProxy_init,        /* tp_init */
+    0,                                  /* tp_alloc */
+    0                                   /* tp_new */
+};
+
+
+#ifndef PyMODINIT_FUNC  /* declarations for DLL import/export */
+#define PyMODINIT_FUNC void
+#endif
+
+
+static PyMethodDef module_methods[] = {
+    {"rowproxy_reconstructor", rowproxy_reconstructor, METH_VARARGS,
+     "reconstruct a RowProxy instance from its pickled form."},
+    {NULL, NULL, 0, NULL}        /* Sentinel */
+};
+
+PyMODINIT_FUNC
+initcresultproxy(void)
+{
+    PyObject *m;
+
+    BaseRowProxyType.tp_new = PyType_GenericNew;
+    if (PyType_Ready(&BaseRowProxyType) < 0)
+        return;
+
+    m = Py_InitModule3("cresultproxy", module_methods,
+                       "Module containing C versions of core ResultProxy classes.");
+    if (m == NULL)
+        return;
+
+    Py_INCREF(&BaseRowProxyType);
+    PyModule_AddObject(m, "BaseRowProxy", (PyObject *)&BaseRowProxyType);
+
+}
+
index a46ad247a4e7089d65e994685096021b070d93fd..c10e77011cc2f5807242c2e1e979ad0b2ffcac13 100644 (file)
@@ -17,23 +17,17 @@ This dialect is *not* tested on SQLAlchemy 0.6.
 from sqlalchemy import sql, schema, types, exc, pool
 from sqlalchemy.sql import compiler, expression
 from sqlalchemy.engine import default, base
-
+from sqlalchemy import processors
 
 class AcNumeric(types.Numeric):
-    def result_processor(self, dialect, coltype):
-        return None
+    def get_col_spec(self):
+        return "NUMERIC"
 
     def bind_processor(self, dialect):
-        def process(value):
-            if value is None:
-                # Not sure that this exception is needed
-                return value
-            else:
-                return str(value)
-        return process
+        return processors.to_str
 
-    def get_col_spec(self):
-        return "NUMERIC"
+    def result_processor(self, dialect, coltype):
+        return None
 
 class AcFloat(types.Float):
     def get_col_spec(self):
@@ -41,11 +35,7 @@ class AcFloat(types.Float):
 
     def bind_processor(self, dialect):
         """By converting to string, we can use Decimal types round-trip."""
-        def process(value):
-            if not value is None:
-                return str(value)
-            return None
-        return process
+        return processors.to_str
 
 class AcInteger(types.Integer):
     def get_col_spec(self):
@@ -103,25 +93,6 @@ class AcBoolean(types.Boolean):
     def get_col_spec(self):
         return "YESNO"
 
-    def result_processor(self, dialect, coltype):
-        def process(value):
-            if value is None:
-                return None
-            return value and True or False
-        return process
-
-    def bind_processor(self, dialect):
-        def process(value):
-            if value is True:
-                return 1
-            elif value is False:
-                return 0
-            elif value is None:
-                return None
-            else:
-                return value and True or False
-        return process
-
 class AcTimeStamp(types.TIMESTAMP):
     def get_col_spec(self):
         return "TIMESTAMP"
@@ -443,4 +414,4 @@ dialect.poolclass = pool.SingletonThreadPool
 dialect.statement_compiler = AccessCompiler
 dialect.ddlcompiler = AccessDDLCompiler
 dialect.preparer = AccessIdentifierPreparer
-dialect.execution_ctx_cls = AccessExecutionContext
\ No newline at end of file
+dialect.execution_ctx_cls = AccessExecutionContext
index 2802d493a6d80a0404e5d5e9d872a6d027239225..54aae6eb34275882a97a90bf36eebcd630ddc500 100644 (file)
@@ -302,4 +302,4 @@ class InformixDialect(default.DefaultDialect):
     @reflection.cache
     def get_indexes(self, connection, table_name, schema, **kw):
         # TODO
-        return []
\ No newline at end of file
+        return []
index 2e0b9518b0035bb49b1469e328a022a22e90a11d..f409f3213b87bdc4f4bc0191acd3feb19d2f5378 100644 (file)
@@ -60,7 +60,7 @@ this.
 """
 import datetime, itertools, re
 
-from sqlalchemy import exc, schema, sql, util
+from sqlalchemy import exc, schema, sql, util, processors
 from sqlalchemy.sql import operators as sql_operators, expression as sql_expr
 from sqlalchemy.sql import compiler, visitors
 from sqlalchemy.engine import base as engine_base, default
@@ -86,6 +86,12 @@ class _StringType(sqltypes.String):
             return process
 
     def result_processor(self, dialect, coltype):
+        #XXX: this code is probably very slow and one should try (if at all
+        # possible) to determine the correct code path on a per-connection
+        # basis (ie, here in result_processor, instead of inside the processor
+        # function itself) and probably also use a few generic
+        # processors, or possibly per query (though there is no mechanism
+        # for that yet).
         def process(value):
             while True:
                 if value is None:
@@ -152,6 +158,7 @@ class MaxNumeric(sqltypes.Numeric):
     def bind_processor(self, dialect):
         return None
 
+
 class MaxTimestamp(sqltypes.DateTime):
     def bind_processor(self, dialect):
         def process(value):
@@ -172,25 +179,30 @@ class MaxTimestamp(sqltypes.DateTime):
         return process
 
     def result_processor(self, dialect, coltype):
-        def process(value):
-            if value is None:
-                return None
-            elif dialect.datetimeformat == 'internal':
-                return datetime.datetime(
-                    *[int(v)
-                      for v in (value[0:4], value[4:6], value[6:8],
-                                value[8:10], value[10:12], value[12:14],
-                                value[14:])])
-            elif dialect.datetimeformat == 'iso':
-                return datetime.datetime(
-                    *[int(v)
-                      for v in (value[0:4], value[5:7], value[8:10],
-                                value[11:13], value[14:16], value[17:19],
-                                value[20:])])
-            else:
-                raise exc.InvalidRequestError(
-                    "datetimeformat '%s' is not supported." % (
-                    dialect.datetimeformat,))
+        if dialect.datetimeformat == 'internal':
+            def process(value):
+                if value is None:
+                    return None
+                else:
+                    return datetime.datetime(
+                        *[int(v)
+                          for v in (value[0:4], value[4:6], value[6:8],
+                                    value[8:10], value[10:12], value[12:14],
+                                    value[14:])])
+        elif dialect.datetimeformat == 'iso':
+            def process(value):
+                if value is None:
+                    return None
+                else:
+                    return datetime.datetime(
+                        *[int(v)
+                          for v in (value[0:4], value[5:7], value[8:10],
+                                    value[11:13], value[14:16], value[17:19],
+                                    value[20:])])
+        else:
+            raise exc.InvalidRequestError(
+                "datetimeformat '%s' is not supported." % 
+                dialect.datetimeformat)
         return process
 
 
@@ -212,19 +224,24 @@ class MaxDate(sqltypes.Date):
         return process
 
     def result_processor(self, dialect, coltype):
-        def process(value):
-            if value is None:
-                return None
-            elif dialect.datetimeformat == 'internal':
-                return datetime.date(
-                    *[int(v) for v in (value[0:4], value[4:6], value[6:8])])
-            elif dialect.datetimeformat == 'iso':
-                return datetime.date(
-                    *[int(v) for v in (value[0:4], value[5:7], value[8:10])])
-            else:
-                raise exc.InvalidRequestError(
-                    "datetimeformat '%s' is not supported." % (
-                    dialect.datetimeformat,))
+        if dialect.datetimeformat == 'internal':
+            def process(value):
+                if value is None:
+                    return None
+                else:
+                    return datetime.date(int(value[0:4]), int(value[4:6]), 
+                                         int(value[6:8]))
+        elif dialect.datetimeformat == 'iso':
+            def process(value):
+                if value is None:
+                    return None
+                else:
+                    return datetime.date(int(value[0:4]), int(value[5:7]), 
+                                         int(value[8:10]))
+        else:
+            raise exc.InvalidRequestError(
+                "datetimeformat '%s' is not supported." % 
+                dialect.datetimeformat)
         return process
 
 
@@ -246,31 +263,30 @@ class MaxTime(sqltypes.Time):
         return process
 
     def result_processor(self, dialect, coltype):
-        def process(value):
-            if value is None:
-                return None
-            elif dialect.datetimeformat == 'internal':
-                t = datetime.time(
-                    *[int(v) for v in (value[0:4], value[4:6], value[6:8])])
-                return t
-            elif dialect.datetimeformat == 'iso':
-                return datetime.time(
-                    *[int(v) for v in (value[0:4], value[5:7], value[8:10])])
-            else:
-                raise exc.InvalidRequestError(
-                    "datetimeformat '%s' is not supported." % (
-                    dialect.datetimeformat,))
+        if dialect.datetimeformat == 'internal':
+            def process(value):
+                if value is None:
+                    return None
+                else:
+                    return datetime.time(int(value[0:4]), int(value[4:6]), 
+                                         int(value[6:8]))
+        elif dialect.datetimeformat == 'iso':
+            def process(value):
+                if value is None:
+                    return None
+                else:
+                    return datetime.time(int(value[0:4]), int(value[5:7]),
+                                         int(value[8:10]))
+        else:
+            raise exc.InvalidRequestError(
+                "datetimeformat '%s' is not supported." % 
+                dialect.datetimeformat)
         return process
 
 
 class MaxBlob(sqltypes.LargeBinary):
     def bind_processor(self, dialect):
-        def process(value):
-            if value is None:
-                return None
-            else:
-                return str(value)
-        return process
+        return processors.to_str
 
     def result_processor(self, dialect, coltype):
         def process(value):
index 4e58d64b362d40d5b192b23cef20f059d257fbc2..3f4e0b9f34b6da81a5515499feb26c2e2b55afec 100644 (file)
@@ -233,11 +233,10 @@ from sqlalchemy.sql import select, compiler, expression, \
                             functions as sql_functions, util as sql_util
 from sqlalchemy.engine import default, base, reflection
 from sqlalchemy import types as sqltypes
-from decimal import Decimal as _python_Decimal
+from sqlalchemy import processors
 from sqlalchemy.types import INTEGER, BIGINT, SMALLINT, DECIMAL, NUMERIC, \
                                 FLOAT, TIMESTAMP, DATETIME, DATE, BINARY,\
                                 VARBINARY, BLOB
-            
 
 from sqlalchemy.dialects.mssql import information_schema as ischema
 
@@ -280,22 +279,12 @@ RESERVED_WORDS = set(
 class _MSNumeric(sqltypes.Numeric):
     def result_processor(self, dialect, coltype):
         if self.asdecimal:
-            def process(value):
-                if value is not None:
-                    return _python_Decimal(str(value))
-                else:
-                    return value
-            return process
+            return processors.to_decimal_processor_factory(decimal.Decimal)
         else:
             #XXX: if the DBAPI returns a float (this is likely, given the
             # processor when asdecimal is True), this should be a None
             # processor instead.
-            def process(value):
-                if value is not None:
-                    return float(value)
-                else:
-                    return value
-            return process
+            return processors.to_float
 
     def bind_processor(self, dialect):
         def process(value):
index eb348f1a16378bb64745874b4744fbf5f319fa73..82a4af941fdb021530125dff96893bda4803562c 100644 (file)
@@ -351,7 +351,8 @@ class DECIMAL(_NumericType, sqltypes.DECIMAL):
           numeric.
 
         """
-        super(DECIMAL, self).__init__(precision=precision, scale=scale, asdecimal=asdecimal, **kw)
+        super(DECIMAL, self).__init__(precision=precision, scale=scale,
+                                      asdecimal=asdecimal, **kw)
 
     
 class DOUBLE(_FloatType):
@@ -375,7 +376,8 @@ class DOUBLE(_FloatType):
           numeric.
 
         """
-        super(DOUBLE, self).__init__(precision=precision, scale=scale, asdecimal=asdecimal, **kw)
+        super(DOUBLE, self).__init__(precision=precision, scale=scale,
+                                     asdecimal=asdecimal, **kw)
 
 class REAL(_FloatType):
     """MySQL REAL type."""
@@ -398,7 +400,8 @@ class REAL(_FloatType):
           numeric.
 
         """
-        super(REAL, self).__init__(precision=precision, scale=scale, asdecimal=asdecimal, **kw)
+        super(REAL, self).__init__(precision=precision, scale=scale,
+                                   asdecimal=asdecimal, **kw)
 
 class FLOAT(_FloatType, sqltypes.FLOAT):
     """MySQL FLOAT type."""
@@ -421,7 +424,8 @@ class FLOAT(_FloatType, sqltypes.FLOAT):
           numeric.
 
         """
-        super(FLOAT, self).__init__(precision=precision, scale=scale, asdecimal=asdecimal, **kw)
+        super(FLOAT, self).__init__(precision=precision, scale=scale,
+                                    asdecimal=asdecimal, **kw)
 
     def bind_processor(self, dialect):
         return None
@@ -2459,6 +2463,7 @@ class _DecodingRowProxy(object):
     def __init__(self, rowproxy, charset):
         self.rowproxy = rowproxy
         self.charset = charset
+
     def __getitem__(self, index):
         item = self.rowproxy[index]
         if isinstance(item, _array):
@@ -2467,6 +2472,7 @@ class _DecodingRowProxy(object):
             return item.decode(self.charset)
         else:
             return item
+
     def __getattr__(self, attr):
         item = getattr(self.rowproxy, attr)
         if isinstance(item, _array):
index c07ed87135d8c5458f6436821d4f4eca0194e47c..8cfd5930f69e6c3d66d8340e8bd830eddbab8fab 100644 (file)
@@ -28,6 +28,7 @@ from sqlalchemy.dialects.mysql.base import (DECIMAL, MySQLDialect, MySQLExecutio
 from sqlalchemy.engine import base as engine_base, default
 from sqlalchemy.sql import operators as sql_operators
 from sqlalchemy import exc, log, schema, sql, types as sqltypes, util
+from sqlalchemy import processors
 
 class MySQL_mysqldbExecutionContext(MySQLExecutionContext):
     
@@ -51,12 +52,7 @@ class _DecimalType(_NumericType):
     def result_processor(self, dialect, coltype):
         if self.asdecimal:
             return None
-        def process(value):
-            if value is not None:
-                return float(value)
-            else:
-                return value
-        return process
+        return processors.to_float
 
 
 class _MySQLdbNumeric(_DecimalType, NUMERIC):
index a03aa988eb9392ca4d063ba471f35b300561a9e3..1fca6850a14157140c524094ead6d2b960793a8e 100644 (file)
@@ -29,18 +29,14 @@ from sqlalchemy.dialects.mysql.base import (BIT, MySQLDialect, MySQLExecutionCon
 from sqlalchemy.engine import base as engine_base, default
 from sqlalchemy.sql import operators as sql_operators
 from sqlalchemy import exc, log, schema, sql, types as sqltypes, util
+from sqlalchemy import processors
 
 
 class _oursqlNumeric(NUMERIC):
     def result_processor(self, dialect, coltype):
         if self.asdecimal:
             return None
-        def process(value):
-            if value is not None:
-                return float(value)
-            else:
-                return value
-        return process
+        return processors.to_float
 
 
 class _oursqlBIT(BIT):
index 22a1f443cfbb08ae5d2e97ba7b730e9a0ffc8a65..fba16288a57247e5278a2e90afac18988c71c3e5 100644 (file)
@@ -32,6 +32,9 @@ class _ZxJDBCDate(sqltypes.Date):
 class _ZxJDBCNumeric(sqltypes.Numeric):
 
     def result_processor(self, dialect, coltype):
+        #XXX: does the dialect return Decimal or not???
+        # if it does (in all cases), we could use a None processor as well as
+        # the to_float generic processor
         if self.asdecimal:
             def process(value):
                 if isinstance(value, decimal.Decimal):
index e90bebb6b290954c1d163e7b02f87bcef22baddd..079b05530c6c41b1561162c42f910942240efbf6 100644 (file)
@@ -19,31 +19,23 @@ Interval
 Passing data from/to the Interval type is not supported as of yet.
 
 """
-from sqlalchemy.engine import default
 import decimal
+
+from sqlalchemy.engine import default
 from sqlalchemy import util, exc
+from sqlalchemy import processors
 from sqlalchemy import types as sqltypes
 from sqlalchemy.dialects.postgresql.base import PGDialect, \
                 PGCompiler, PGIdentifierPreparer, PGExecutionContext
 
 class _PGNumeric(sqltypes.Numeric):
     def bind_processor(self, dialect):
-        def process(value):
-            if value is not None:
-                return float(value)
-            else:
-                return value
-        return process
+        return processors.to_float
     
     def result_processor(self, dialect, coltype):
         if self.asdecimal:
             if coltype in (700, 701):
-                def process(value):
-                    if value is not None:
-                        return decimal.Decimal(str(value))
-                    else:
-                        return value
-                return process
+                return processors.to_decimal_processor_factory(decimal.Decimal)
             elif coltype == 1700:
                 # pg8000 returns Decimal natively for 1700
                 return None
@@ -54,12 +46,7 @@ class _PGNumeric(sqltypes.Numeric):
                 # pg8000 returns float natively for 701
                 return None
             elif coltype == 1700:
-                def process(value):
-                    if value is not None:
-                        return float(value)
-                    else:
-                        return value
-                return process
+                return processors.to_float
             else:
                 raise exc.InvalidRequestError("Unknown PG numeric type: %d" % coltype)
 
index bb6562deafdf485350929754c2c3d8f5e7556087..712124288a21e66d5370718ce6ec361092c084f6 100644 (file)
@@ -46,8 +46,11 @@ The following per-statement execution options are respected:
 
 """
 
-import decimal, random, re
+import random, re
+import decimal
+
 from sqlalchemy import util
+from sqlalchemy import processors
 from sqlalchemy.engine import base, default
 from sqlalchemy.sql import expression
 from sqlalchemy.sql import operators as sql_operators
@@ -63,12 +66,7 @@ class _PGNumeric(sqltypes.Numeric):
     def result_processor(self, dialect, coltype):
         if self.asdecimal:
             if coltype in (700, 701):
-                def process(value):
-                    if value is not None:
-                        return decimal.Decimal(str(value))
-                    else:
-                        return value
-                return process
+                return processors.to_decimal_processor_factory(decimal.Decimal)
             elif coltype == 1700:
                 # pg8000 returns Decimal natively for 1700
                 return None
@@ -79,12 +77,7 @@ class _PGNumeric(sqltypes.Numeric):
                 # pg8000 returns float natively for 701
                 return None
             elif coltype == 1700:
-                def process(value):
-                    if value is not None:
-                        return float(value)
-                    else:
-                        return value
-                return process
+                return processors.to_float
             else:
                 raise exc.InvalidRequestError("Unknown PG numeric type: %d" % coltype)
 
index 77ed44512bfc9db8e90931d36aba7234c0869fb4..88f1acde721154b47b9188480392fc6e8db76a79 100644 (file)
@@ -12,6 +12,7 @@ import decimal
 from sqlalchemy import util
 from sqlalchemy import types as sqltypes
 from sqlalchemy.dialects.postgresql.base import PGDialect, PGExecutionContext
+from sqlalchemy import processors
 
 class PGNumeric(sqltypes.Numeric):
     def bind_processor(self, dialect):
@@ -21,12 +22,7 @@ class PGNumeric(sqltypes.Numeric):
         if self.asdecimal:
             return None
         else:
-            def process(value):
-                if value is not None:
-                    return float(value)
-                else:
-                    return value
-            return process
+            return processors.to_float
 
 class PostgreSQL_pypostgresqlExecutionContext(PGExecutionContext):
     pass
index 696f65a6ca77ddc8f91a4af4d4d36e9579c20671..e987439c5901e3e11cec9fdea2f01fdf307d0cc6 100644 (file)
@@ -54,6 +54,7 @@ from sqlalchemy import types as sqltypes
 from sqlalchemy import util
 from sqlalchemy.sql import compiler, functions as sql_functions
 from sqlalchemy.util import NoneType
+from sqlalchemy import processors
 
 from sqlalchemy.types import BLOB, BOOLEAN, CHAR, DATE, DATETIME, DECIMAL,\
                             FLOAT, INTEGER, NUMERIC, SMALLINT, TEXT, TIME,\
@@ -62,13 +63,10 @@ from sqlalchemy.types import BLOB, BOOLEAN, CHAR, DATE, DATETIME, DECIMAL,\
 
 class _NumericMixin(object):
     def bind_processor(self, dialect):
-        type_ = self.asdecimal and str or float
-        def process(value):
-            if value is not None:
-                return type_(value)
-            else:
-                return value
-        return process
+        if self.asdecimal:
+            return processors.to_str
+        else:
+            return processors.to_float
 
 class _SLNumeric(_NumericMixin, sqltypes.Numeric):
     pass
@@ -86,19 +84,7 @@ class _DateTimeMixin(object):
         if storage_format is not None:
             self._storage_format = storage_format
             
-    def _result_processor(self, fn):
-        rmatch = self._reg.match
-        # Even on python2.6 datetime.strptime is both slower than this code
-        # and it does not support microseconds.
-        def process(value):
-            if value is not None:
-                return fn(*map(int, rmatch(value).groups(0)))
-            else:
-                return None
-        return process
-
 class DATETIME(_DateTimeMixin, sqltypes.DateTime):
-    _reg = re.compile(r"(\d+)-(\d+)-(\d+) (\d+):(\d+):(\d+)(?:\.(\d+))?")
     _storage_format = "%04d-%02d-%02d %02d:%02d:%02d.%06d"
   
     def bind_processor(self, dialect):
@@ -121,10 +107,13 @@ class DATETIME(_DateTimeMixin, sqltypes.DateTime):
         return process
 
     def result_processor(self, dialect, coltype):
-        return self._result_processor(datetime.datetime)
+        if self._reg:
+            return processors.str_to_datetime_processor_factory(
+                self._reg, datetime.datetime)
+        else:
+            return processors.str_to_datetime
 
 class DATE(_DateTimeMixin, sqltypes.Date):
-    _reg = re.compile(r"(\d+)-(\d+)-(\d+)")
     _storage_format = "%04d-%02d-%02d"
 
     def bind_processor(self, dialect):
@@ -141,10 +130,13 @@ class DATE(_DateTimeMixin, sqltypes.Date):
         return process
   
     def result_processor(self, dialect, coltype):
-        return self._result_processor(datetime.date)
+        if self._reg:
+            return processors.str_to_datetime_processor_factory(
+                self._reg, datetime.date)
+        else:
+            return processors.str_to_date
 
 class TIME(_DateTimeMixin, sqltypes.Time):
-    _reg = re.compile(r"(\d+):(\d+):(\d+)(?:\.(\d+))?")
     _storage_format = "%02d:%02d:%02d.%06d"
 
     def bind_processor(self, dialect):
@@ -162,7 +154,11 @@ class TIME(_DateTimeMixin, sqltypes.Time):
         return process
   
     def result_processor(self, dialect, coltype):
-        return self._result_processor(datetime.time)
+        if self._reg:
+            return processors.str_to_datetime_processor_factory(
+                self._reg, datetime.time)
+        else:
+            return processors.str_to_time
 
 colspecs = {
     sqltypes.Date: DATE,
index cdbf6138daee332aeb7ca6303bb7e4e5255e87bd..886a773d8e09f2d1767e5d16c331cbd05cfc1272 100644 (file)
@@ -115,24 +115,7 @@ class SybaseUniqueIdentifier(sqltypes.TypeEngine):
     __visit_name__ = "UNIQUEIDENTIFIER"
     
 class SybaseBoolean(sqltypes.Boolean):
-    def result_processor(self, dialect, coltype):
-        def process(value):
-            if value is None:
-                return None
-            return value and True or False
-        return process
-
-    def bind_processor(self, dialect):
-        def process(value):
-            if value is True:
-                return 1
-            elif value is False:
-                return 0
-            elif value is None:
-                return None
-            else:
-                return value and True or False
-        return process
+    pass
 
 class SybaseTypeCompiler(compiler.GenericTypeCompiler):
     def visit_large_binary(self, type_):
index 844183628d062385b534836083f048cefca3433a..4dc9665c0c841d0adb8821fe7527dc6963e232a9 100644 (file)
@@ -20,6 +20,7 @@ __all__ = [
     'connection_memoize']
 
 import inspect, StringIO, sys, operator
+from itertools import izip
 from sqlalchemy import exc, schema, util, types, log
 from sqlalchemy.sql import expression
 
@@ -1536,16 +1537,20 @@ class Engine(Connectable):
 def _proxy_connection_cls(cls, proxy):
     class ProxyConnection(cls):
         def execute(self, object, *multiparams, **params):
-            return proxy.execute(self, super(ProxyConnection, self).execute, object, *multiparams, **params)
+            return proxy.execute(self, super(ProxyConnection, self).execute, 
+                                            object, *multiparams, **params)
 
         def _execute_clauseelement(self, elem, multiparams=None, params=None):
-            return proxy.execute(self, super(ProxyConnection, self).execute, elem, *(multiparams or []), **(params or {}))
+            return proxy.execute(self, super(ProxyConnection, self).execute, 
+                                            elem, *(multiparams or []), **(params or {}))
 
         def _cursor_execute(self, cursor, statement, parameters, context=None):
-            return proxy.cursor_execute(super(ProxyConnection, self)._cursor_execute, cursor, statement, parameters, context, False)
+            return proxy.cursor_execute(super(ProxyConnection, self)._cursor_execute, 
+                                            cursor, statement, parameters, context, False)
 
         def _cursor_executemany(self, cursor, statement, parameters, context=None):
-            return proxy.cursor_execute(super(ProxyConnection, self)._cursor_executemany, cursor, statement, parameters, context, True)
+            return proxy.cursor_execute(super(ProxyConnection, self)._cursor_executemany, 
+                                            cursor, statement, parameters, context, True)
 
         def _begin_impl(self):
             return proxy.begin(self, super(ProxyConnection, self)._begin_impl)
@@ -1560,27 +1565,125 @@ def _proxy_connection_cls(cls, proxy):
             return proxy.savepoint(self, super(ProxyConnection, self)._savepoint_impl, name=name)
 
         def _rollback_to_savepoint_impl(self, name, context):
-            return proxy.rollback_savepoint(self, super(ProxyConnection, self)._rollback_to_savepoint_impl, name, context)
+            return proxy.rollback_savepoint(self, 
+                                    super(ProxyConnection, self)._rollback_to_savepoint_impl, 
+                                    name, context)
             
         def _release_savepoint_impl(self, name, context):
-            return proxy.release_savepoint(self, super(ProxyConnection, self)._release_savepoint_impl, name, context)
+            return proxy.release_savepoint(self, 
+                                    super(ProxyConnection, self)._release_savepoint_impl, 
+                                    name, context)
 
         def _begin_twophase_impl(self, xid):
-            return proxy.begin_twophase(self, super(ProxyConnection, self)._begin_twophase_impl, xid)
+            return proxy.begin_twophase(self, 
+                                    super(ProxyConnection, self)._begin_twophase_impl, xid)
 
         def _prepare_twophase_impl(self, xid):
-            return proxy.prepare_twophase(self, super(ProxyConnection, self)._prepare_twophase_impl, xid)
+            return proxy.prepare_twophase(self, 
+                                    super(ProxyConnection, self)._prepare_twophase_impl, xid)
 
         def _rollback_twophase_impl(self, xid, is_prepared):
-            return proxy.rollback_twophase(self, super(ProxyConnection, self)._rollback_twophase_impl, xid, is_prepared)
+            return proxy.rollback_twophase(self, 
+                                    super(ProxyConnection, self)._rollback_twophase_impl, 
+                                    xid, is_prepared)
 
         def _commit_twophase_impl(self, xid, is_prepared):
-            return proxy.commit_twophase(self, super(ProxyConnection, self)._commit_twophase_impl, xid, is_prepared)
+            return proxy.commit_twophase(self, 
+                                    super(ProxyConnection, self)._commit_twophase_impl, 
+                                    xid, is_prepared)
 
     return ProxyConnection
 
+# This reconstructor is necessary so that pickles with the C extension or
+# without use the same Binary format.
+# We need a different reconstructor on the C extension so that we can
+# add extra checks that fields have correctly been initialized by
+# __setstate__.
+try:
+    from sqlalchemy.cresultproxy import rowproxy_reconstructor
+
+    # this is a hack so that the reconstructor function is pickled with the
+    # same name as without the C extension.
+    # BUG: It fails for me if I run the "python" interpreter and 
+    # then say "import sqlalchemy":
+    # TypeError: 'builtin_function_or_method' object has only read-only attributes (assign to .__module__)
+    # However, if I run the tests with nosetests, it succeeds !  
+    # I've verified with pdb etc. that this is the case.
+    #rowproxy_reconstructor.__module__ = 'sqlalchemy.engine.base'
+
+except ImportError:
+    def rowproxy_reconstructor(cls, state):
+        obj = cls.__new__(cls)
+        obj.__setstate__(state)
+        return obj
+
+try:
+    from sqlalchemy.cresultproxy import BaseRowProxy
+except ImportError:
+    class BaseRowProxy(object):
+        __slots__ = ('_parent', '_row', '_processors', '_keymap')
+
+        def __init__(self, parent, row, processors, keymap):
+            """RowProxy objects are constructed by ResultProxy objects."""
+
+            self._parent = parent
+            self._row = row
+            self._processors = processors
+            self._keymap = keymap
+
+        def __reduce__(self):
+            return (rowproxy_reconstructor,
+                    (self.__class__, self.__getstate__()))
+
+        def values(self):
+            """Return the values represented by this RowProxy as a list."""
+            return list(self)
+
+        def __iter__(self):
+            for processor, value in izip(self._processors, self._row):
+                if processor is None:
+                    yield value
+                else:
+                    yield processor(value)
+
+        def __len__(self):
+            return len(self._row)
 
-class RowProxy(object):
+        def __getitem__(self, key):
+            try:
+                processor, index = self._keymap[key]
+            except KeyError:
+                processor, index = self._parent._key_fallback(key)
+            except TypeError:
+                if isinstance(key, slice):
+                    l = []
+                    for processor, value in izip(self._processors[key],
+                                                 self._row[key]):
+                        if processor is None:
+                            l.append(value)
+                        else:
+                            l.append(processor(value))
+                    return tuple(l)
+                else:
+                    raise
+            if index is None:
+                raise exc.InvalidRequestError(
+                        "Ambiguous column name '%s' in result set! "
+                        "try 'use_labels' option on select statement." % key)
+            if processor is not None:
+                return processor(self._row[index])
+            else:
+                return self._row[index]
+
+        def __getattr__(self, name):
+            try:
+                # TODO: no test coverage here
+                return self[name]
+            except KeyError, e:
+                raise AttributeError(e.args[0])
+
+
+class RowProxy(BaseRowProxy):
     """Proxy values from a single cursor row.
 
     Mostly follows "ordered dictionary" behavior, mapping result
@@ -1589,38 +1692,22 @@ class RowProxy(object):
     mapped to the original Columns that produced this result set (for
     results that correspond to constructed SQL expressions).
     """
+    __slots__ = ()
 
-    __slots__ = ['__parent', '__row', '__colfuncs']
-
-    def __init__(self, parent, row):
-
-        self.__parent = parent
-        self.__row = row
-        self.__colfuncs = parent._colfuncs
-        if self.__parent._echo:
-            self.__parent.logger.debug("Row %r", row)
-        
     def __contains__(self, key):
-        return self.__parent._has_key(self.__row, key)
+        return self._parent._has_key(self._row, key)
 
-    def __len__(self):
-        return len(self.__row)
-    
     def __getstate__(self):
         return {
-            '__row':[self.__colfuncs[i][0](self.__row) for i in xrange(len(self.__row))],
-            '__parent':self.__parent
+            '_parent': self._parent,
+            '_row': tuple(self)
         }
-    
-    def __setstate__(self, d):
-        self.__row = d['__row']
-        self.__parent = d['__parent']
-        self.__colfuncs = self.__parent._colfuncs
-        
-    def __iter__(self):
-        row = self.__row 
-        for func in self.__parent._colfunc_list: 
-            yield func(row)
+
+    def __setstate__(self, state):
+        self._parent = parent = state['_parent']
+        self._row = state['_row']
+        self._processors = parent._processors
+        self._keymap = parent._keymap
 
     __hash__ = None
 
@@ -1636,33 +1723,7 @@ class RowProxy(object):
     def has_key(self, key):
         """Return True if this RowProxy contains the given key."""
 
-        return self.__parent._has_key(self.__row, key)
-
-    def __getitem__(self, key):
-        # the fallback and slices are only useful for __getitem__ anyway 
-        try: 
-            return self.__colfuncs[key][0](self.__row) 
-        except KeyError: 
-            k = self.__parent._key_fallback(key)
-            if k is None:
-                raise exc.NoSuchColumnError(
-                    "Could not locate column in row for column '%s'" % key)
-            else:
-                # save on KeyError + _key_fallback() lookup next time around
-                self.__colfuncs[key] = k
-                return k[0](self.__row)
-        except TypeError: 
-            if isinstance(key, slice): 
-                return tuple(func(self.__row) for func in self.__parent._colfunc_list[key]) 
-            else: 
-                raise
-
-    def __getattr__(self, name):
-        try:
-            # TODO: no test coverage here
-            return self[name]
-        except KeyError, e:
-            raise AttributeError(e.args[0])
+        return self._parent._has_key(self._row, key)
 
     def items(self):
         """Return a list of tuples, each tuple containing a key/value pair."""
@@ -1672,24 +1733,25 @@ class RowProxy(object):
     def keys(self):
         """Return the list of keys as strings represented by this RowProxy."""
 
-        return self.__parent.keys
+        return self._parent.keys
 
     def iterkeys(self):
-        return iter(self.__parent.keys)
-
-    def values(self):
-        """Return the values represented by this RowProxy as a list."""
-
-        return list(self)
+        return iter(self._parent.keys)
 
     def itervalues(self):
         return iter(self)
 
+
 class ResultMetaData(object):
     """Handle cursor.description, applying additional info from an execution context."""
     
     def __init__(self, parent, metadata):
-        self._colfuncs = colfuncs = {}
+        self._processors = processors = []
+
+        # We do not strictly need to store the processor in the key mapping,
+        # though it is faster in the Python version (probably because of the
+        # saved attribute lookup self._processors)
+        self._keymap = keymap = {}
         self.keys = []
         self._echo = parent._echo
         context = parent.context
@@ -1720,29 +1782,25 @@ class ResultMetaData(object):
             processor = type_.dialect_impl(dialect).\
                             result_processor(dialect, coltype)
             
-            if processor:
-                def make_colfunc(processor, index):
-                    def getcol(row):
-                        return processor(row[index])
-                    return getcol
-                rec = (make_colfunc(processor, i), i, "colfunc")
-            else:
-                rec = (operator.itemgetter(i), i, "itemgetter")
+            processors.append(processor)
+            rec = (processor, i)
 
-            # indexes as keys
-            colfuncs[i] = rec
+            # indexes as keys. This is only needed for the Python version of
+            # RowProxy (the C version uses a faster path for integer indexes).
+            keymap[i] = rec
             
             # Column names as keys 
-            if colfuncs.setdefault(name.lower(), rec) is not rec: 
-                #XXX: why not raise directly? because several columns colliding 
-                #by name is not a problem as long as the user don't use them (ie 
-                #use the more precise ColumnElement 
-                colfuncs[name.lower()] = (self._ambiguous_processor(name), i, "ambiguous")
-            
+            if keymap.setdefault(name.lower(), rec) is not rec: 
+                # We do not raise an exception directly because several
+                # columns colliding by name is not a problem as long as the
+                # user does not try to access them (ie use an index directly,
+                # or the more precise ColumnElement)
+                keymap[name.lower()] = (processor, None)
+
             # store the "origname" if we truncated (sqlite only)
             if origname and \
-                    colfuncs.setdefault(origname.lower(), rec) is not rec:
-                colfuncs[origname.lower()] = (self._ambiguous_processor(origname), i, "ambiguous")
+                    keymap.setdefault(origname.lower(), rec) is not rec:
+                keymap[origname.lower()] = (processor, None)
             
             if dialect.requires_name_normalize:
                 colname = dialect.normalize_name(colname)
@@ -1750,76 +1808,67 @@ class ResultMetaData(object):
             self.keys.append(colname)
             if obj:
                 for o in obj:
-                    colfuncs[o] = rec
+                    keymap[o] = rec
 
         if self._echo:
             self.logger = context.engine.logger
             self.logger.debug(
                 "Col %r", tuple(x[0] for x in metadata))
 
-    @util.memoized_property
-    def _colfunc_list(self):
-        funcs = self._colfuncs
-        return [funcs[i][0] for i in xrange(len(self.keys))]
-
     def _key_fallback(self, key):
-        funcs = self._colfuncs
-
+        map = self._keymap
+        result = None
         if isinstance(key, basestring):
-            key = key.lower()
-            if key in funcs:
-                return funcs[key]
-
+            result = map.get(key.lower())
         # fallback for targeting a ColumnElement to a textual expression
         # this is a rare use case which only occurs when matching text()
-        # constructs to ColumnElements
-        if isinstance(key, expression.ColumnElement):
-            if key._label and key._label.lower() in funcs:
-                return funcs[key._label.lower()]
-            elif hasattr(key, 'name') and key.name.lower() in funcs:
-                return funcs[key.name.lower()]
-
-        return None
+        # constructs to ColumnElements, and after a pickle/unpickle roundtrip
+        elif isinstance(key, expression.ColumnElement):
+            if key._label and key._label.lower() in map:
+                result = map[key._label.lower()]
+            elif hasattr(key, 'name') and key.name.lower() in map:
+                result = map[key.name.lower()]
+        if result is None:
+            raise exc.NoSuchColumnError(
+                "Could not locate column in row for column '%s'" % key)
+        else:
+            map[key] = result
+        return result
 
     def _has_key(self, row, key):
-        if key in self._colfuncs:
+        if key in self._keymap:
             return True
         else:
-            key = self._key_fallback(key)
-            return key is not None
+            try:
+                self._key_fallback(key)
+                return True
+            except exc.NoSuchColumnError:
+                return False
 
-    @classmethod
-    def _ambiguous_processor(cls, colname):
-        def process(value):
-            raise exc.InvalidRequestError(
-                    "Ambiguous column name '%s' in result set! "
-                    "try 'use_labels' option on select statement." % colname)
-        return process
-    
     def __len__(self):
         return len(self.keys)
 
     def __getstate__(self):
         return {
-            '_pickled_colfuncs':dict(
-                (key, (i, type_)) 
-                for key, (fn, i, type_) in self._colfuncs.iteritems() 
+            '_pickled_keymap': dict(
+                (key, index)
+                for key, (processor, index) in self._keymap.iteritems()
                 if isinstance(key, (basestring, int))
             ),
-            'keys':self.keys
+            'keys': self.keys
         }
     
     def __setstate__(self, state):
-        pickled_colfuncs = state['_pickled_colfuncs']
-        self._colfuncs = d = {}
-        for key, (index, type_) in pickled_colfuncs.iteritems():
-            if type_ == 'ambiguous':
-                d[key] = (self._ambiguous_processor(key), index, type_)
-            else:
-                d[key] = (operator.itemgetter(index), index, "itemgetter")
+        # the row has been processed at pickling time so we don't need any
+        # processor anymore
+        self._processors = [None for _ in xrange(len(state['keys']))]
+        self._keymap = keymap = {}
+        for key, index in state['_pickled_keymap'].iteritems():
+            keymap[key] = (None, index)
         self.keys = state['keys']
         self._echo = False
-        
+
+       
 class ResultProxy(object):
     """Wraps a DB-API cursor object to provide easier access to row columns.
 
@@ -2031,13 +2080,27 @@ class ResultProxy(object):
     def _fetchall_impl(self):
         return self.cursor.fetchall()
 
+    def process_rows(self, rows):
+        process_row = self._process_row
+        metadata = self._metadata
+        keymap = metadata._keymap
+        processors = metadata._processors
+        if self._echo:
+            log = self.context.engine.logger.debug
+            l = []
+            for row in rows:
+                log("Row %r", row)
+                l.append(process_row(metadata, row, processors, keymap))
+            return l
+        else:
+            return [process_row(metadata, row, processors, keymap)
+                    for row in rows]
+
     def fetchall(self):
         """Fetch all rows, just like DB-API ``cursor.fetchall()``."""
 
         try:
-            process_row = self._process_row
-            metadata = self._metadata
-            l = [process_row(metadata, row) for row in self._fetchall_impl()]
+            l = self.process_rows(self._fetchall_impl())
             self.close()
             return l
         except Exception, e:
@@ -2053,9 +2116,7 @@ class ResultProxy(object):
         """
 
         try:
-            process_row = self._process_row
-            metadata = self._metadata
-            l = [process_row(metadata, row) for row in self._fetchmany_impl(size)]
+            l = self.process_rows(self._fetchmany_impl(size))
             if len(l) == 0:
                 self.close()
             return l
@@ -2074,7 +2135,7 @@ class ResultProxy(object):
         try:
             row = self._fetchone_impl()
             if row is not None:
-                return self._process_row(self._metadata, row)
+                return self.process_rows([row])[0]
             else:
                 self.close()
                 return None
@@ -2096,13 +2157,12 @@ class ResultProxy(object):
 
         try:
             if row is not None:
-                return self._process_row(self._metadata, row)
+                return self.process_rows([row])[0]
             else:
                 return None
         finally:
             self.close()
         
-        
     def scalar(self):
         """Fetch the first column of the first row, and close the result set.
         
@@ -2210,9 +2270,18 @@ class FullyBufferedResultProxy(ResultProxy):
         return ret
 
 class BufferedColumnRow(RowProxy):
-    def __init__(self, parent, row):
-        row = [parent._orig_colfuncs[i][0](row) for i in xrange(len(row))]
-        super(BufferedColumnRow, self).__init__(parent, row)
+    def __init__(self, parent, row, processors, keymap):
+        # preprocess row
+        row = list(row)
+        # this is a tad faster than using enumerate
+        index = 0
+        for processor in parent._orig_processors:
+            if processor is not None:
+                row[index] = processor(row[index])
+            index += 1
+        row = tuple(row)
+        super(BufferedColumnRow, self).__init__(parent, row,
+                                                processors, keymap)
         
 class BufferedColumnResultProxy(ResultProxy):
     """A ResultProxy with column buffering behavior.
@@ -2221,7 +2290,7 @@ class BufferedColumnResultProxy(ResultProxy):
     fetchone() is called.  If fetchmany() or fetchall() are called,
     the full grid of results is fetched.  This is to operate with
     databases where result rows contain "live" results that fall out
-    of scope unless explicitly fetched.  Currently this includes 
+    of scope unless explicitly fetched.  Currently this includes
     cx_Oracle LOB objects.
     
     """
@@ -2230,17 +2299,16 @@ class BufferedColumnResultProxy(ResultProxy):
 
     def _init_metadata(self):
         super(BufferedColumnResultProxy, self)._init_metadata()
-        self._metadata._orig_colfuncs = self._metadata._colfuncs
-        self._metadata._colfuncs = colfuncs = {}
-        # replace the parent's _colfuncs dict, replacing 
-        # column processors with straight itemgetters.
-        # the original _colfuncs dict is used when each row
-        # is constructed.
-        for k, (colfunc, index, type_) in self._metadata._orig_colfuncs.iteritems():
-            if type_ == "colfunc":
-                colfuncs[k] = (operator.itemgetter(index), index, "itemgetter")
-            else:
-                colfuncs[k] = (colfunc, index, type_)
+        metadata = self._metadata
+        # orig_processors will be used to preprocess each row when they are
+        # constructed.
+        metadata._orig_processors = metadata._processors
+        # replace the all type processors by None processors.
+        metadata._processors = [None for _ in xrange(len(metadata.keys))]
+        keymap = {}
+        for k, (func, index) in metadata._keymap.iteritems():
+            keymap[k] = (None, index)
+        self._metadata._keymap = keymap
 
     def fetchall(self):
         # can't call cursor.fetchall(), since rows must be
diff --git a/lib/sqlalchemy/processors.py b/lib/sqlalchemy/processors.py
new file mode 100644 (file)
index 0000000..cb4b725
--- /dev/null
@@ -0,0 +1,90 @@
+# processors.py
+# Copyright (C) 2010 Gaetan de Menten gdementen@gmail.com
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
+
+"""defines generic type conversion functions, as used in result processors.
+
+They all share one common characteristic: None is passed through unchanged.
+
+"""
+
+import codecs
+import re
+import datetime
+
+def str_to_datetime_processor_factory(regexp, type_):
+    rmatch = regexp.match
+    # Even on python2.6 datetime.strptime is both slower than this code
+    # and it does not support microseconds.
+    def process(value):
+        if value is None:
+            return None
+        else:
+            return type_(*map(int, rmatch(value).groups(0)))
+    return process
+
+try:
+    from sqlalchemy.cprocessors import UnicodeResultProcessor, \
+                                       DecimalResultProcessor, \
+                                       to_float, to_str, int_to_boolean, \
+                                       str_to_datetime, str_to_time, \
+                                       str_to_date
+
+    def to_unicode_processor_factory(encoding):
+        return UnicodeResultProcessor(encoding).process
+
+    def to_decimal_processor_factory(target_class):
+        return DecimalResultProcessor(target_class).process
+
+except ImportError:
+    def to_unicode_processor_factory(encoding):
+        decoder = codecs.getdecoder(encoding)
+
+        def process(value):
+            if value is None:
+                return None
+            else:
+                # decoder returns a tuple: (value, len). Simply dropping the
+                # len part is safe: it is done that way in the normal
+                # 'xx'.decode(encoding) code path.
+                # cfr python-source/Python/codecs.c:PyCodec_Decode
+                return decoder(value)[0]
+        return process
+
+    def to_decimal_processor_factory(target_class):
+        def process(value):
+            if value is None:
+                return None
+            else:
+                return target_class(str(value))
+        return process
+
+    def to_float(value):
+        if value is None:
+            return None
+        else:
+            return float(value)
+
+    def to_str(value):
+        if value is None:
+            return None
+        else:
+            return str(value)
+
+    def int_to_boolean(value):
+        if value is None:
+            return None
+        else:
+            return value and True or False
+
+    DATETIME_RE = re.compile("(\d+)-(\d+)-(\d+) (\d+):(\d+):(\d+)(?:\.(\d+))?")
+    TIME_RE = re.compile("(\d+):(\d+):(\d+)(?:\.(\d+))?")
+    DATE_RE = re.compile("(\d+)-(\d+)-(\d+)")
+
+    str_to_datetime = str_to_datetime_processor_factory(DATETIME_RE,
+                                                        datetime.datetime)
+    str_to_time = str_to_datetime_processor_factory(TIME_RE, datetime.time)
+    str_to_date = str_to_datetime_processor_factory(DATE_RE, datetime.date)
+
index 8cab6ceba1525686d333be32865b895372230622..c5256affa48c517ba76a7fefbdf984420ce979f8 100644 (file)
@@ -93,9 +93,16 @@ def function_call_count(count=None, versions={}, variance=0.05):
 
     version_info = list(sys.version_info)
     py_version = '.'.join([str(v) for v in sys.version_info])
-
+    try:
+        from sqlalchemy.cprocessors import to_float
+        cextension = True
+    except ImportError:
+        cextension = False
+        
     while version_info:
         version = '.'.join([str(v) for v in version_info])
+        if cextension:
+            version += "+cextension"
         if version in versions:
             count = versions[version]
             break
index 465454df95bc2d60d22580b2fbf6463ba1d4abaa..36302cae3e6f90d61ed0bb01757e828e77875515 100644 (file)
@@ -32,6 +32,7 @@ schema.types = expression.sqltypes =sys.modules['sqlalchemy.types']
 from sqlalchemy.util import pickle
 from sqlalchemy.sql.visitors import Visitable
 from sqlalchemy import util
+from sqlalchemy import processors
 
 NoneType = type(None)
 if util.jython:
@@ -608,14 +609,7 @@ class String(Concatenable, TypeEngine):
         if needs_convert:
             # note we *assume* that we do not have a unicode object
             # here, instead of an expensive isinstance() check.
-            decoder = codecs.getdecoder(dialect.encoding)
-            def process(value):
-                if value is not None:
-                    # decoder returns a tuple: (value, len)
-                    return decoder(value)[0]
-                else:
-                    return value
-            return process
+            return processors.to_unicode_processor_factory(dialect.encoding)
         else:
             return None
 
@@ -810,21 +804,15 @@ class Numeric(_DateAffinity, TypeEngine):
         return dbapi.NUMBER
 
     def bind_processor(self, dialect):
-        def process(value):
-            if value is not None:
-                return float(value)
-            else:
-                return value
-        return process
+        return processors.to_float
 
     def result_processor(self, dialect, coltype):
         if self.asdecimal:
-            def process(value):
-                if value is not None:
-                    return _python_Decimal(str(value))
-                else:
-                    return value
-            return process
+            #XXX: use decimal from http://www.bytereef.org/libmpdec.html
+#            try:
+#                from fastdec import mpd as Decimal
+#            except ImportError:
+            return processors.to_decimal_processor_factory(_python_Decimal)
         else:
             return None
 
@@ -991,11 +979,7 @@ class _Binary(TypeEngine):
                 else:
                     return None
         else:
-            def process(value):
-                if value is not None:
-                    return str(value)
-                else:
-                    return None
+            process = processors.to_str
         return process
     # end Py2K
     
@@ -1349,11 +1333,7 @@ class Boolean(TypeEngine, SchemaType):
         if dialect.supports_native_boolean:
             return None
         else:
-            def process(value):
-                if value is None:
-                    return None
-                return value and True or False
-            return process
+            return processors.int_to_boolean
 
 class Interval(_DateAffinity, TypeDecorator):
     """A type for ``datetime.timedelta()`` objects.
@@ -1419,7 +1399,7 @@ class Interval(_DateAffinity, TypeDecorator):
         if impl_processor:
             def process(value):
                 value = impl_processor(value)
-                if value is None: 
+                if value is None:
                     return None
                 return value - epoch
         else:
index 3a1b5f1dd386821ba9434dbc7d256b4b9182879e..20da456d9bc29246dce42c985d4142c2f8fe19a8 100644 (file)
--- a/setup.py
+++ b/setup.py
@@ -15,9 +15,11 @@ if sys.version_info >= (3, 0):
     )
 
 try:
-    from setuptools import setup
+    from setuptools import setup, Extension
 except ImportError:
-    from distutils.core import setup
+    from distutils.core import setup, Extension
+
+BUILD_CEXTENSIONS = False
 
 def find_packages(dir_):
     packages = []
@@ -46,6 +48,12 @@ setup(name = "SQLAlchemy",
       license = "MIT License",
       tests_require = ['nose >= 0.10'],
       test_suite = "nose.collector",
+      ext_modules = (BUILD_CEXTENSIONS and
+                     [Extension('sqlalchemy.cprocessors',
+                                sources=['lib/sqlalchemy/cextension/processors.c']),
+                      Extension('sqlalchemy.cresultproxy',
+                                sources=['lib/sqlalchemy/cextension/resultproxy.c'])
+                                ]),
       entry_points = {
           'nose.plugins.0.10': [
               'sqlalchemy = sqlalchemy.test.noseplugin:NoseSQLAlchemy',
index 83901b7f7a848bfe9e2e07c6331f3f9cc48cfbed..459a8e4c4dd2ed543d97211f35f34fb52b8cb34d 100644 (file)
@@ -29,13 +29,13 @@ class ResultSetTest(TestBase, AssertsExecutionResults):
     def teardown(self):
         metadata.drop_all()
         
-    @profiling.function_call_count(14416, versions={'2.4':13214})
+    @profiling.function_call_count(14416, versions={'2.4':13214, '2.6+cextension':409})
     def test_string(self):
         [tuple(row) for row in t.select().execute().fetchall()]
 
     # sqlite3 returns native unicode.  so shouldn't be an
     # increase here.
-    @profiling.function_call_count(14396, versions={'2.4':13214})
+    @profiling.function_call_count(14396, versions={'2.4':13214, '2.6+cextension':409})
     def test_unicode(self):
         [tuple(row) for row in t2.select().execute().fetchall()]
 
index 66bb45f31279614f33cec2e874fd77b1cf694764..706f8e470b7ec323fd5f07d1b505fb3fe83de5ad 100644 (file)
@@ -339,7 +339,7 @@ class ZooMarkTest(TestBase):
     def test_profile_3_properties(self):
         self.test_baseline_3_properties()
 
-    @profiling.function_call_count(13341, {'2.4': 7963})
+    @profiling.function_call_count(13341, {'2.4': 7963, '2.6+cextension':12447})
     def test_profile_4_expressions(self):
         self.test_baseline_4_expressions()
 
@@ -351,7 +351,7 @@ class ZooMarkTest(TestBase):
     def test_profile_6_editing(self):
         self.test_baseline_6_editing()
 
-    @profiling.function_call_count(2641, {'2.4': 1673})
+    @profiling.function_call_count(2641, {'2.4': 1673, '2.6+cextension':2502})
     def test_profile_7_multiview(self):
         self.test_baseline_7_multiview()
 
diff --git a/test/perf/stress_all.py b/test/perf/stress_all.py
new file mode 100644 (file)
index 0000000..ad074ee
--- /dev/null
@@ -0,0 +1,226 @@
+# -*- encoding: utf8 -*-
+from datetime import *
+from decimal import Decimal
+#from fastdec import mpd as Decimal
+from cPickle import dumps, loads
+
+#from sqlalchemy.dialects.postgresql.base import ARRAY
+
+from stresstest import *
+
+# ---
+test_types = False
+test_methods = True
+test_pickle = False
+test_orm = False
+# ---
+verbose = True
+
+def values_results(raw_results):
+    return [tuple(r.values()) for r in raw_results]
+
+def getitem_str_results(raw_results):
+    return [
+        (r['id'],
+         r['field0'], r['field1'], r['field2'], r['field3'], r['field4'],
+         r['field5'], r['field6'], r['field7'], r['field8'], r['field9'])
+         for r in raw_results]
+
+def getitem_fallback_results(raw_results):
+    return [
+        (r['ID'],
+         r['FIELD0'], r['FIELD1'], r['FIELD2'], r['FIELD3'], r['FIELD4'],
+         r['FIELD5'], r['FIELD6'], r['FIELD7'], r['FIELD8'], r['FIELD9'])
+         for r in raw_results]
+
+def getitem_int_results(raw_results):
+    return [
+        (r[0],
+         r[1], r[2], r[3], r[4], r[5],
+         r[6], r[7], r[8], r[9], r[10])
+         for r in raw_results]
+
+def getitem_long_results(raw_results):
+    return [
+        (r[0L],
+         r[1L], r[2L], r[3L], r[4L], r[5L],
+         r[6L], r[7L], r[8L], r[9L], r[10L])
+         for r in raw_results]
+
+def getitem_obj_results(raw_results):
+    c = test_table.c
+    fid, f0, f1, f2, f3, f4, f5, f6, f7, f8, f9 = (
+        c.id, c.field0, c.field1, c.field2, c.field3, c.field4,
+        c.field5, c.field6, c.field7, c.field8, c.field9)
+    return [
+        (r[fid],
+         r[f0], r[f1], r[f2], r[f3], r[f4],
+         r[f5], r[f6], r[f7], r[f8], r[f9])
+         for r in raw_results]
+
+def slice_results(raw_results):
+    return [row[0:6] + row[6:11] for row in raw_results]
+
+# ---------- #
+# Test types #
+# ---------- #
+
+# Array
+#def genarrayvalue(rnum, fnum):
+#    return [fnum, fnum + 1, fnum + 2]
+#arraytest = (ARRAY(Integer), genarrayvalue,
+#             dict(num_fields=100, num_records=1000,
+#                  engineurl='postgresql:///test'))
+
+# Boolean
+def genbooleanvalue(rnum, fnum):
+    if rnum % 4:
+        return bool(fnum % 2)
+    else:
+        return None
+booleantest = (Boolean, genbooleanvalue, dict(num_records=100000))
+
+# Datetime
+def gendatetimevalue(rnum, fnum):
+    return (rnum % 4) and datetime(2005, 3, 3) or None
+datetimetest = (DateTime, gendatetimevalue, dict(num_records=10000))
+
+# Decimal
+def gendecimalvalue(rnum, fnum):
+    if rnum % 4:
+        return Decimal(str(0.25 * fnum))
+    else:
+        return None
+decimaltest = (Numeric(10, 2), gendecimalvalue, dict(num_records=10000))
+
+# Interval
+
+# no microseconds because Postgres does not seem to support it
+from_epoch = timedelta(14643, 70235)
+def genintervalvalue(rnum, fnum):
+    return from_epoch
+intervaltest = (Interval, genintervalvalue,
+                dict(num_fields=2, num_records=100000))
+
+# PickleType
+def genpicklevalue(rnum, fnum):
+    return (rnum % 4) and {'str': "value%d" % fnum, 'int': rnum} or None
+pickletypetest = (PickleType, genpicklevalue,
+                  dict(num_fields=1, num_records=100000))
+
+# TypeDecorator
+class MyIntType(TypeDecorator):
+    impl = Integer
+
+    def process_bind_param(self, value, dialect):
+        return value * 10
+
+    def process_result_value(self, value, dialect):
+        return value / 10
+
+    def copy(self):
+        return MyIntType()
+
+def genmyintvalue(rnum, fnum):
+    return rnum + fnum
+typedecoratortest = (MyIntType, genmyintvalue,
+                     dict(num_records=100000))
+
+# Unicode
+def genunicodevalue(rnum, fnum):
+    return (rnum % 4) and (u"value%d" % fnum) or None
+unicodetest = (Unicode(20, assert_unicode=False), genunicodevalue,
+               dict(num_records=100000))
+#               dict(engineurl='mysql:///test', freshdata=False))
+
+# do the tests
+if test_types:
+    tests = [booleantest, datetimetest, decimaltest, intervaltest,
+             pickletypetest, typedecoratortest, unicodetest]
+    for engineurl in ('postgresql://scott:tiger@localhost/test', 
+                        'sqlite://', 'mysql://scott:tiger@localhost/test'):
+        print "\n%s\n" % engineurl
+        for datatype, genvalue, kwargs in tests:
+            print "%s:" % getattr(datatype, '__name__',
+                                  datatype.__class__.__name__),
+            profile_and_time_dbfunc(iter_results, datatype, genvalue,
+                                    profile=False, engineurl=engineurl,
+                                    verbose=verbose, **kwargs)
+
+# ---------------------- #
+# test row proxy methods #
+# ---------------------- #
+
+if test_methods:
+    methods = [iter_results, values_results, getattr_results,
+               getitem_str_results, getitem_fallback_results,
+               getitem_int_results, getitem_long_results, getitem_obj_results,
+               slice_results]
+    for engineurl in ('postgresql://scott:tiger@localhost/test', 
+                       'sqlite://', 'mysql://scott:tiger@localhost/test'):
+        print "\n%s\n" % engineurl
+        test_table = prepare(Unicode(20, assert_unicode=False),
+                             genunicodevalue,
+                             num_fields=10, num_records=100000,
+                             verbose=verbose, engineurl=engineurl)
+        for method in methods:
+            print "%s:" % method.__name__,
+            time_dbfunc(test_table, method, genunicodevalue,
+                        num_fields=10, num_records=100000, profile=False,
+                        verbose=verbose)
+
+# --------------------------------
+# test pickling Rowproxy instances
+# --------------------------------
+
+def pickletofile_results(raw_results):
+    from cPickle import dump, load
+    for protocol in (0, 1, 2):
+        print "dumping protocol %d..." % protocol
+        f = file('noext.pickle%d' % protocol, 'wb')
+        dump(raw_results, f, protocol)
+        f.close()
+    return raw_results
+
+def pickle_results(raw_results):
+    return loads(dumps(raw_results, 2))
+
+def pickle_meta(raw_results):
+    pickled = dumps(raw_results[0]._parent, 2)
+    metadata = loads(pickled)
+    return raw_results
+
+def pickle_rows(raw_results):
+    return [loads(dumps(row, 2)) for row in raw_results]
+
+if test_pickle:
+    test_table = prepare(Unicode, genunicodevalue,
+                         num_fields=10, num_records=10000)
+    funcs = [pickle_rows, pickle_results]
+    for func in funcs:
+        print "%s:" % func.__name__,
+        time_dbfunc(test_table, func, genunicodevalue,
+                    num_records=10000, profile=False, verbose=verbose)
+
+# --------------------------------
+# test ORM
+# --------------------------------
+
+if test_orm:
+    from sqlalchemy.orm import *
+
+    class Test(object):
+        pass
+
+    Session = sessionmaker()
+    session = Session()
+
+    def get_results():
+        return session.query(Test).all()
+    print "ORM:",
+    for engineurl in ('postgresql:///test', 'sqlite://', 'mysql:///test'):
+        print "\n%s\n" % engineurl
+        profile_and_time_dbfunc(getattr_results, Unicode(20), genunicodevalue,
+                                class_=Test, getresults_func=get_results,
+                                engineurl=engineurl, #freshdata=False,
+                                num_records=10000, verbose=verbose)
diff --git a/test/perf/stresstest.py b/test/perf/stresstest.py
new file mode 100644 (file)
index 0000000..cf9404f
--- /dev/null
@@ -0,0 +1,174 @@
+import gc
+import sys
+import timeit
+import cProfile
+
+from sqlalchemy import MetaData, Table, Column
+from sqlalchemy.types import *
+from sqlalchemy.orm import mapper, clear_mappers
+
+metadata = MetaData()
+
+def gen_table(num_fields, field_type, metadata):
+    return Table('test', metadata,
+        Column('id', Integer, primary_key=True),
+        *[Column("field%d" % fnum, field_type)
+          for fnum in range(num_fields)])
+
+def insert(test_table, num_fields, num_records, genvalue, verbose=True):
+    if verbose:
+        print "building insert values...",
+        sys.stdout.flush()
+    values = [dict(("field%d" % fnum, genvalue(rnum, fnum))
+                   for fnum in range(num_fields))
+              for rnum in range(num_records)]
+    if verbose:
+        print "inserting...",
+        sys.stdout.flush()
+    def db_insert():
+        test_table.insert().execute(values)
+    sys.modules['__main__'].db_insert = db_insert
+    timing = timeit.timeit("db_insert()",
+                            "from __main__ import db_insert",
+                            number=1)
+    if verbose:
+        print "%s" % round(timing, 3)
+
+def check_result(results, num_fields, genvalue, verbose=True):
+    if verbose:
+        print "checking...",
+        sys.stdout.flush()
+    for rnum, row in enumerate(results):
+        expected = tuple([rnum + 1] +
+                         [genvalue(rnum, fnum) for fnum in range(num_fields)])
+        assert row == expected, "got: %s\nexpected: %s" % (row, expected)
+    return True
+
+def avgdev(values, comparison):
+    return sum(value - comparison for value in values) / len(values)
+
+def nicer_res(values, printvalues=False):
+    if printvalues:
+        print values
+    min_time = min(values)
+    return round(min_time, 3), round(avgdev(values, min_time), 2)
+
+def profile_func(func_name, verbose=True):
+    if verbose:
+        print "profiling...",
+        sys.stdout.flush()
+    cProfile.run('%s()' % func_name, 'prof')
+
+def time_func(func_name, num_tests=1, verbose=True):
+    if verbose:
+        print "timing...",
+        sys.stdout.flush()
+    timings = timeit.repeat('%s()' % func_name,
+                            "from __main__ import %s" % func_name,
+                            number=num_tests, repeat=5)
+    avg, dev = nicer_res(timings)
+    if verbose:
+        print "%s (%s)" % (avg, dev)
+    else:
+        print avg
+
+def profile_and_time(func_name, num_tests=1):
+    profile_func(func_name)
+    time_func(func_name, num_tests)
+
+def iter_results(raw_results):
+    return [tuple(row) for row in raw_results]
+
+def getattr_results(raw_results):
+    return [
+        (r.id,
+         r.field0, r.field1, r.field2, r.field3, r.field4,
+         r.field5, r.field6, r.field7, r.field8, r.field9)
+         for r in raw_results]
+
+def fetchall(test_table):
+    def results():
+        return test_table.select().order_by(test_table.c.id).execute() \
+                         .fetchall()
+    return results
+
+def hashable_set(l):
+    hashables = []
+    for o in l:
+        try:
+            hash(o)
+            hashables.append(o)
+        except:
+            pass
+    return set(hashables)
+
+def prepare(field_type, genvalue, engineurl='sqlite://',
+            num_fields=10, num_records=1000, freshdata=True, verbose=True):
+    global metadata
+    metadata.clear()
+    metadata.bind = engineurl
+    test_table = gen_table(num_fields, field_type, metadata)
+    if freshdata:
+        metadata.drop_all()
+        metadata.create_all()
+        insert(test_table, num_fields, num_records, genvalue, verbose)
+    return test_table
+
+def time_dbfunc(test_table, test_func, genvalue,
+                class_=None,
+                getresults_func=None,
+                num_fields=10, num_records=1000, num_tests=1,
+                check_results=check_result, profile=True,
+                check_leaks=True, print_leaks=False, verbose=True):
+    if verbose:
+        print "testing '%s'..." % test_func.__name__,
+    sys.stdout.flush()
+    if class_ is not None:
+        clear_mappers()
+        mapper(class_, test_table)
+    if getresults_func is None:
+        getresults_func = fetchall(test_table)
+    def test():
+        return test_func(getresults_func())
+    sys.modules['__main__'].test = test
+    if check_leaks:
+        gc.collect()
+        objects_before = gc.get_objects()
+        num_objects_before = len(objects_before)
+        hashable_objects_before = hashable_set(objects_before)
+#    gc.set_debug(gc.DEBUG_LEAK)
+    if check_results:
+        check_results(test(), num_fields, genvalue, verbose)
+    if check_leaks:
+        gc.collect()
+        objects_after = gc.get_objects()
+        num_objects_after = len(objects_after)
+        num_leaks = num_objects_after - num_objects_before
+        hashable_objects_after = hashable_set(objects_after)
+        diff = hashable_objects_after - hashable_objects_before
+        ldiff = len(diff)
+        if print_leaks and ldiff < num_records:
+            print "\n*** hashable objects leaked (%d) ***" % ldiff
+            print '\n'.join(map(str, diff))
+            print "***\n"
+
+        if num_leaks > num_records:
+            print "(leaked: %d !)" % num_leaks,
+    if profile:
+        profile_func('test', verbose)
+    time_func('test', num_tests, verbose)
+
+def profile_and_time_dbfunc(test_func, field_type, genvalue,
+                            class_=None,
+                            getresults_func=None,
+                            engineurl='sqlite://', freshdata=True,
+                            num_fields=10, num_records=1000, num_tests=1,
+                            check_results=check_result, profile=True,
+                            check_leaks=True, print_leaks=False, verbose=True):
+    test_table = prepare(field_type, genvalue, engineurl,
+                         num_fields, num_records, freshdata, verbose)
+    time_dbfunc(test_table, test_func, genvalue, class_,
+                getresults_func,
+                num_fields, num_records, num_tests,
+                check_results, profile,
+                check_leaks, print_leaks, verbose)
index 345ecef67a3e40414931e855667f295e14da432a..5433cb92fc23567ec0bb93ee212866314b90093c 100644 (file)
@@ -701,21 +701,21 @@ class QueryTest(TestBase):
                          Column('shadow_name', VARCHAR(20)),
                          Column('parent', VARCHAR(20)),
                          Column('row', VARCHAR(40)),
-                         Column('__parent', VARCHAR(20)),
-                         Column('__row', VARCHAR(20)),
+                         Column('_parent', VARCHAR(20)),
+                         Column('_row', VARCHAR(20)),
         )
         shadowed.create(checkfirst=True)
         try:
-            shadowed.insert().execute(shadow_id=1, shadow_name='The Shadow', parent='The Light', row='Without light there is no shadow', __parent='Hidden parent', __row='Hidden row')
+            shadowed.insert().execute(shadow_id=1, shadow_name='The Shadow', parent='The Light', row='Without light there is no shadow', _parent='Hidden parent', _row='Hidden row')
             r = shadowed.select(shadowed.c.shadow_id==1).execute().first()
             self.assert_(r.shadow_id == r['shadow_id'] == r[shadowed.c.shadow_id] == 1)
             self.assert_(r.shadow_name == r['shadow_name'] == r[shadowed.c.shadow_name] == 'The Shadow')
             self.assert_(r.parent == r['parent'] == r[shadowed.c.parent] == 'The Light')
             self.assert_(r.row == r['row'] == r[shadowed.c.row] == 'Without light there is no shadow')
-            self.assert_(r['__parent'] == 'Hidden parent')
-            self.assert_(r['__row'] == 'Hidden row')
+            self.assert_(r['_parent'] == 'Hidden parent')
+            self.assert_(r['_row'] == 'Hidden row')
             try:
-                print r.__parent, r.__row
+                print r._parent, r._row
                 self.fail('Should not allow access to private attributes')
             except AttributeError:
                 pass # expected