From: Mike Bayer Date: Thu, 28 Jan 2010 22:47:25 +0000 (+0000) Subject: - make frozendict serializable X-Git-Tag: rel_0_6beta1~15 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=0b185fc84f32c153239fd42a219b5a3a8e56ebda;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - make frozendict serializable - serialize tests use HIGHEST_PROTOCOL --- diff --git a/lib/sqlalchemy/ext/serializer.py b/lib/sqlalchemy/ext/serializer.py index fd456e385f..90be39eaf2 100644 --- a/lib/sqlalchemy/ext/serializer.py +++ b/lib/sqlalchemy/ext/serializer.py @@ -131,9 +131,9 @@ def Deserializer(file, metadata=None, scoped_session=None, engine=None): unpickler.persistent_load = persistent_load return unpickler -def dumps(obj): +def dumps(obj, protocol=0): buf = byte_buffer() - pickler = Serializer(buf) + pickler = Serializer(buf, protocol) pickler.dump(obj) return buf.getvalue() diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py index f8be5c3fd6..9f92354501 100644 --- a/lib/sqlalchemy/util.py +++ b/lib/sqlalchemy/util.py @@ -154,6 +154,9 @@ class frozendict(dict): def __init__(self, *args): pass + def __reduce__(self): + return frozendict, (dict(self), ) + def union(self, d): if not self: return frozendict(d) diff --git a/test/base/test_utils.py b/test/base/test_utils.py index 5377daedd1..9321c888ff 100644 --- a/test/base/test_utils.py +++ b/test/base/test_utils.py @@ -77,6 +77,26 @@ class OrderedSetTest(TestBase): eq_(o.intersection(iter([3,4, 6])), util.OrderedSet([3, 4])) eq_(o.union(iter([3,4, 6])), util.OrderedSet([2, 3, 4, 5, 6])) +class FrozenDictTest(TestBase): + def test_serialize(self): + + picklers = set() + try: + import cPickle + picklers.add(cPickle) + except ImportError: + pass + import pickle + picklers.add(pickle) + + d = util.frozendict({1:2, 3:4}) + + # yes, this thing needs this much testing + for pickle in picklers: + for protocol in -1, 0, 1, 2: + print pickle.loads(pickle.dumps(d, protocol)) + + class ColumnCollectionTest(TestBase): def test_in(self): cc = sql.ColumnCollection() diff --git a/test/ext/test_serializer.py b/test/ext/test_serializer.py index c400797b0e..7e5a73efe3 100644 --- a/test/ext/test_serializer.py +++ b/test/ext/test_serializer.py @@ -69,22 +69,22 @@ class SerializeTest(MappedTest): ) def test_tables(self): - assert serializer.loads(serializer.dumps(users), users.metadata, Session) is users + assert serializer.loads(serializer.dumps(users, -1), users.metadata, Session) is users def test_columns(self): - assert serializer.loads(serializer.dumps(users.c.name), users.metadata, Session) is users.c.name + assert serializer.loads(serializer.dumps(users.c.name, -1), users.metadata, Session) is users.c.name def test_mapper(self): user_mapper = class_mapper(User) - assert serializer.loads(serializer.dumps(user_mapper), None, None) is user_mapper + assert serializer.loads(serializer.dumps(user_mapper, -1), None, None) is user_mapper def test_attribute(self): - assert serializer.loads(serializer.dumps(User.name), None, None) is User.name + assert serializer.loads(serializer.dumps(User.name, -1), None, None) is User.name def test_expression(self): expr = select([users]).select_from(users.join(addresses)).limit(5) - re_expr = serializer.loads(serializer.dumps(expr), users.metadata, None) + re_expr = serializer.loads(serializer.dumps(expr, -1), users.metadata, None) eq_( str(expr), str(re_expr) @@ -100,7 +100,7 @@ class SerializeTest(MappedTest): q = Session.query(User).filter(User.name=='ed').options(eagerload(User.addresses)) eq_(q.all(), [User(name='ed', addresses=[Address(id=2), Address(id=3), Address(id=4)])]) - q2 = serializer.loads(serializer.dumps(q), users.metadata, Session) + q2 = serializer.loads(serializer.dumps(q, -1), users.metadata, Session) def go(): eq_(q2.all(), [User(name='ed', addresses=[Address(id=2), Address(id=3), Address(id=4)])]) self.assert_sql_count(testing.db, go, 1) @@ -110,12 +110,12 @@ class SerializeTest(MappedTest): u1 = Session.query(User).get(8) q = Session.query(Address).filter(Address.user==u1).order_by(desc(Address.email)) - q2 = serializer.loads(serializer.dumps(q), users.metadata, Session) + q2 = serializer.loads(serializer.dumps(q, -1), users.metadata, Session) eq_(q2.all(), [Address(email='ed@wood.com'), Address(email='ed@lala.com'), Address(email='ed@bettyboop.com')]) q = Session.query(User).join(User.addresses).filter(Address.email.like('%fred%')) - q2 = serializer.loads(serializer.dumps(q), users.metadata, Session) + q2 = serializer.loads(serializer.dumps(q, -1), users.metadata, Session) eq_(q2.all(), [User(name='fred')]) eq_(list(q2.values(User.id, User.name)), [(9, u'fred')]) @@ -128,13 +128,13 @@ class SerializeTest(MappedTest): q = Session.query(User, ualias).join((ualias, User.id < ualias.id)).filter(User.id<9).order_by(User.id, ualias.id) eq_(list(q.all()), [(u7, u8), (u7, u9), (u7, u10), (u8, u9), (u8, u10)]) - q2 = serializer.loads(serializer.dumps(q), users.metadata, Session) + q2 = serializer.loads(serializer.dumps(q, -1), users.metadata, Session) eq_(list(q2.all()), [(u7, u8), (u7, u9), (u7, u10), (u8, u9), (u8, u10)]) def test_any(self): r = User.addresses.any(Address.email=='x') - ser = serializer.dumps(r) + ser = serializer.dumps(r, -1) x = serializer.loads(ser, users.metadata) eq_(str(r), str(x))