]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
NamedTuple is pickleable ! no really with all the protocols too !
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 13 Jan 2010 17:11:27 +0000 (17:11 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 13 Jan 2010 17:11:27 +0000 (17:11 +0000)
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/util.py
test/orm/test_pickled.py
test/orm/test_query.py

index bd9069e3a230a945e4689c9e6b0ea77d4aff468f..fc83f91959d7d87e235dcf0b8b0db51e5d48321c 100644 (file)
@@ -1375,8 +1375,7 @@ class Query(object):
             elif single_entity:
                 rows = [process[0](row, None) for row in fetch]
             else:
-                rows = [util.NamedTuple(labels,
-                                        [proc(row, None) for proc in process])
+                rows = [util.NamedTuple([proc(row, None) for proc in process], labels)
                         for row in fetch]
 
             if filter:
@@ -1445,7 +1444,7 @@ class Query(object):
                                 attributes.instance_state(newrow[i]), 
                                 attributes.instance_dict(newrow[i]), 
                                 load=load, _recursive={})
-                    result.append(util.NamedTuple(row._labels, newrow))  
+                    result.append(util.NamedTuple(newrow, row._labels))  
             
             return iter(result)
         finally:
index f7d696971b45205ded96329cc00df0db2e0fdbfd..c3ae255894585d88215c6b84e3e3c69aa8fb9a25 100644 (file)
@@ -642,11 +642,12 @@ class NamedTuple(tuple):
     
     """
 
-    def __new__(cls, labels, vals):
+    def __new__(cls, vals, labels=None):
         vals = list(vals)
         t = tuple.__new__(cls, vals)
-        t.__dict__ = dict(itertools.izip(labels, vals))
-        t._labels = labels
+        if labels:
+            t.__dict__ = dict(itertools.izip(labels, vals))
+            t._labels = labels
         return t
 
     def keys(self):
index 1fcd1b8a43031d8023f4a800dc77265ccd0323ca..0285f4d0b2b883d4600f3a6b61ed5b023587f764 100644 (file)
@@ -4,7 +4,7 @@ import sqlalchemy as sa
 from sqlalchemy.test import testing
 from sqlalchemy import Integer, String, ForeignKey
 from sqlalchemy.test.schema import Table, Column
-from sqlalchemy.orm import mapper, relation, create_session, attributes, interfaces
+from sqlalchemy.orm import mapper, relation, create_session, sessionmaker, attributes, interfaces
 from test.orm import _base, _fixtures
 
 
@@ -130,6 +130,25 @@ class PickleTest(_fixtures.FixtureTest):
         eq_(ad.email_address, 'ed@bar.com')
         eq_(u2, User(name='ed', addresses=[Address(email_address='ed@bar.com')]))
 
+    @testing.resolve_artifact_names
+    def test_pickle_protocols(self):
+        mapper(User, users, properties={
+            'addresses':relation(Address, backref="user")
+        })
+        mapper(Address, addresses)
+
+        sess = sessionmaker()()
+        u1 = User(name='ed')
+        u1.addresses.append(Address(email_address='ed@bar.com'))
+        sess.add(u1)
+        sess.commit()
+
+        u1 = sess.query(User).first()
+        u1.addresses
+        for protocol in -1, 0, 1, 2:
+            u2 = pickle.loads(pickle.dumps(u1, protocol))
+            eq_(u1, u2)
+        
     @testing.resolve_artifact_names
     def test_options_with_descriptors(self):
         mapper(User, users, properties={
index bc3b9e26d8fa1fe554eb8f8c16ed859d72c96028..ee9f1853ba0f3c9f131bd4acc8fe41a103f34c20 100644 (file)
@@ -2403,44 +2403,49 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL):
     def test_tuple_labeling(self):
         sess = create_session()
         
-        for pickled in False, True:
+        # test pickle + all the protocols !
+        for pickled in False, -1, 0, 1, 2:
             for row in sess.query(User, Address).join(User.addresses).all():
-                if pickled:
-                    row = util.pickle.loads(util.pickle.dumps(row))
+                if pickled is not False:
+                    row = util.pickle.loads(util.pickle.dumps(row, pickled))
                     
                 eq_(set(row.keys()), set(['User', 'Address']))
                 eq_(row.User, row[0])
                 eq_(row.Address, row[1])
         
             for row in sess.query(User.name, User.id.label('foobar')):
-                if pickled:
-                    row = util.pickle.loads(util.pickle.dumps(row))
+                if pickled is not False:
+                    row = util.pickle.loads(util.pickle.dumps(row, pickled))
                 eq_(set(row.keys()), set(['name', 'foobar']))
                 eq_(row.name, row[0])
                 eq_(row.foobar, row[1])
 
             for row in sess.query(User).values(User.name, User.id.label('foobar')):
-                if pickled:
-                    row = util.pickle.loads(util.pickle.dumps(row))
+                if pickled is not False:
+                    row = util.pickle.loads(util.pickle.dumps(row, pickled))
                 eq_(set(row.keys()), set(['name', 'foobar']))
                 eq_(row.name, row[0])
                 eq_(row.foobar, row[1])
 
             oalias = aliased(Order)
             for row in sess.query(User, oalias).join(User.orders).all():
-                if pickled:
-                    row = util.pickle.loads(util.pickle.dumps(row))
+                if pickled is not False:
+                    row = util.pickle.loads(util.pickle.dumps(row, pickled))
                 eq_(set(row.keys()), set(['User']))
                 eq_(row.User, row[0])
 
             oalias = aliased(Order, name='orders')
             for row in sess.query(User, oalias).join(User.orders).all():
-                if pickled:
-                    row = util.pickle.loads(util.pickle.dumps(row))
+                if pickled is not False:
+                    row = util.pickle.loads(util.pickle.dumps(row, pickled))
                 eq_(set(row.keys()), set(['User', 'orders']))
                 eq_(row.User, row[0])
                 eq_(row.orders, row[1])
-
+            
+            if pickled is not False:
+                ret = sess.query(User, Address).join(User.addresses).all()
+                util.pickle.loads(util.pickle.dumps(ret, pickled))
+                
     def test_column_queries(self):
         sess = create_session()