]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
narrow down cascades in session some more
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 19 Dec 2010 22:47:40 +0000 (17:47 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 19 Dec 2010 22:47:40 +0000 (17:47 -0500)
lib/sqlalchemy/ext/sqlsoup.py
lib/sqlalchemy/orm/session.py

index 9ff503dfa781a5a1eaac124fbb7da9e9a9eef1da..6981919cfb69aa34d90a43a0d4017fae7ae187d8 100644 (file)
@@ -359,7 +359,7 @@ from sqlalchemy import schema, sql, util
 from sqlalchemy.engine.base import Engine
 from sqlalchemy.orm import scoped_session, sessionmaker, mapper, \
                             class_mapper, relationship, session,\
-                            object_session
+                            object_session, attributes
 from sqlalchemy.orm.interfaces import MapperExtension, EXT_CONTINUE
 from sqlalchemy.exceptions import SQLAlchemyError, InvalidRequestError, ArgumentError
 from sqlalchemy.sql import expression
@@ -384,7 +384,8 @@ class AutoAdd(MapperExtension):
 
     def init_instance(self, mapper, class_, oldinit, instance, args, kwargs):
         session = self.scoped_session()
-        session._save_without_cascade(instance)
+        state = attributes.instance_state(instance)
+        session._save_impl(state)
         return EXT_CONTINUE
 
     def init_failed(self, mapper, class_, oldinit, instance, args, kwargs):
index a3714bc7e6c4b7638c288352cedbe127be4e8402..1f704f5023767fc301e65e68619fd232de8a148c 100644 (file)
@@ -981,16 +981,18 @@ class Session(object):
         else:
             # pre-fetch the full cascade since the expire is going to
             # remove associations
-            cascaded = list(_cascade_state_iterator('refresh-expire', state))
+            cascaded = list(state.manager.mapper.cascade_iterator(
+                                            'refresh-expire', state))
             self._conditional_expire(state)
-            for (state, m, o) in cascaded:
-                self._conditional_expire(state)
+            for o, m, st_, dct_ in cascaded:
+                self._conditional_expire(st_)
         
     def _conditional_expire(self, state):
         """Expire a state if persistent, else expunge if pending"""
         
         if state.key:
-            _expire_state(state, state.dict, None, instance_dict=self.identity_map)
+            _expire_state(state, state.dict, None, 
+                                instance_dict=self.identity_map)
         elif state in self._new:
             self._new.pop(state)
             state.detach()
@@ -1023,8 +1025,12 @@ class Session(object):
             raise sa_exc.InvalidRequestError(
                 "Instance %s is not present in this Session" %
                 mapperutil.state_str(state))
-        for s, m, o in [(state, None, None)] + list(_cascade_state_iterator('expunge', state)):
-            self._expunge_state(s)
+
+        cascaded = list(state.manager.mapper.cascade_iterator(
+                                    'expunge', state))
+        self._expunge_state(state)
+        for o, m, st_, dct_ in cascaded:
+            self._expunge_state(st_)
 
     def _expunge_state(self, state):
         if state in self._new:
@@ -1078,12 +1084,6 @@ class Session(object):
         self._deleted.pop(state, None)
         state.deleted = True
 
-    def _save_without_cascade(self, instance):
-        """Used by scoping.py to save on init without cascade."""
-
-        state = _state_for_unsaved_instance(instance, create=True)
-        self._save_impl(state)
-
     def add(self, instance):
         """Place an object in the ``Session``.
 
@@ -1094,7 +1094,11 @@ class Session(object):
         is ``expunge()``.
 
         """
-        state = _state_for_unknown_persistence_instance(instance)
+        try:
+            state = attributes.instance_state(instance)
+        except exc.NO_STATE:
+            raise exc.UnmappedInstanceError(instance)
+
         self._save_or_update_state(state)
 
     def add_all(self, instances):
@@ -1107,7 +1111,10 @@ class Session(object):
         self._save_or_update_impl(state)
 
         mapper = _state_mapper(state)
-        for o, m, st_, dct_ in mapper.cascade_iterator('save-update', state, halt_on=self._contains_state):
+        for o, m, st_, dct_ in mapper.cascade_iterator(
+                                    'save-update', 
+                                    state, 
+                                    halt_on=self._contains_state):
             self._save_or_update_impl(st_)
 
     def delete(self, instance):
@@ -1137,16 +1144,18 @@ class Session(object):
         # grab the cascades before adding the item to the deleted list
         # so that autoflush does not delete the item
         # the strong reference to the instance itself is significant here
-        cascade_states = list(_cascade_state_iterator('delete', state))
+        cascade_states = list(state.manager.mapper.cascade_iterator(
+                                            'delete', state))
 
         self._deleted[state] = state.obj()
         self.identity_map.add(state)
 
-        for state, m, o in cascade_states:
-            self._delete_impl(state)
+        for o, m, st_, dct_ in cascade_states:
+            self._delete_impl(st_)
 
     def merge(self, instance, load=True, **kw):
-        """Copy the state an instance onto the persistent instance with the same identifier.
+        """Copy the state an instance onto the persistent instance with the
+        same identifier.
 
         If there is no persistent instance currently associated with the
         session, it will be loaded.  Return the persistent instance. If the
@@ -1162,7 +1171,8 @@ class Session(object):
         """
         if 'dont_load' in kw:
             load = not kw['dont_load']
-            util.warn_deprecated("dont_load=True has been renamed to load=False.")
+            util.warn_deprecated('dont_load=True has been renamed to '
+                                 'load=False.')
         
         _recursive = {}
         
@@ -1239,7 +1249,9 @@ class Session(object):
             merged_state.load_options = state.load_options
             
             for prop in mapper.iterate_properties:
-                prop.merge(self, state, state_dict, merged_state, merged_dict, load, _recursive)
+                prop.merge(self, state, state_dict, 
+                                merged_state, merged_dict, 
+                                load, _recursive)
 
         if not load:
             # remove any history
@@ -1317,10 +1329,10 @@ class Session(object):
         if state.key and \
             state.key in self.identity_map and \
             not self.identity_map.contains_state(state):
-            raise sa_exc.InvalidRequestError(
-                "Can't attach instance %s; another instance with key %s is already present in this session." % 
-                    (mapperutil.state_str(state), state.key)
-                )
+            raise sa_exc.InvalidRequestError("Can't attach instance "
+                    "%s; another instance with key %s is already "
+                    "present in this session."
+                    % (mapperutil.state_str(state), state.key))
                 
         if state.session_id and state.session_id is not self.hash_key:
             raise sa_exc.InvalidRequestError(
@@ -1473,7 +1485,8 @@ class Session(object):
         #if not objects:
         #    assert not self.identity_map._modified
         #else:
-        #    assert self.identity_map._modified == self.identity_map._modified.difference(objects)
+        #    assert self.identity_map._modified == \
+        #            self.identity_map._modified.difference(objects)
         #self.identity_map._modified.clear()
         
         self.dispatch.on_after_flush_postexec(self, flush_context)
@@ -1596,42 +1609,6 @@ UOWEventHandler = unitofwork.UOWEventHandler
 
 _sessions = weakref.WeakValueDictionary()
 
-def _cascade_state_iterator(cascade, state, **kwargs):
-    mapper = _state_mapper(state)
-    # yield the state, object, mapper.  yielding the object
-    # allows the iterator's results to be held in a list without
-    # states being garbage collected
-    for o, m, st_, dct_ in mapper.cascade_iterator(cascade, state, **kwargs):
-        yield st_, o, m
-
-def _state_for_unsaved_instance(instance, create=False):
-    try:
-        state = attributes.instance_state(instance)
-    except AttributeError:
-        raise exc.UnmappedInstanceError(instance)
-    if state:
-        if state.key is not None:
-            raise sa_exc.InvalidRequestError(
-                "Instance '%s' is already persistent" %
-                mapperutil.state_str(state))
-    elif create:
-        manager = attributes.manager_of_class(instance.__class__)
-        if manager is None:
-            raise exc.UnmappedInstanceError(instance)
-        state = manager.setup_instance(instance)
-    else:
-        raise exc.UnmappedInstanceError(instance)
-
-    return state
-
-def _state_for_unknown_persistence_instance(instance):
-    try:
-        state = attributes.instance_state(instance)
-    except exc.NO_STATE:
-        raise exc.UnmappedInstanceError(instance)
-
-    return state
-
 def make_transient(instance):
     """Make the given instance 'transient'.