]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- PickleType now favors == comparison by default,
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 11 Dec 2008 17:27:33 +0000 (17:27 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 11 Dec 2008 17:27:33 +0000 (17:27 +0000)
if the incoming object (such as a dict) implements
__eq__().  If the object does not implement
__eq__() and mutable=True, a deprecation warning
is raised.

CHANGES
lib/sqlalchemy/types.py
lib/sqlalchemy/util.py
test/orm/mapper.py
test/orm/unitofwork.py
test/pickleable.py
test/sql/testtypes.py
test/testlib/testing.py

diff --git a/CHANGES b/CHANGES
index 63ad11cc533321a095b4add34a51c547eaf3dfb8..a9a9edd9912db4d7bdc8074382dbf81dbe077779 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -108,6 +108,12 @@ CHANGES
           mapper since it's not needed.
       
 - sql
+    - PickleType now favors == comparison by default,
+      if the incoming object (such as a dict) implements
+      __eq__().  If the object does not implement 
+      __eq__() and mutable=True, a deprecation warning
+      is raised.
+      
     - Fixed the import weirdness in sqlalchemy.sql
       to not export __names__ [ticket:1215].
     
index 70022b35414f06f75ff05f088fa38d0f1d5fde58..2604f4e8fecd93cb68f7f54d1fb41492583c9dea 100644 (file)
@@ -744,12 +744,20 @@ class PickleType(MutableType, TypeDecorator):
           pickle-compatible ``dumps` and ``loads`` methods.
 
         :param mutable: defaults to True; implements
-          :meth:`AbstractType.is_mutable`.
+          :meth:`AbstractType.is_mutable`.   When ``True``, incoming
+          objects *must* provide an ``__eq__()`` method which 
+          performs the desired deep comparison of members, or the 
+          ``comparator`` argument must be present.  Otherwise,
+          comparisons are done by comparing pickle strings.
+          The pickle form of comparison is a deprecated usage and will
+          raise a warning.
 
         :param comparator: optional. a 2-arg callable predicate used
-          to compare values of this type.  Defaults to equality if
-          *mutable* is False or ``pickler.dumps()`` equality if
-          *mutable* is True.
+          to compare values of this type.  Otherwise, either
+          the == operator is used to compare values, or if mutable==True
+          and the incoming object does not implement __eq__(), the value
+          of pickle.dumps(obj) is compared.  The last option is a deprecated
+          usage and will raise a warning.
 
         """
         self.protocol = protocol
@@ -780,7 +788,8 @@ class PickleType(MutableType, TypeDecorator):
     def compare_values(self, x, y):
         if self.comparator:
             return self.comparator(x, y)
-        elif self.mutable:
+        elif self.mutable and not hasattr(x, '__eq__') and x is not None:
+            util.warn_deprecated("Objects stored with PickleType when mutable=True must implement __eq__() for reliable comparison.")
             return self.pickler.dumps(x, self.protocol) == self.pickler.dumps(y, self.protocol)
         else:
             return x == y
index 297dd738f151ec5d9649a900fcfc6e4aa23235e5..8b68fb10868c8e2e7c036d1c1172b37e4a88f643 100644 (file)
@@ -138,6 +138,17 @@ except ImportError:
             getattr(wrapper, attr).update(getattr(wrapped, attr, ()))
         return wrapper
 
+try:
+    from functools import partial
+except:
+    def partial(func, *args, **keywords):
+        def newfunc(*fargs, **fkeywords):
+            newkeywords = keywords.copy()
+            newkeywords.update(fkeywords)
+            return func(*(args + fargs), **newkeywords)
+        return newfunc
+
+
 def accepts_a_list_as_starargs(list_deprecation=None):
     def decorate(fn):
 
index 7e2d3b5958a0f1c7f3f790a234f44eb97e8a3ab1..a2cac1a8e1c471de8b2616f152ad3d8f986df67c 100644 (file)
@@ -2284,44 +2284,6 @@ class MagicNamesTest(_base.MappedTest):
                   reserved: maps.c.state})
 
 
-class ScalarRequirementsTest(_base.MappedTest):
-
-    # TODO: is this needed here?
-    # what does this suite excercise that unitofwork doesn't?
-
-    def define_tables(self, metadata):
-        Table('t1', metadata,
-              Column('id', Integer, primary_key=True,
-                     test_needs_autoincrement=True),
-              Column('data', sa.PickleType()))
-
-    def setup_classes(self):
-        class Foo(_base.ComparableEntity):
-            pass
-
-    @testing.resolve_artifact_names
-    def test_correct_comparison(self):
-        mapper(Foo, t1)
-
-        f1 = Foo(data=pickleable.NotComparable('12345'))
-
-        session = create_session()
-        session.add(f1)
-        session.flush()
-        session.clear()
-
-        f1 = session.query(Foo).get(f1.id)
-        eq_(f1.data.data, '12345')
-
-        f2 = Foo(data=pickleable.BrokenComparable('abc'))
-
-        session.add(f2)
-        session.flush()
-        session.clear()
-
-        f2 = session.query(Foo).get(f2.id)
-        eq_(f2.data.data, 'abc')
-
 
 if __name__ == "__main__":
     testenv.main()
index 627e7cb99b56829f6d9f777c78fb707b7f1f57a4..7bdbe745c31b9b28190f4d911d0863f5b702db93 100644 (file)
@@ -364,6 +364,7 @@ class MutableTypesTest(_base.MappedTest):
              "WHERE mutable_t.id = :mutable_t_id",
              {'mutable_t_id': f1.id, 'val': u'hi', 'data':f1.data})])
 
+    @testing.uses_deprecated()
     @testing.resolve_artifact_names
     def test_nocomparison(self):
         """Changes are detected on MutableTypes lacking an __eq__ method."""
index f6331ca0bec0e30f74fc864ec6328c419e26bb4e..3ffc1e59be945513c2fca9ea7f703e9a2be15426 100644 (file)
@@ -20,7 +20,18 @@ class Bar(object):
     def __str__(self):
         return "Bar(%d, %d)" % (self.x, self.y)
 
+class OldSchool:
+    def __init__(self, x, y):
+        self.x = x
+        self.y = y
+    def __eq__(self, other):
+        return other.__class__ is self.__class__ and other.x==self.x and other.y==self.y
 
+class OldSchoolWithoutCompare:    
+    def __init__(self, x, y):
+        self.x = x
+        self.y = y
+    
 class BarWithoutCompare(object):
     def __init__(self, x, y):
         self.x = x
index b3e2b0b57f26f933e024269f55df523540f73898..e66ff6b116a0b91a832ce125be08a09efe2796a0 100644 (file)
@@ -801,16 +801,53 @@ class BooleanTest(TestBase, AssertsExecutionResults):
         print res2
         assert(res2==[(2, False)])
 
-try:
-    from functools import partial
-except:
-    def partial(func, *args, **keywords):
-        def newfunc(*fargs, **fkeywords):
-            newkeywords = keywords.copy()
-            newkeywords.update(fkeywords)
-            return func(*(args + fargs), **newkeywords)
-        return newfunc
+class PickleTest(TestBase):
+    def test_noeq_deprecation(self):
+        p1 = PickleType()
+        
+        self.assertRaises(DeprecationWarning, 
+            p1.compare_values, pickleable.BarWithoutCompare(1, 2), pickleable.BarWithoutCompare(1, 2)
+        )
 
+        self.assertRaises(DeprecationWarning, 
+            p1.compare_values, pickleable.OldSchoolWithoutCompare(1, 2), pickleable.OldSchoolWithoutCompare(1, 2)
+        )
+        
+        @testing.uses_deprecated()
+        def go():
+            # test actual dumps comparison
+            assert p1.compare_values(pickleable.BarWithoutCompare(1, 2), pickleable.BarWithoutCompare(1, 2))
+            assert p1.compare_values(pickleable.OldSchoolWithoutCompare(1, 2), pickleable.OldSchoolWithoutCompare(1, 2))
+        go()
+        
+        assert p1.compare_values({1:2, 3:4}, {3:4, 1:2})
+        
+        p2 = PickleType(mutable=False)
+        assert not p2.compare_values(pickleable.BarWithoutCompare(1, 2), pickleable.BarWithoutCompare(1, 2))
+        assert not p2.compare_values(pickleable.OldSchoolWithoutCompare(1, 2), pickleable.OldSchoolWithoutCompare(1, 2))
+        
+    def test_eq_comparison(self):
+        p1 = PickleType()
+        
+        for obj in (
+            {'1':'2'},
+            pickleable.Bar(5, 6),
+            pickleable.OldSchool(10, 11)
+        ):
+            assert p1.compare_values(p1.copy_value(obj), obj)
+
+        self.assertRaises(NotImplementedError, p1.compare_values, pickleable.BrokenComparable('foo'),pickleable.BrokenComparable('foo'))
+        
+    def test_nonmutable_comparison(self):
+        p1 = PickleType()
+
+        for obj in (
+            {'1':'2'},
+            pickleable.Bar(5, 6),
+            pickleable.OldSchool(10, 11)
+        ):
+            assert p1.compare_values(p1.copy_value(obj), obj)
+    
 class CallableTest(TestBase):
     def setUpAll(self):
         global meta
@@ -820,7 +857,7 @@ class CallableTest(TestBase):
         meta.drop_all()
 
     def test_callable_as_arg(self):
-        ucode = partial(Unicode, assert_unicode=None)
+        ucode = util.partial(Unicode, assert_unicode=None)
 
         thing_table = Table('thing', meta,
             Column('name', ucode(20))
@@ -829,7 +866,7 @@ class CallableTest(TestBase):
         thing_table.create()
 
     def test_callable_as_kwarg(self):
-        ucode = partial(Unicode, assert_unicode=None)
+        ucode = util.partial(Unicode, assert_unicode=None)
 
         thang_table = Table('thang', meta,
             Column('name', type_=ucode(20), primary_key=True)
index 5f5d323c79a2b6a87765de6191afa773176be5ed..a7ac138491880219d7abbabd8f7e01d1c47fedd3 100644 (file)
@@ -307,13 +307,13 @@ def emits_warning(*messages):
             filters = [dict(action='ignore',
                             category=sa_exc.SAPendingDeprecationWarning)]
             if not messages:
-                filters.append([dict(action='ignore',
-                                     category=sa_exc.SAWarning)])
+                filters.append(dict(action='ignore',
+                                     category=sa_exc.SAWarning))
             else:
-                filters.extend([dict(action='ignore',
+                filters.extend(dict(action='ignore',
                                      message=message,
                                      category=sa_exc.SAWarning)
-                                for message in messages])
+                                for message in messages)
             for f in filters:
                 warnings.filterwarnings(**f)
             try: