]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- RowProxy objects are now pickleable, i.e. the object returned
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 11 Oct 2009 17:16:53 +0000 (17:16 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 11 Oct 2009 17:16:53 +0000 (17:16 +0000)
by result.fetchone(), result.fetchall() etc.
- the "named tuple" objects returned when iterating a
Query() are now pickleable.

CHANGES
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/util.py
test/orm/test_query.py
test/sql/test_query.py

diff --git a/CHANGES b/CHANGES
index add2ece7c0d11ccab6997abd7f57908b22360085..44bd7e577a0ae7077d60c389013909688667e14e 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -22,6 +22,15 @@ CHANGES
 
   - query.update() and query.delete() both default to
     'evaluate' for the synchronize strategy.
+
+  - the 'synchronize' strategy for update() and delete() raises
+    an error on failure. There is no implicit fallback onto
+    "fetch". Failure of evaluation is based on the structure of
+    criteria, so success/failure is deterministic based on code
+    structure.
+
+  - the "named tuple" objects returned when iterating a
+    Query() are now pickleable.
     
   - query.join() has been reworked to provide more consistent 
     behavior and more flexibility (includes [ticket:1537])
@@ -30,12 +39,6 @@ CHANGES
     multiple comma separated entries within the FROM clause.
     Useful when selecting from multiple-homed join() clauses.
     
-  - the 'synchronize' strategy for update() and delete() raises
-    an error on failure. There is no implicit fallback onto
-    "fetch". Failure of evaluation is based on the structure of
-    criteria, so success/failure is deterministic based on code
-    structure.
-
   - the "dont_load=True" flag on Session.merge() is deprecated
     and is now "load=False".
 
@@ -121,6 +124,9 @@ CHANGES
   - added first() method to ResultProxy, returns first row and
     closes result set immediately.
 
+  - RowProxy objects are now pickleable, i.e. the object returned
+    by result.fetchone(), result.fetchall() etc.
+
 - schema
     - deprecated MetaData.connect() and
       ThreadLocalMetaData.connect() have been removed - send
index 70ee295db4acd4ef0841cebd05853eb5bba392fe..26e44dd6b04c64463450c4621e6d4b75518bbcf7 100644 (file)
@@ -1497,7 +1497,7 @@ class RowProxy(object):
         self.__row = row
         if self.__parent._echo:
             self.__parent.context.engine.logger.debug("Row %r", row)
-
+        
     def close(self):
         """Close the parent ResultProxy."""
 
@@ -1508,7 +1508,17 @@ class RowProxy(object):
 
     def __len__(self):
         return len(self.__row)
-
+    
+    def __getstate__(self):
+        return {
+            '__row':[self.__parent._get_col(self.__row, i) for i in xrange(len(self.__row))],
+            '__parent':PickledResultProxy(self.__parent)
+        }
+    
+    def __setstate__(self, d):
+        self.__row = d['__row']
+        self.__parent = d['__parent']
+        
     def __iter__(self):
         for i in xrange(len(self.__row)):
             yield self.__parent._get_col(self.__row, i)
@@ -1561,7 +1571,52 @@ class RowProxy(object):
     def itervalues(self):
         return iter(self)
 
-
+class PickledResultProxy(object):
+    """a 'mock' ResultProxy used by a RowProxy being pickled."""
+    
+    _echo = False
+    
+    def __init__(self, resultproxy):
+        self._props = dict(
+            (k, resultproxy._props[k][2]) for k in resultproxy._props
+            if isinstance(k, (basestring, int))
+        )
+        self._keys = resultproxy.keys
+
+    def _fallback_key(self, key):
+        if key in self._props:
+            return self._props[key]
+            
+        if isinstance(key, basestring):
+            key = key.lower()
+            if key in self._props:
+                return self._props[key]
+
+        if isinstance(key, expression.ColumnElement):
+            if key._label and key._label.lower() in self._props:
+                return self._props[key._label.lower()]
+            elif hasattr(key, 'name') and key.name.lower() in self._props:
+                return self._props[key.name.lower()]
+        
+        return None
+        
+    def close(self):
+        pass
+        
+    def _has_key(self, row, key):
+        return self._fallback_key(key) is not None
+        
+    def _get_col(self, row, orig_key):
+        key = self._fallback_key(orig_key)
+        if key is None:
+            raise exc.NoSuchColumnError("Could not locate column in row for column '%s'" % orig_key)
+        return row[key]
+        
+    @property
+    def keys(self):
+        return self._keys
+        
+        
 class BufferedColumnRow(RowProxy):
     def __init__(self, parent, row):
         row = [ResultProxy._get_col(parent, row, i) for i in xrange(len(row))]
@@ -1639,7 +1694,7 @@ class ResultProxy(object):
         
         """
         return self.cursor.lastrowid
-
+    
     def _cursor_description(self):
         return self.cursor.description
             
@@ -1732,7 +1787,7 @@ class ResultProxy(object):
                 elif hasattr(key, 'name') and key.name.lower() in props:
                     return props[key.name.lower()]
 
-            raise exc.NoSuchColumnError("Could not locate column in row for column '%s'" % (str(key)))
+            raise exc.NoSuchColumnError("Could not locate column in row for column '%s'" % key)
         return fallback
 
     def __ambiguous_processor(self, colname):
index b1b85cd01cd6fdb6d786b0eda77ab9380532f043..b347e205e897981add754d4dbde4c40ae1818b1a 100644 (file)
@@ -1371,11 +1371,7 @@ class Query(object):
         (process, labels) = zip(*[query_entity.row_processor(self, context, custom_rows) for query_entity in self._entities])
 
         if not single_entity:
-            labels = dict((label, property(itemgetter(i)))
-                          for i, label in enumerate(labels)
-                          if label)
-            rowtuple = type.__new__(type, "RowTuple", (tuple,), labels)
-            rowtuple.keys = labels.keys
+            labels = [l for l in labels if l]
 
         while True:
             context.progress = {}
@@ -1395,7 +1391,7 @@ class Query(object):
             elif single_entity:
                 rows = [process[0](context, row) for row in fetch]
             else:
-                rows = [rowtuple(proc(context, row) for proc in process)
+                rows = [util.NamedTuple(labels, (proc(context, row) for proc in process))
                         for row in fetch]
 
             if filter:
index 67990a2028d04b16ecf2fc724a3bcb9ee3e6b984..8f0b5583dd758af8b51d650cde1f6f45e01b5ef4 100644 (file)
@@ -635,6 +635,23 @@ def monkeypatch_proxied_specials(into_cls, from_cls, skip=None, only=None,
             pass
         setattr(into_cls, method, env[method])
 
+class NamedTuple(tuple):
+    """tuple() subclass that adds labeled names.
+    
+    Is also pickleable.
+    
+    """
+
+    def __new__(cls, labels, vals):
+        vals = list(vals)
+        t = tuple.__new__(cls, vals)
+        t.__dict__ = dict(zip(labels, vals))
+        t._labels = labels
+        return t
+
+    def keys(self):
+        return self._labels
+
 
 class OrderedProperties(object):
     """An object that maintains the order in which attributes are set upon it.
index 31547d16db4660b95c428e76ce5b0e3da02069f9..0ec2b998d98fb78ad5ff2b7e0400d7936659171f 100644 (file)
@@ -2315,32 +2315,44 @@ class MixedEntitiesTest(QueryTest):
 
     def test_tuple_labeling(self):
         sess = create_session()
-        for row in sess.query(User, Address).join(User.addresses).all():
-            eq_(set(row.keys()), set(['User', 'Address']))
-            eq_(row.User, row[0])
-            eq_(row.Address, row[1])
         
-        for row in sess.query(User.name, User.id.label('foobar')):
-            eq_(set(row.keys()), set(['name', 'foobar']))
-            eq_(row.name, row[0])
-            eq_(row.foobar, row[1])
-
-        for row in sess.query(User).values(User.name, User.id.label('foobar')):
-            eq_(set(row.keys()), set(['name', 'foobar']))
-            eq_(row.name, row[0])
-            eq_(row.foobar, row[1])
-
-        oalias = aliased(Order)
-        for row in sess.query(User, oalias).join(User.orders).all():
-            eq_(set(row.keys()), set(['User']))
-            eq_(row.User, row[0])
-
-        oalias = aliased(Order, name='orders')
-        for row in sess.query(User, oalias).join(User.orders).all():
-            eq_(set(row.keys()), set(['User', 'orders']))
-            eq_(row.User, row[0])
-            eq_(row.orders, row[1])
+        for pickled in False, True:
+            for row in sess.query(User, Address).join(User.addresses).all():
+                if pickled:
+                    row = util.pickle.loads(util.pickle.dumps(row))
+                    
+                eq_(set(row.keys()), set(['User', 'Address']))
+                eq_(row.User, row[0])
+                eq_(row.Address, row[1])
+        
+            for row in sess.query(User.name, User.id.label('foobar')):
+                if pickled:
+                    row = util.pickle.loads(util.pickle.dumps(row))
+                eq_(set(row.keys()), set(['name', 'foobar']))
+                eq_(row.name, row[0])
+                eq_(row.foobar, row[1])
+
+            for row in sess.query(User).values(User.name, User.id.label('foobar')):
+                if pickled:
+                    row = util.pickle.loads(util.pickle.dumps(row))
+                eq_(set(row.keys()), set(['name', 'foobar']))
+                eq_(row.name, row[0])
+                eq_(row.foobar, row[1])
 
+            oalias = aliased(Order)
+            for row in sess.query(User, oalias).join(User.orders).all():
+                if pickled:
+                    row = util.pickle.loads(util.pickle.dumps(row))
+                eq_(set(row.keys()), set(['User']))
+                eq_(row.User, row[0])
+
+            oalias = aliased(Order, name='orders')
+            for row in sess.query(User, oalias).join(User.orders).all():
+                if pickled:
+                    row = util.pickle.loads(util.pickle.dumps(row))
+                eq_(set(row.keys()), set(['User', 'orders']))
+                eq_(row.User, row[0])
+                eq_(row.orders, row[1])
 
     def test_column_queries(self):
         sess = create_session()
index 3222ff6ef42b2823a0cb707cc330561af941bfc4..470a694fb947a8a206aff6cc429fd1ef7b42dfcb 100644 (file)
@@ -1,10 +1,10 @@
 from sqlalchemy.test.testing import eq_
 import datetime
 from sqlalchemy import *
-from sqlalchemy import exc, sql
+from sqlalchemy import exc, sql, util
 from sqlalchemy.engine import default
 from sqlalchemy.test import *
-from sqlalchemy.test.testing import eq_, assert_raises_message
+from sqlalchemy.test.testing import eq_, assert_raises_message, assert_raises
 from sqlalchemy.test.schema import Table, Column
 
 class QueryTest(TestBase):
@@ -207,7 +207,7 @@ class QueryTest(TestBase):
         for row in select([sel + 1, sel + 3], bind=users.bind).execute():
             assert row['anon_1'] == 8
             assert row['anon_2'] == 10
-
+    
     @testing.fails_on('firebird', "kinterbasdb doesn't send full type information")
     def test_order_by_label(self):
         """test that a label within an ORDER BY works on each backend.
@@ -260,6 +260,47 @@ class QueryTest(TestBase):
         self.assert_(not (rp != equal))
         self.assert_(not (equal != equal))
 
+    def test_pickled_rows(self):
+        users.insert().execute(
+            {'user_id':7, 'user_name':'jack'},
+            {'user_id':8, 'user_name':'ed'},
+            {'user_id':9, 'user_name':'fred'},
+        )
+
+        for pickle in False, True:
+            for use_labels in False, True:
+                result = users.select(use_labels=use_labels).order_by(users.c.user_id).execute().fetchall()
+            
+                if pickle:
+                    result = util.pickle.loads(util.pickle.dumps(result))
+                
+                eq_(
+                    result, 
+                    [(7, "jack"), (8, "ed"), (9, "fred")]
+                )
+                if use_labels:
+                    eq_(result[0]['query_users_user_id'], 7)
+                    eq_(result[0].keys(), ["query_users_user_id", "query_users_user_name"])
+                else:
+                    eq_(result[0]['user_id'], 7)
+                    eq_(result[0].keys(), ["user_id", "user_name"])
+                    
+                eq_(result[0][0], 7)
+                eq_(result[0][users.c.user_id], 7)
+                eq_(result[0][users.c.user_name], 'jack')
+            
+                if use_labels:
+                    assert_raises(exc.NoSuchColumnError, lambda: result[0][addresses.c.user_id])
+                else:
+                    # test with a different table.  name resolution is 
+                    # causing 'user_id' to match when use_labels wasn't used.
+                    eq_(result[0][addresses.c.user_id], 7)
+            
+                assert_raises(exc.NoSuchColumnError, lambda: result[0]['fake key'])
+                assert_raises(exc.NoSuchColumnError, lambda: result[0][addresses.c.address_id])
+            
+
+
     @testing.requires.boolean_col_expressions
     def test_or_and_as_columns(self):
         true, false = literal(True), literal(False)