]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Fixed 1.0 regression where value objects that override
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 1 Jul 2015 17:19:28 +0000 (13:19 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 1 Jul 2015 17:19:28 +0000 (13:19 -0400)
``__eq__()`` to return a non-boolean-capable object, such as
some geoalchemy types as well as numpy types, were being tested
for ``bool()`` during a unit of work update operation, where in
0.9 the return value of ``__eq__()`` was tested against "is True"
to guard against this.
fixes #3469

doc/build/changelog/changelog_10.rst
lib/sqlalchemy/orm/persistence.py
test/orm/test_unitofworkv2.py

index 9f0f0dff3ef629f04b7fd383b2e7226f3695c6c1..8ac3d5844b0c84ab00e0cc207813a54cb5b9ce60 100644 (file)
 .. changelog::
     :version: 1.0.7
 
+    .. change::
+        :tags: bug, orm
+        :tickets: 3469
+
+        Fixed 1.0 regression where value objects that override
+        ``__eq__()`` to return a non-boolean-capable object, such as
+        some geoalchemy types as well as numpy types, were being tested
+        for ``bool()`` during a unit of work update operation, where in
+        0.9 the return value of ``__eq__()`` was tested against "is True"
+        to guard against this.
+
     .. change::
         :tags: bug, orm
         :tickets: 3468
index 4f074df8e69cdbb436ee605e0e66b572f88c1743..0bfee2ecec2dc07da1b9801efde6be6eca128e2f 100644 (file)
@@ -455,8 +455,10 @@ def _collect_update_commands(
 
                 if isinstance(value, sql.ClauseElement):
                     value_params[col] = value
-                elif not state.manager[propkey].impl.is_equal(
-                        value, state.committed_state[propkey]):
+                # guard against values that generate non-__nonzero__
+                # objects for __eq__()
+                elif state.manager[propkey].impl.is_equal(
+                        value, state.committed_state[propkey]) is not True:
                     params[col.key] = value
 
         if update_version_id is not None and \
index 42b774b102f0dd51ab7ed077101eb25d6a0839ae..35d32a6b34dcd9e4532fa210c5bfdfd85af4e538 100644 (file)
@@ -1846,3 +1846,111 @@ class NoAttrEventInFlushTest(fixtures.MappedTest):
         eq_(t1.id, 1)
         eq_(t1.prefetch_val, 5)
         eq_(t1.returning_val, 5)
+
+
+class TypeWoBoolTest(fixtures.MappedTest, testing.AssertsExecutionResults):
+    """test support for custom datatypes that return a non-__bool__ value
+    when compared via __eq__(), eg. ticket 3469"""
+
+    @classmethod
+    def define_tables(cls, metadata):
+        from sqlalchemy import TypeDecorator
+
+        class NoBool(object):
+            def __nonzero__(self):
+                raise NotImplementedError("not supported")
+
+        class MyWidget(object):
+            def __init__(self, text):
+                self.text = text
+
+            def __eq__(self, other):
+                return NoBool()
+
+        cls.MyWidget = MyWidget
+
+        class MyType(TypeDecorator):
+            impl = String(50)
+
+            def process_bind_param(self, value, dialect):
+                if value is not None:
+                    value = value.text
+                return value
+
+            def process_result_value(self, value, dialect):
+                if value is not None:
+                    value = MyWidget(value)
+                return value
+
+        Table(
+            'test', metadata,
+            Column('id', Integer, primary_key=True,
+                   test_needs_autoincrement=True),
+            Column('value', MyType),
+            Column('unrelated', String(10))
+        )
+
+    @classmethod
+    def setup_classes(cls):
+        class Thing(cls.Basic):
+            pass
+
+    @classmethod
+    def setup_mappers(cls):
+        Thing = cls.classes.Thing
+
+        mapper(Thing, cls.tables.test)
+
+    def test_update_against_none(self):
+        Thing = self.classes.Thing
+
+        s = Session()
+        s.add(Thing(value=self.MyWidget("foo")))
+        s.commit()
+
+        t1 = s.query(Thing).first()
+        t1.value = None
+        s.commit()
+
+        eq_(
+            s.query(Thing.value).scalar(), None
+        )
+
+    def test_update_against_something_else(self):
+        Thing = self.classes.Thing
+
+        s = Session()
+        s.add(Thing(value=self.MyWidget("foo")))
+        s.commit()
+
+        t1 = s.query(Thing).first()
+        t1.value = self.MyWidget("bar")
+        s.commit()
+
+        eq_(
+            s.query(Thing.value).scalar().text, "bar"
+        )
+
+    def test_no_update_no_change(self):
+        Thing = self.classes.Thing
+
+        s = Session()
+        s.add(Thing(value=self.MyWidget("foo"), unrelated='unrelated'))
+        s.commit()
+
+        t1 = s.query(Thing).first()
+        t1.unrelated = 'something else'
+
+        self.assert_sql_execution(
+            testing.db,
+            s.commit,
+            CompiledSQL(
+                "UPDATE test SET unrelated=:unrelated "
+                "WHERE test.id = :test_id",
+                [{'test_id': 1, 'unrelated': 'something else'}]
+            ),
+        )
+
+        eq_(
+            s.query(Thing.value).scalar().text, "foo"
+        )