From: Mike Bayer Date: Thu, 11 Dec 2008 17:27:33 +0000 (+0000) Subject: - PickleType now favors == comparison by default, X-Git-Tag: rel_0_5_0~118 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=f527d3b9afc212f33bf75084fae5664513ca4184;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - PickleType now favors == comparison by default, if the incoming object (such as a dict) implements __eq__(). If the object does not implement __eq__() and mutable=True, a deprecation warning is raised. --- diff --git a/CHANGES b/CHANGES index 63ad11cc53..a9a9edd991 100644 --- a/CHANGES +++ b/CHANGES @@ -108,6 +108,12 @@ CHANGES mapper since it's not needed. - sql + - PickleType now favors == comparison by default, + if the incoming object (such as a dict) implements + __eq__(). If the object does not implement + __eq__() and mutable=True, a deprecation warning + is raised. + - Fixed the import weirdness in sqlalchemy.sql to not export __names__ [ticket:1215]. diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py index 70022b3541..2604f4e8fe 100644 --- a/lib/sqlalchemy/types.py +++ b/lib/sqlalchemy/types.py @@ -744,12 +744,20 @@ class PickleType(MutableType, TypeDecorator): pickle-compatible ``dumps` and ``loads`` methods. :param mutable: defaults to True; implements - :meth:`AbstractType.is_mutable`. + :meth:`AbstractType.is_mutable`. When ``True``, incoming + objects *must* provide an ``__eq__()`` method which + performs the desired deep comparison of members, or the + ``comparator`` argument must be present. Otherwise, + comparisons are done by comparing pickle strings. + The pickle form of comparison is a deprecated usage and will + raise a warning. :param comparator: optional. a 2-arg callable predicate used - to compare values of this type. Defaults to equality if - *mutable* is False or ``pickler.dumps()`` equality if - *mutable* is True. + to compare values of this type. Otherwise, either + the == operator is used to compare values, or if mutable==True + and the incoming object does not implement __eq__(), the value + of pickle.dumps(obj) is compared. The last option is a deprecated + usage and will raise a warning. """ self.protocol = protocol @@ -780,7 +788,8 @@ class PickleType(MutableType, TypeDecorator): def compare_values(self, x, y): if self.comparator: return self.comparator(x, y) - elif self.mutable: + elif self.mutable and not hasattr(x, '__eq__') and x is not None: + util.warn_deprecated("Objects stored with PickleType when mutable=True must implement __eq__() for reliable comparison.") return self.pickler.dumps(x, self.protocol) == self.pickler.dumps(y, self.protocol) else: return x == y diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py index 297dd738f1..8b68fb1086 100644 --- a/lib/sqlalchemy/util.py +++ b/lib/sqlalchemy/util.py @@ -138,6 +138,17 @@ except ImportError: getattr(wrapper, attr).update(getattr(wrapped, attr, ())) return wrapper +try: + from functools import partial +except: + def partial(func, *args, **keywords): + def newfunc(*fargs, **fkeywords): + newkeywords = keywords.copy() + newkeywords.update(fkeywords) + return func(*(args + fargs), **newkeywords) + return newfunc + + def accepts_a_list_as_starargs(list_deprecation=None): def decorate(fn): diff --git a/test/orm/mapper.py b/test/orm/mapper.py index 7e2d3b5958..a2cac1a8e1 100644 --- a/test/orm/mapper.py +++ b/test/orm/mapper.py @@ -2284,44 +2284,6 @@ class MagicNamesTest(_base.MappedTest): reserved: maps.c.state}) -class ScalarRequirementsTest(_base.MappedTest): - - # TODO: is this needed here? - # what does this suite excercise that unitofwork doesn't? - - def define_tables(self, metadata): - Table('t1', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', sa.PickleType())) - - def setup_classes(self): - class Foo(_base.ComparableEntity): - pass - - @testing.resolve_artifact_names - def test_correct_comparison(self): - mapper(Foo, t1) - - f1 = Foo(data=pickleable.NotComparable('12345')) - - session = create_session() - session.add(f1) - session.flush() - session.clear() - - f1 = session.query(Foo).get(f1.id) - eq_(f1.data.data, '12345') - - f2 = Foo(data=pickleable.BrokenComparable('abc')) - - session.add(f2) - session.flush() - session.clear() - - f2 = session.query(Foo).get(f2.id) - eq_(f2.data.data, 'abc') - if __name__ == "__main__": testenv.main() diff --git a/test/orm/unitofwork.py b/test/orm/unitofwork.py index 627e7cb99b..7bdbe745c3 100644 --- a/test/orm/unitofwork.py +++ b/test/orm/unitofwork.py @@ -364,6 +364,7 @@ class MutableTypesTest(_base.MappedTest): "WHERE mutable_t.id = :mutable_t_id", {'mutable_t_id': f1.id, 'val': u'hi', 'data':f1.data})]) + @testing.uses_deprecated() @testing.resolve_artifact_names def test_nocomparison(self): """Changes are detected on MutableTypes lacking an __eq__ method.""" diff --git a/test/pickleable.py b/test/pickleable.py index f6331ca0be..3ffc1e59be 100644 --- a/test/pickleable.py +++ b/test/pickleable.py @@ -20,7 +20,18 @@ class Bar(object): def __str__(self): return "Bar(%d, %d)" % (self.x, self.y) +class OldSchool: + def __init__(self, x, y): + self.x = x + self.y = y + def __eq__(self, other): + return other.__class__ is self.__class__ and other.x==self.x and other.y==self.y +class OldSchoolWithoutCompare: + def __init__(self, x, y): + self.x = x + self.y = y + class BarWithoutCompare(object): def __init__(self, x, y): self.x = x diff --git a/test/sql/testtypes.py b/test/sql/testtypes.py index b3e2b0b57f..e66ff6b116 100644 --- a/test/sql/testtypes.py +++ b/test/sql/testtypes.py @@ -801,16 +801,53 @@ class BooleanTest(TestBase, AssertsExecutionResults): print res2 assert(res2==[(2, False)]) -try: - from functools import partial -except: - def partial(func, *args, **keywords): - def newfunc(*fargs, **fkeywords): - newkeywords = keywords.copy() - newkeywords.update(fkeywords) - return func(*(args + fargs), **newkeywords) - return newfunc +class PickleTest(TestBase): + def test_noeq_deprecation(self): + p1 = PickleType() + + self.assertRaises(DeprecationWarning, + p1.compare_values, pickleable.BarWithoutCompare(1, 2), pickleable.BarWithoutCompare(1, 2) + ) + self.assertRaises(DeprecationWarning, + p1.compare_values, pickleable.OldSchoolWithoutCompare(1, 2), pickleable.OldSchoolWithoutCompare(1, 2) + ) + + @testing.uses_deprecated() + def go(): + # test actual dumps comparison + assert p1.compare_values(pickleable.BarWithoutCompare(1, 2), pickleable.BarWithoutCompare(1, 2)) + assert p1.compare_values(pickleable.OldSchoolWithoutCompare(1, 2), pickleable.OldSchoolWithoutCompare(1, 2)) + go() + + assert p1.compare_values({1:2, 3:4}, {3:4, 1:2}) + + p2 = PickleType(mutable=False) + assert not p2.compare_values(pickleable.BarWithoutCompare(1, 2), pickleable.BarWithoutCompare(1, 2)) + assert not p2.compare_values(pickleable.OldSchoolWithoutCompare(1, 2), pickleable.OldSchoolWithoutCompare(1, 2)) + + def test_eq_comparison(self): + p1 = PickleType() + + for obj in ( + {'1':'2'}, + pickleable.Bar(5, 6), + pickleable.OldSchool(10, 11) + ): + assert p1.compare_values(p1.copy_value(obj), obj) + + self.assertRaises(NotImplementedError, p1.compare_values, pickleable.BrokenComparable('foo'),pickleable.BrokenComparable('foo')) + + def test_nonmutable_comparison(self): + p1 = PickleType() + + for obj in ( + {'1':'2'}, + pickleable.Bar(5, 6), + pickleable.OldSchool(10, 11) + ): + assert p1.compare_values(p1.copy_value(obj), obj) + class CallableTest(TestBase): def setUpAll(self): global meta @@ -820,7 +857,7 @@ class CallableTest(TestBase): meta.drop_all() def test_callable_as_arg(self): - ucode = partial(Unicode, assert_unicode=None) + ucode = util.partial(Unicode, assert_unicode=None) thing_table = Table('thing', meta, Column('name', ucode(20)) @@ -829,7 +866,7 @@ class CallableTest(TestBase): thing_table.create() def test_callable_as_kwarg(self): - ucode = partial(Unicode, assert_unicode=None) + ucode = util.partial(Unicode, assert_unicode=None) thang_table = Table('thang', meta, Column('name', type_=ucode(20), primary_key=True) diff --git a/test/testlib/testing.py b/test/testlib/testing.py index 5f5d323c79..a7ac138491 100644 --- a/test/testlib/testing.py +++ b/test/testlib/testing.py @@ -307,13 +307,13 @@ def emits_warning(*messages): filters = [dict(action='ignore', category=sa_exc.SAPendingDeprecationWarning)] if not messages: - filters.append([dict(action='ignore', - category=sa_exc.SAWarning)]) + filters.append(dict(action='ignore', + category=sa_exc.SAWarning)) else: - filters.extend([dict(action='ignore', + filters.extend(dict(action='ignore', message=message, category=sa_exc.SAWarning) - for message in messages]) + for message in messages) for f in filters: warnings.filterwarnings(**f) try: