]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- most of the __init__ decoration has been removed from mapper, save for
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 27 May 2007 03:20:11 +0000 (03:20 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 27 May 2007 03:20:11 +0000 (03:20 +0000)
that the mappers all get compiled when an instance of a mapped class is first constructed.
the SessionContextExt extension gets all the "add object to the session" logic now and the
_sa_session and _sa_entity_name arguments only apply to when the SessionContextExt is in use.
Some extra methods to MapperExtension to support __init__ decoration.
- assignmapper loses "join_to", gains "join".  id like to replace all those methods with just "query"
but i think they are too popular, so it should probably get filter(), filter_by() also.

lib/sqlalchemy/ext/assignmapper.py
lib/sqlalchemy/ext/sessioncontext.py
lib/sqlalchemy/orm/__init__.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/query.py
test/ext/activemapper.py
test/ext/selectresults.py
test/orm/inheritance/basic.py
test/orm/inheritance/manytomany.py
test/orm/mapper.py

index b7d5411b652acec3e8fb0d3d1fa80ed063f4cac2..7886f4d272316817bdc64e3da3bcad411406f99a 100644 (file)
@@ -14,15 +14,13 @@ def monkeypatch_query_method(ctx, class_, name):
 def monkeypatch_objectstore_method(ctx, class_, name):
     def do(self, *args, **kwargs):
         session = ctx.current
-        if name == "flush":
-            # flush expects a list of objects
-            self = [self]
         return getattr(session, name)(self, *args, **kwargs)
     try:
         do.__name__ = name
     except:
         pass
     setattr(class_, name, do)
+
     
 def assign_mapper(ctx, class_, *args, **kwargs):
     validate = kwargs.pop('validate', False)
@@ -43,9 +41,9 @@ def assign_mapper(ctx, class_, *args, **kwargs):
     m = mapper(class_, extension=extension, *args, **kwargs)
     class_.mapper = m
     class_.query = classmethod(lambda cls: Query(class_, session=ctx.current))
-    for name in ['get', 'select', 'select_by', 'selectfirst', 'selectfirst_by', 'selectone', 'get_by', 'join_to', 'join_via', 'count', 'count_by', 'options', 'instances']:
+    for name in ['get', 'select', 'select_by', 'selectfirst', 'selectfirst_by', 'selectone', 'get_by', 'join', 'count', 'count_by', 'options', 'instances']:
         monkeypatch_query_method(ctx, class_, name)
-    for name in ['flush', 'delete', 'expire', 'refresh', 'expunge', 'merge', 'save', 'update', 'save_or_update']:
+    for name in ['delete', 'expire', 'refresh', 'expunge', 'save', 'update', 'save_or_update']:
         monkeypatch_objectstore_method(ctx, class_, name)
     return m
 
index 2f81e55d2c4801fe7efa42cce51dd032c234b163..5a0e0afc77911c12c2f98dfa0eff235aca41f071 100644 (file)
@@ -1,5 +1,6 @@
 from sqlalchemy.util import ScopedRegistry
-from sqlalchemy.orm.mapper import MapperExtension
+from sqlalchemy.orm.mapper import MapperExtension, EXT_PASS
+from sqlalchemy.orm import create_session
 
 __all__ = ['SessionContext', 'SessionContextExt']
 
@@ -24,7 +25,9 @@ class SessionContext(object):
                           # be created on the next call to context.current)
     """
 
-    def __init__(self, session_factory, scopefunc=None):
+    def __init__(self, session_factory=None, scopefunc=None):
+        if session_factory is None:
+            session_factory = create_session
         self.registry = ScopedRegistry(session_factory, scopefunc)
         super(SessionContext, self).__init__()
 
@@ -60,3 +63,21 @@ class SessionContextExt(MapperExtension):
 
     def get_session(self):
         return self.context.current
+
+    def init_instance(self, mapper, class_, instance, args, kwargs):
+        session = kwargs.pop('_sa_session', self.context.current)
+        session._save_impl(instance, entity_name=kwargs.pop('_sa_entity_name', None))
+        return EXT_PASS
+
+    def init_failed(self, mapper, class_, instance, args, kwargs):
+        object_session(instance).expunge(instance)
+        return EXT_PASS
+        
+    def dispose_class(self, mapper, class_):
+        if hasattr(class_, '__init__') and hasattr(class_.__init__, '_oldinit'):
+            if class_.__init__._oldinit is not None:
+                class_.__init__ = class_.__init__._oldinit
+            else:
+                delattr(class_, '__init__')
+                
+            
\ No newline at end of file
index a1fa1726e8df12726c221a2b0104c7709543f334..ce22f462391a49a2102927df0dd43190a117988e 100644 (file)
@@ -118,9 +118,7 @@ def clear_mappers():
     """
 
     for mapper in mapper_registry.values():
-        attribute_manager.reset_class_managed(mapper.class_)
-        if hasattr(mapper.class_, 'c'):
-            del mapper.class_.c
+        mapper.dispose()
     mapper_registry.clear()
     sautil.ArgSingleton.instances.clear()
 
index 7453d1caa5197a30369e7d2056f5ff6130812727..223c270b0f322039ffe84415cf63f4c2f2758d4c 100644 (file)
@@ -319,6 +319,16 @@ class Mapper(object):
     props = property(_get_props, doc="compiles this mapper if needed, and returns the "
                      "dictionary of MapperProperty objects associated with this mapper.")
 
+    def dispose(self):
+        attribute_manager.reset_class_managed(self.class_)
+        if hasattr(self.class_, 'c'):
+            del self.class_.c
+        if hasattr(self.class_, '__init__') and hasattr(self.class_.__init__, '_oldinit'):
+            if self.class_.__init__._oldinit is not None:
+                self.class_.__init__ = self.class_.__init__._oldinit
+            else:
+                delattr(self.class_, '__init__')
+        
     def compile(self):
         """Compile this mapper into its final internal format.
 
@@ -574,7 +584,7 @@ class Mapper(object):
         
         this method is called repeatedly during the compilation process as 
         the resulting dictionary contains more equivalents as more inheriting 
-        mappers are compiled.  the repetition of this process may be open to some optimization.
+        mappers are compiled.  the repetition process may be open to some optimization.
         """
 
         result = {}
@@ -723,54 +733,31 @@ class Mapper(object):
         attribute_manager.reset_class_managed(self.class_)
 
         oldinit = self.class_.__init__
-        def init(self, *args, **kwargs):
-            entity_name = kwargs.pop('_sa_entity_name', None)
-            mapper = mapper_registry.get(ClassKey(self.__class__, entity_name))
-            if mapper is not None:
-                mapper = mapper.compile()
-
-                # this gets the AttributeManager to do some pre-initialization,
-                # in order to save on KeyErrors later on
-                attribute_manager.init_attr(self)
-
-            if kwargs.has_key('_sa_session'):
-                session = kwargs.pop('_sa_session')
-            else:
-                # works for whatever mapper the class is associated with
-                if mapper is not None:
-                    session = mapper.extension.get_session()
-                    if session is EXT_PASS:
-                        session = None
-                else:
-                    session = None
-            # if a session was found, either via _sa_session or via mapper extension,
-            # and we have found a mapper, save() this instance to the session, and give it an associated entity_name.
-            # otherwise, this instance will not have a session or mapper association until it is
-            # save()d to some session.
-            if session is not None and mapper is not None:
-                self._entity_name = entity_name
-                session._register_pending(self)
+        def init(instance, *args, **kwargs):
+            self.compile()
+            self.extension.init_instance(self, self.class_, instance, args, kwargs)
 
             if oldinit is not None:
                 try:
-                    oldinit(self, *args, **kwargs)
+                    oldinit(instance, *args, **kwargs)
                 except Exception, e:
                     try:
-                        if session is not None:
-                            session.expunge(self)
+                        self.extension.init_failed(self, self.class_, instance, args, kwargs)
                     except:
                         pass # raise original exception instead
                     raise e
-        # override oldinit, insuring that its not already a Mapper-decorated init method
-        if oldinit is None or not hasattr(oldinit, '_sa_mapper_init'):
-            init._sa_mapper_init = True
+
+        # override oldinit, ensuring that its not already a Mapper-decorated init method
+        if oldinit is None or not hasattr(oldinit, '_oldinit'):
             try:
                 init.__name__ = oldinit.__name__
                 init.__doc__ = oldinit.__doc__
             except:
                 # cant set __name__ in py 2.3 !
                 pass
+            init._oldinit = oldinit
             self.class_.__init__ = init
+        
         mapper_registry[self.class_key] = self
         if self.entity_name is None:
             self.class_.c = self.c
@@ -1629,6 +1616,12 @@ class MapperExtension(object):
     is not overridden.
     """
 
+    def init_instance(self, mapper, class_, instance, args, kwargs):
+        return EXT_PASS
+
+    def init_failed(self, mapper, class_, instance, args, kwargs):
+        return EXT_PASS
+
     def get_session(self):
         """Retrieve a contextual Session instance with which to
         register a new object.
@@ -1830,7 +1823,10 @@ class _ExtensionCarrier(MapperExtension):
             else:
                 return EXT_PASS
         return _do
-        
+    
+    init_instance = _create_do('init_instance')
+    init_failed = _create_do('init_failed')
+    dispose_class = _create_do('dispose_class')
     get_session = _create_do('get_session')
     load = _create_do('load')
     get = _create_do('get')
index 61a37d435f374d9c2ff8fac51fb4e8073ec86cc4..81dbe07b8a400875ebf49c1dc22bc83b98ef6bbd 100644 (file)
@@ -179,7 +179,7 @@ class Query(object):
         ``select_by()`` method.
         """
 
-        return self._join_by(args, params)
+        return self._join_by(args, params, start=self._joinpoint)
 
 
     def join_to(self, key):
@@ -201,7 +201,7 @@ class Query(object):
         mapper.
         """
 
-        mapper = self.mapper
+        mapper = self._joinpoint
         clause = None
         for key in keys:
             prop = mapper.props[key]
index ebb832fdcd14a8f2b6002e5f233c4b99c54aa9af..25e20168c9d5aca7753a636c2feac7dac62b4808 100644 (file)
@@ -253,9 +253,8 @@ class testcase(testbase.PersistTest):
         objectstore.flush()
         objectstore.clear()
         
-        results = Person.select(
-            Address.c.postal_code.like('30075') &
-            Person.join_to('addresses')
+        results = Person.join('addresses').select(
+            Address.c.postal_code.like('30075') 
         )
         self.assertEquals(len(results), 1)
 
index 8df416be94c9ce74475ab64a2065bd244665bc9e..5af61b6a70647f7e6c28155475099b60d4e5cfd7 100644 (file)
@@ -82,7 +82,7 @@ class SelectResultsTest(PersistTest):
     
     def test_options(self):
         class ext1(MapperExtension):
-            def populate_instance(self, mapper, selectcontext, row, instance, identitykey, isnew):
+            def populate_instance(self, mapper, selectcontext, row, instance, **flags):
                 instance.TEST = "hello world"
                 return EXT_PASS
         objectstore.clear()
index 88344c17b1b577f788ebf3fd9e5da5290a0ebc50..f95390008e383d25bb372f88dc3eae58fe413dc7 100644 (file)
@@ -43,9 +43,12 @@ class O2MTest(testbase.ORMTest):
         })
 
         sess = create_session()
-        b1 = Blub("blub #1", _sa_session=sess)
-        b2 = Blub("blub #2", _sa_session=sess)
-        f = Foo("foo #1", _sa_session=sess)
+        b1 = Blub("blub #1")
+        b2 = Blub("blub #2")
+        f = Foo("foo #1")
+        sess.save(b1)
+        sess.save(b2)
+        sess.save(f)
         b1.parent_foo = f
         b2.parent_foo = f
         sess.flush()
index f97b8ed0d558d7383fe9259a47e39441d85f234f..b5cc83e7b4c26e80f7d45aeeaa2fd2cc53025017 100644 (file)
@@ -193,7 +193,8 @@ class InheritTest3(testbase.ORMTest):
         })
 
         sess = create_session()
-        b = Bar('bar #1', _sa_session=sess)
+        b = Bar('bar #1')
+        sess.save(b)
         b.foos.append(Foo("foo #1"))
         b.foos.append(Foo("foo #2"))
         sess.flush()
@@ -226,10 +227,12 @@ class InheritTest3(testbase.ORMTest):
         })
 
         sess = create_session()
-        f1 = Foo("foo #1", _sa_session=sess)
-        b1 = Bar("bar #1", _sa_session=sess)
-        b2 = Bar("bar #2", _sa_session=sess)
-        bl1 = Blub("blub #1", _sa_session=sess)
+        f1 = Foo("foo #1")
+        b1 = Bar("bar #1")
+        b2 = Bar("bar #2")
+        bl1 = Blub("blub #1")
+        for o in (f1, b1, b2, bl1):
+            sess.save(o)
         bl1.foos.append(f1)
         bl1.bars.append(b2)
         sess.flush()
index de4dfa0ce03ba004381e866a2717797809962d1d..8dd9c0cca38622eace756542a65274af046e70bd 100644 (file)
@@ -3,7 +3,7 @@ import testbase
 import unittest, sys, os
 from sqlalchemy import *
 import sqlalchemy.exceptions as exceptions
-from sqlalchemy.ext.sessioncontext import SessionContext
+from sqlalchemy.ext.sessioncontext import SessionContext, SessionContextExt
 from tables import *
 import tables
 
@@ -167,7 +167,7 @@ class MapperTest(MapperSuperTest):
             def __init__(self):
                 raise ex
         mapper(Foo, users)
-        
+
         try:
             Foo()
             assert False
@@ -179,7 +179,7 @@ class MapperTest(MapperSuperTest):
                 object_session(self).expunge(self)
                 raise ex
 
-        mapper(Bar, orders)
+        mapper(Bar, orders, extension=SessionContextExt(SessionContext()))
 
         try:
             Bar(_sa_session=sess)