]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Handle association proxy delete and provide for scalar delete cascade
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 1 Aug 2018 16:01:59 +0000 (12:01 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 1 Aug 2018 22:06:48 +0000 (18:06 -0400)
Fixed multiple issues regarding de-association of scalar objects with the
association proxy.  ``del`` now works, and additionally a new flag
:paramref:`.AssociationProxy.cascade_scalar_deletes` is added, which when
set to True indicates that setting a scalar attribute to ``None`` or
deleting via ``del`` will also set the source association to ``None``.

Change-Id: I1580d761571d63eb03a7e8df078cef97d265b85c
Fixes: #4308
doc/build/changelog/migration_13.rst
doc/build/changelog/unreleased_13/4308.rst [new file with mode: 0644]
lib/sqlalchemy/ext/associationproxy.py
lib/sqlalchemy/orm/attributes.py
test/ext/test_associationproxy.py
test/orm/test_attributes.py

index 5740728a2fe6f43069e420ed1432ccbc8bad6a8e..7d41ef5f6f10d50a0d5af9d60a3302d662da1e86 100644 (file)
@@ -23,6 +23,59 @@ New Features and Improvements - ORM
 Key Behavioral Changes - ORM
 =============================
 
+.. _change_4308:
+
+Association proxy has new cascade_scalar_deletes flag
+-----------------------------------------------------
+
+Given a mapping as::
+
+    class A(Base):
+        __tablename__ = 'test_a'
+        id = Column(Integer, primary_key=True)
+        ab = relationship(
+            'AB', backref='a', uselist=False)
+        b = association_proxy(
+            'ab', 'b', creator=lambda b: AB(b=b),
+            cascade_scalar_deletes=True)
+
+
+    class B(Base):
+        __tablename__ = 'test_b'
+        id = Column(Integer, primary_key=True)
+        ab = relationship('AB', backref='b', cascade='all, delete-orphan')
+
+
+    class AB(Base):
+        __tablename__ = 'test_ab'
+        a_id = Column(Integer, ForeignKey(A.id), primary_key=True)
+        b_id = Column(Integer, ForeignKey(B.id), primary_key=True)
+
+An assigment to ``A.b`` will generate an ``AB`` object::
+
+    a.b = B()
+
+The ``A.b`` association is scalar, and includes a new flag
+:paramref:`.AssociationProxy.cascade_scalar_deletes`.  When set, setting ``A.b``
+to ``None`` will remove ``A.ab`` as well.   The default behavior remains
+that it leaves ``a.ab`` in place::
+
+    a.b = None
+    assert a.ab is None
+
+While it at first seemed intuitive that this logic should just look at the
+"cascade" attribute of the existing relationship, it's not clear from that
+alone if the proxied object should be removed, hence the behavior is
+made available as an explicit option.
+
+Additionally, ``del`` now works for scalars in a similar manner as setting
+to ``None``::
+
+    del a.b
+    assert a.ab is None
+
+:ticket:`4308`
+
 .. _change_4246:
 
 FOR UPDATE clause is rendered within the joined eager load subquery as well as outside
diff --git a/doc/build/changelog/unreleased_13/4308.rst b/doc/build/changelog/unreleased_13/4308.rst
new file mode 100644 (file)
index 0000000..d4d3d75
--- /dev/null
@@ -0,0 +1,14 @@
+.. change::
+    :tags: bug, ext
+    :tickets: 4308
+
+       Fixed multiple issues regarding de-association of scalar objects with the
+       association proxy.  ``del`` now works, and additionally a new flag
+       :paramref:`.AssociationProxy.cascade_scalar_deletes` is added, which when
+       set to True indicates that setting a scalar attribute to ``None`` or
+       deleting via ``del`` will also set the source association to ``None``.
+
+    .. seealso::
+
+        :ref:`change_4308`
+
index a0945fa6c98f3a2bc310127e7bac0996cfaff214..3c27cb59f6bfd909e11c449b6da5bb34c1154ce1 100644 (file)
@@ -94,7 +94,8 @@ class AssociationProxy(interfaces.InspectionAttrInfo):
 
     def __init__(self, target_collection, attr, creator=None,
                  getset_factory=None, proxy_factory=None,
-                 proxy_bulk_set=None, info=None):
+                 proxy_bulk_set=None, info=None,
+                 cascade_scalar_deletes=False):
         """Construct a new :class:`.AssociationProxy`.
 
         The :func:`.association_proxy` function is provided as the usual
@@ -119,6 +120,15 @@ class AssociationProxy(interfaces.InspectionAttrInfo):
           If you want to construct instances differently, supply a 'creator'
           function that takes arguments as above and returns instances.
 
+        :param cascade_scalar_deletes: when True, indicates that setting
+         the proxied value to ``None``, or deleting it via ``del``, should
+         also remove the source object.  Only applies to scalar attributes.
+         Normally, removing the proxied target will not remove the proxy
+         source, as this object may have other state that is still to be
+         kept.
+
+         .. versionadded:: 1.3
+
         :param getset_factory: Optional.  Proxied attribute access is
           automatically handled by routines that get and set values based on
           the `attr` argument for this proxy.
@@ -150,6 +160,7 @@ class AssociationProxy(interfaces.InspectionAttrInfo):
         self.getset_factory = getset_factory
         self.proxy_factory = proxy_factory
         self.proxy_bulk_set = proxy_bulk_set
+        self.cascade_scalar_deletes = cascade_scalar_deletes
 
         self.owning_class = None
         self.key = '_%s_%s_%s' % (
@@ -308,9 +319,13 @@ class AssociationProxy(interfaces.InspectionAttrInfo):
             creator = self.creator and self.creator or self.target_class
             target = getattr(obj, self.target_collection)
             if target is None:
+                if values is None:
+                    return
                 setattr(obj, self.target_collection, creator(values))
             else:
                 self._scalar_set(target, values)
+                if values is None and self.cascade_scalar_deletes:
+                    setattr(obj, self.target_collection, None)
         else:
             proxy = self.__get__(obj, None)
             if proxy is not values:
@@ -321,7 +336,11 @@ class AssociationProxy(interfaces.InspectionAttrInfo):
         if self.owning_class is None:
             self._calc_owner(obj, None)
 
-        delattr(obj, self.key)
+        if self.scalar:
+            target = getattr(obj, self.target_collection)
+            if target is not None:
+                delattr(target, self.value_attr)
+        delattr(obj, self.target_collection)
 
     def _initialize_scalar_accessors(self):
         if self.getset_factory:
index e9227362e79006277be2c172f4ab927f83494421..0bbe70655ef1faa5dc2bdfe3ceaf54257746fbe0 100644 (file)
@@ -673,7 +673,6 @@ class ScalarAttributeImpl(AttributeImpl):
 
     def delete(self, state, dict_):
 
-        # TODO: catch key errors, convert to attributeerror?
         if self.dispatch._active_history:
             old = self.get(state, dict_, PASSIVE_RETURN_NEVER_SET)
         else:
@@ -682,7 +681,10 @@ class ScalarAttributeImpl(AttributeImpl):
         if self.dispatch.remove:
             self.fire_remove_event(state, dict_, old, self._remove_token)
         state._modified_event(dict_, self, old)
-        del dict_[self.key]
+        try:
+            del dict_[self.key]
+        except KeyError:
+            raise AttributeError("%s object does not have a value" % self)
 
     def get_history(self, state, dict_, passive=PASSIVE_OFF):
         if self.key in dict_:
@@ -742,7 +744,10 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl):
     def delete(self, state, dict_):
         old = self.get(state, dict_)
         self.fire_remove_event(state, dict_, old, self._remove_token)
-        del dict_[self.key]
+        try:
+            del dict_[self.key]
+        except:
+            raise AttributeError("%s object does not have a value" % self)
 
     def get_history(self, state, dict_, passive=PASSIVE_OFF):
         if self.key in dict_:
@@ -969,7 +974,9 @@ class CollectionAttributeImpl(AttributeImpl):
 
         collection = self.get_collection(state, state.dict)
         collection.clear_with_event()
-        # TODO: catch key errors, convert to attributeerror?
+
+        # key is always present because we checked above.  e.g.
+        # del is a no-op if collection not present.
         del dict_[self.key]
 
     def initialize(self, state, dict_):
index 2d403fa8dd64f5d89e70fa496b4b998b43ff4c24..70d5e1613af6c9dc221c923a4218f96b676ad0c1 100644 (file)
@@ -2107,6 +2107,239 @@ class AttributeAccessTest(fixtures.TestBase):
         is_(Bat.foo.owning_class, Bat)
 
 
+class ScalarRemoveTest(object):
+    useobject = None
+    cascade_scalar_deletes = None
+    uselist = None
+
+    @classmethod
+    def setup_classes(cls):
+        Base = cls.DeclarativeBasic
+
+        class A(Base):
+            __tablename__ = 'test_a'
+            id = Column(Integer, primary_key=True)
+            ab = relationship(
+                'AB', backref='a',
+                uselist=cls.uselist)
+            b = association_proxy(
+                'ab', 'b', creator=lambda b: AB(b=b),
+                cascade_scalar_deletes=cls.cascade_scalar_deletes)
+
+        if cls.useobject:
+            class B(Base):
+                __tablename__ = 'test_b'
+                id = Column(Integer, primary_key=True)
+                ab = relationship('AB', backref="b")
+
+            class AB(Base):
+                __tablename__ = 'test_ab'
+                a_id = Column(Integer, ForeignKey(A.id), primary_key=True)
+                b_id = Column(Integer, ForeignKey(B.id), primary_key=True)
+
+        else:
+            class AB(Base):
+                __tablename__ = 'test_ab'
+                b = Column(Integer)
+                a_id = Column(Integer, ForeignKey(A.id), primary_key=True)
+
+    def test_set_nonnone_to_none(self):
+        if self.useobject:
+            A, AB, B = self.classes("A", "AB", "B")
+        else:
+            A, AB = self.classes("A", "AB")
+
+        a1 = A()
+
+        b1 = B() if self.useobject else 5
+
+        if self.uselist:
+            a1.b.append(b1)
+        else:
+            a1.b = b1
+
+        if self.uselist:
+            assert isinstance(a1.ab[0], AB)
+        else:
+            assert isinstance(a1.ab, AB)
+
+        if self.uselist:
+            a1.b.remove(b1)
+        else:
+            a1.b = None
+
+        if self.uselist:
+            eq_(a1.ab, [])
+        else:
+            if self.cascade_scalar_deletes:
+                assert a1.ab is None
+            else:
+                assert isinstance(a1.ab, AB)
+                assert a1.ab.b is None
+
+    def test_set_none_to_none(self):
+        if self.uselist:
+            return
+
+        if self.useobject:
+            A, AB, B = self.classes("A", "AB", "B")
+        else:
+            A, AB = self.classes("A", "AB")
+
+        a1 = A()
+
+        a1.b = None
+
+        assert a1.ab is None
+
+    def test_del_already_nonpresent(self):
+        if self.useobject:
+            A, AB, B = self.classes("A", "AB", "B")
+        else:
+            A, AB = self.classes("A", "AB")
+
+        a1 = A()
+
+        if self.uselist:
+            del a1.b
+
+            eq_(a1.ab, [])
+
+        else:
+            def go():
+                del a1.b
+
+            assert_raises_message(
+                AttributeError,
+                "A.ab object does not have a value",
+                go
+            )
+
+    def test_del(self):
+        if self.useobject:
+            A, AB, B = self.classes("A", "AB", "B")
+        else:
+            A, AB = self.classes("A", "AB")
+
+        b1 = B() if self.useobject else 5
+
+        a1 = A()
+        if self.uselist:
+            a1.b.append(b1)
+        else:
+            a1.b = b1
+
+        if self.uselist:
+            assert isinstance(a1.ab[0], AB)
+        else:
+            assert isinstance(a1.ab, AB)
+
+        del a1.b
+
+        if self.uselist:
+            eq_(a1.ab, [])
+        else:
+            assert a1.ab is None
+
+    def test_del_no_proxy(self):
+        if not self.uselist:
+            return
+
+        if self.useobject:
+            A, AB, B = self.classes("A", "AB", "B")
+        else:
+            A, AB = self.classes("A", "AB")
+
+        b1 = B() if self.useobject else 5
+        a1 = A()
+        a1.b.append(b1)
+
+        del a1.ab
+
+        # this is what it does for now, so maintain that w/ assoc proxy
+        eq_(a1.ab, [])
+
+    def test_del_already_nonpresent_no_proxy(self):
+        if not self.uselist:
+            return
+
+        if self.useobject:
+            A, AB, B = self.classes("A", "AB", "B")
+        else:
+            A, AB = self.classes("A", "AB")
+
+        a1 = A()
+
+        del a1.ab
+
+        # this is what it does for now, so maintain that w/ assoc proxy
+        eq_(a1.ab, [])
+
+
+class ScalarRemoveListObjectCascade(
+        ScalarRemoveTest, fixtures.DeclarativeMappedTest):
+
+    useobject = True
+    cascade_scalar_deletes = True
+    uselist = True
+
+
+class ScalarRemoveScalarObjectCascade(
+        ScalarRemoveTest, fixtures.DeclarativeMappedTest):
+
+    useobject = True
+    cascade_scalar_deletes = True
+    uselist = False
+
+
+class ScalarRemoveListScalarCascade(
+        ScalarRemoveTest, fixtures.DeclarativeMappedTest):
+
+    useobject = False
+    cascade_scalar_deletes = True
+    uselist = True
+
+
+class ScalarRemoveScalarScalarCascade(
+        ScalarRemoveTest, fixtures.DeclarativeMappedTest):
+
+    useobject = False
+    cascade_scalar_deletes = True
+    uselist = False
+
+
+class ScalarRemoveListObjectNoCascade(
+        ScalarRemoveTest, fixtures.DeclarativeMappedTest):
+
+    useobject = True
+    cascade_scalar_deletes = False
+    uselist = True
+
+
+class ScalarRemoveScalarObjectNoCascade(
+        ScalarRemoveTest, fixtures.DeclarativeMappedTest):
+
+    useobject = True
+    cascade_scalar_deletes = False
+    uselist = False
+
+
+class ScalarRemoveListScalarNoCascade(
+        ScalarRemoveTest, fixtures.DeclarativeMappedTest):
+
+    useobject = False
+    cascade_scalar_deletes = False
+    uselist = True
+
+
+class ScalarRemoveScalarScalarNoCascade(
+        ScalarRemoveTest, fixtures.DeclarativeMappedTest):
+
+    useobject = False
+    cascade_scalar_deletes = False
+    uselist = False
+
+
 class InfoTest(fixtures.TestBase):
     def test_constructor(self):
         assoc = association_proxy('a', 'b', info={'some_assoc': 'some_value'})
index 12c9dddb9c2a1afd6256ef49670dbbd143f763ad..d56d81565c6044ae6aed58f2524662d4606e0273 100644 (file)
@@ -4,7 +4,7 @@ from sqlalchemy.orm.collections import collection
 from sqlalchemy.orm.interfaces import AttributeExtension
 from sqlalchemy import exc as sa_exc
 from sqlalchemy.testing import eq_, ne_, assert_raises, \
-    assert_raises_message, is_true, is_false
+    assert_raises_message, is_true, is_false, is_
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing.util import gc_collect, all_partial_orderings
 from sqlalchemy.util import jython
@@ -303,6 +303,83 @@ class AttributesTest(fixtures.ORMTest):
             lambda: Foo().bars.append(Bar())
         )
 
+    def test_del_scalar_nonobject(self):
+        class Foo(object):
+            pass
+
+        instrumentation.register_class(Foo)
+        attributes.register_attribute(Foo, 'b', uselist=False, useobject=False)
+
+        f1 = Foo()
+
+        is_(f1.b, None)
+
+        f1.b = 5
+
+        del f1.b
+        is_(f1.b, None)
+
+        def go():
+            del f1.b
+
+        assert_raises_message(
+            AttributeError,
+            "Foo.b object does not have a value",
+            go
+        )
+
+    def test_del_scalar_object(self):
+        class Foo(object):
+            pass
+
+        class Bar(object):
+            pass
+
+        instrumentation.register_class(Foo)
+        instrumentation.register_class(Bar)
+        attributes.register_attribute(Foo, 'b', uselist=False, useobject=True)
+
+        f1 = Foo()
+
+        is_(f1.b, None)
+
+        f1.b = Bar()
+
+        del f1.b
+        is_(f1.b, None)
+
+        def go():
+            del f1.b
+
+        assert_raises_message(
+            AttributeError,
+            "Foo.b object does not have a value",
+            go
+        )
+
+    def test_del_collection_object(self):
+        class Foo(object):
+            pass
+
+        class Bar(object):
+            pass
+
+        instrumentation.register_class(Foo)
+        instrumentation.register_class(Bar)
+        attributes.register_attribute(Foo, 'b', uselist=True, useobject=True)
+
+        f1 = Foo()
+
+        eq_(f1.b, [])
+
+        f1.b = [Bar()]
+
+        del f1.b
+        eq_(f1.b, [])
+
+        del f1.b
+        eq_(f1.b, [])
+
     def test_deferred(self):
         class Foo(object):
             pass