From 481cb0873fe7910387ec05333d6524fba4d352fe Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sun, 27 May 2007 03:20:11 +0000 Subject: [PATCH] - most of the __init__ decoration has been removed from mapper, save for that the mappers all get compiled when an instance of a mapped class is first constructed. the SessionContextExt extension gets all the "add object to the session" logic now and the _sa_session and _sa_entity_name arguments only apply to when the SessionContextExt is in use. Some extra methods to MapperExtension to support __init__ decoration. - assignmapper loses "join_to", gains "join". id like to replace all those methods with just "query" but i think they are too popular, so it should probably get filter(), filter_by() also. --- lib/sqlalchemy/ext/assignmapper.py | 8 ++-- lib/sqlalchemy/ext/sessioncontext.py | 25 ++++++++++- lib/sqlalchemy/orm/__init__.py | 4 +- lib/sqlalchemy/orm/mapper.py | 66 +++++++++++++--------------- lib/sqlalchemy/orm/query.py | 4 +- test/ext/activemapper.py | 5 +-- test/ext/selectresults.py | 2 +- test/orm/inheritance/basic.py | 9 ++-- test/orm/inheritance/manytomany.py | 13 +++--- test/orm/mapper.py | 6 +-- 10 files changed, 80 insertions(+), 62 deletions(-) diff --git a/lib/sqlalchemy/ext/assignmapper.py b/lib/sqlalchemy/ext/assignmapper.py index b7d5411b65..7886f4d272 100644 --- a/lib/sqlalchemy/ext/assignmapper.py +++ b/lib/sqlalchemy/ext/assignmapper.py @@ -14,15 +14,13 @@ def monkeypatch_query_method(ctx, class_, name): def monkeypatch_objectstore_method(ctx, class_, name): def do(self, *args, **kwargs): session = ctx.current - if name == "flush": - # flush expects a list of objects - self = [self] return getattr(session, name)(self, *args, **kwargs) try: do.__name__ = name except: pass setattr(class_, name, do) + def assign_mapper(ctx, class_, *args, **kwargs): validate = kwargs.pop('validate', False) @@ -43,9 +41,9 @@ def assign_mapper(ctx, class_, *args, **kwargs): m = mapper(class_, extension=extension, *args, **kwargs) class_.mapper = m class_.query = classmethod(lambda cls: Query(class_, session=ctx.current)) - for name in ['get', 'select', 'select_by', 'selectfirst', 'selectfirst_by', 'selectone', 'get_by', 'join_to', 'join_via', 'count', 'count_by', 'options', 'instances']: + for name in ['get', 'select', 'select_by', 'selectfirst', 'selectfirst_by', 'selectone', 'get_by', 'join', 'count', 'count_by', 'options', 'instances']: monkeypatch_query_method(ctx, class_, name) - for name in ['flush', 'delete', 'expire', 'refresh', 'expunge', 'merge', 'save', 'update', 'save_or_update']: + for name in ['delete', 'expire', 'refresh', 'expunge', 'save', 'update', 'save_or_update']: monkeypatch_objectstore_method(ctx, class_, name) return m diff --git a/lib/sqlalchemy/ext/sessioncontext.py b/lib/sqlalchemy/ext/sessioncontext.py index 2f81e55d2c..5a0e0afc77 100644 --- a/lib/sqlalchemy/ext/sessioncontext.py +++ b/lib/sqlalchemy/ext/sessioncontext.py @@ -1,5 +1,6 @@ from sqlalchemy.util import ScopedRegistry -from sqlalchemy.orm.mapper import MapperExtension +from sqlalchemy.orm.mapper import MapperExtension, EXT_PASS +from sqlalchemy.orm import create_session __all__ = ['SessionContext', 'SessionContextExt'] @@ -24,7 +25,9 @@ class SessionContext(object): # be created on the next call to context.current) """ - def __init__(self, session_factory, scopefunc=None): + def __init__(self, session_factory=None, scopefunc=None): + if session_factory is None: + session_factory = create_session self.registry = ScopedRegistry(session_factory, scopefunc) super(SessionContext, self).__init__() @@ -60,3 +63,21 @@ class SessionContextExt(MapperExtension): def get_session(self): return self.context.current + + def init_instance(self, mapper, class_, instance, args, kwargs): + session = kwargs.pop('_sa_session', self.context.current) + session._save_impl(instance, entity_name=kwargs.pop('_sa_entity_name', None)) + return EXT_PASS + + def init_failed(self, mapper, class_, instance, args, kwargs): + object_session(instance).expunge(instance) + return EXT_PASS + + def dispose_class(self, mapper, class_): + if hasattr(class_, '__init__') and hasattr(class_.__init__, '_oldinit'): + if class_.__init__._oldinit is not None: + class_.__init__ = class_.__init__._oldinit + else: + delattr(class_, '__init__') + + \ No newline at end of file diff --git a/lib/sqlalchemy/orm/__init__.py b/lib/sqlalchemy/orm/__init__.py index a1fa1726e8..ce22f46239 100644 --- a/lib/sqlalchemy/orm/__init__.py +++ b/lib/sqlalchemy/orm/__init__.py @@ -118,9 +118,7 @@ def clear_mappers(): """ for mapper in mapper_registry.values(): - attribute_manager.reset_class_managed(mapper.class_) - if hasattr(mapper.class_, 'c'): - del mapper.class_.c + mapper.dispose() mapper_registry.clear() sautil.ArgSingleton.instances.clear() diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 7453d1caa5..223c270b0f 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -319,6 +319,16 @@ class Mapper(object): props = property(_get_props, doc="compiles this mapper if needed, and returns the " "dictionary of MapperProperty objects associated with this mapper.") + def dispose(self): + attribute_manager.reset_class_managed(self.class_) + if hasattr(self.class_, 'c'): + del self.class_.c + if hasattr(self.class_, '__init__') and hasattr(self.class_.__init__, '_oldinit'): + if self.class_.__init__._oldinit is not None: + self.class_.__init__ = self.class_.__init__._oldinit + else: + delattr(self.class_, '__init__') + def compile(self): """Compile this mapper into its final internal format. @@ -574,7 +584,7 @@ class Mapper(object): this method is called repeatedly during the compilation process as the resulting dictionary contains more equivalents as more inheriting - mappers are compiled. the repetition of this process may be open to some optimization. + mappers are compiled. the repetition process may be open to some optimization. """ result = {} @@ -723,54 +733,31 @@ class Mapper(object): attribute_manager.reset_class_managed(self.class_) oldinit = self.class_.__init__ - def init(self, *args, **kwargs): - entity_name = kwargs.pop('_sa_entity_name', None) - mapper = mapper_registry.get(ClassKey(self.__class__, entity_name)) - if mapper is not None: - mapper = mapper.compile() - - # this gets the AttributeManager to do some pre-initialization, - # in order to save on KeyErrors later on - attribute_manager.init_attr(self) - - if kwargs.has_key('_sa_session'): - session = kwargs.pop('_sa_session') - else: - # works for whatever mapper the class is associated with - if mapper is not None: - session = mapper.extension.get_session() - if session is EXT_PASS: - session = None - else: - session = None - # if a session was found, either via _sa_session or via mapper extension, - # and we have found a mapper, save() this instance to the session, and give it an associated entity_name. - # otherwise, this instance will not have a session or mapper association until it is - # save()d to some session. - if session is not None and mapper is not None: - self._entity_name = entity_name - session._register_pending(self) + def init(instance, *args, **kwargs): + self.compile() + self.extension.init_instance(self, self.class_, instance, args, kwargs) if oldinit is not None: try: - oldinit(self, *args, **kwargs) + oldinit(instance, *args, **kwargs) except Exception, e: try: - if session is not None: - session.expunge(self) + self.extension.init_failed(self, self.class_, instance, args, kwargs) except: pass # raise original exception instead raise e - # override oldinit, insuring that its not already a Mapper-decorated init method - if oldinit is None or not hasattr(oldinit, '_sa_mapper_init'): - init._sa_mapper_init = True + + # override oldinit, ensuring that its not already a Mapper-decorated init method + if oldinit is None or not hasattr(oldinit, '_oldinit'): try: init.__name__ = oldinit.__name__ init.__doc__ = oldinit.__doc__ except: # cant set __name__ in py 2.3 ! pass + init._oldinit = oldinit self.class_.__init__ = init + mapper_registry[self.class_key] = self if self.entity_name is None: self.class_.c = self.c @@ -1629,6 +1616,12 @@ class MapperExtension(object): is not overridden. """ + def init_instance(self, mapper, class_, instance, args, kwargs): + return EXT_PASS + + def init_failed(self, mapper, class_, instance, args, kwargs): + return EXT_PASS + def get_session(self): """Retrieve a contextual Session instance with which to register a new object. @@ -1830,7 +1823,10 @@ class _ExtensionCarrier(MapperExtension): else: return EXT_PASS return _do - + + init_instance = _create_do('init_instance') + init_failed = _create_do('init_failed') + dispose_class = _create_do('dispose_class') get_session = _create_do('get_session') load = _create_do('load') get = _create_do('get') diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 61a37d435f..81dbe07b8a 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -179,7 +179,7 @@ class Query(object): ``select_by()`` method. """ - return self._join_by(args, params) + return self._join_by(args, params, start=self._joinpoint) def join_to(self, key): @@ -201,7 +201,7 @@ class Query(object): mapper. """ - mapper = self.mapper + mapper = self._joinpoint clause = None for key in keys: prop = mapper.props[key] diff --git a/test/ext/activemapper.py b/test/ext/activemapper.py index ebb832fdcd..25e20168c9 100644 --- a/test/ext/activemapper.py +++ b/test/ext/activemapper.py @@ -253,9 +253,8 @@ class testcase(testbase.PersistTest): objectstore.flush() objectstore.clear() - results = Person.select( - Address.c.postal_code.like('30075') & - Person.join_to('addresses') + results = Person.join('addresses').select( + Address.c.postal_code.like('30075') ) self.assertEquals(len(results), 1) diff --git a/test/ext/selectresults.py b/test/ext/selectresults.py index 8df416be94..5af61b6a70 100644 --- a/test/ext/selectresults.py +++ b/test/ext/selectresults.py @@ -82,7 +82,7 @@ class SelectResultsTest(PersistTest): def test_options(self): class ext1(MapperExtension): - def populate_instance(self, mapper, selectcontext, row, instance, identitykey, isnew): + def populate_instance(self, mapper, selectcontext, row, instance, **flags): instance.TEST = "hello world" return EXT_PASS objectstore.clear() diff --git a/test/orm/inheritance/basic.py b/test/orm/inheritance/basic.py index 88344c17b1..f95390008e 100644 --- a/test/orm/inheritance/basic.py +++ b/test/orm/inheritance/basic.py @@ -43,9 +43,12 @@ class O2MTest(testbase.ORMTest): }) sess = create_session() - b1 = Blub("blub #1", _sa_session=sess) - b2 = Blub("blub #2", _sa_session=sess) - f = Foo("foo #1", _sa_session=sess) + b1 = Blub("blub #1") + b2 = Blub("blub #2") + f = Foo("foo #1") + sess.save(b1) + sess.save(b2) + sess.save(f) b1.parent_foo = f b2.parent_foo = f sess.flush() diff --git a/test/orm/inheritance/manytomany.py b/test/orm/inheritance/manytomany.py index f97b8ed0d5..b5cc83e7b4 100644 --- a/test/orm/inheritance/manytomany.py +++ b/test/orm/inheritance/manytomany.py @@ -193,7 +193,8 @@ class InheritTest3(testbase.ORMTest): }) sess = create_session() - b = Bar('bar #1', _sa_session=sess) + b = Bar('bar #1') + sess.save(b) b.foos.append(Foo("foo #1")) b.foos.append(Foo("foo #2")) sess.flush() @@ -226,10 +227,12 @@ class InheritTest3(testbase.ORMTest): }) sess = create_session() - f1 = Foo("foo #1", _sa_session=sess) - b1 = Bar("bar #1", _sa_session=sess) - b2 = Bar("bar #2", _sa_session=sess) - bl1 = Blub("blub #1", _sa_session=sess) + f1 = Foo("foo #1") + b1 = Bar("bar #1") + b2 = Bar("bar #2") + bl1 = Blub("blub #1") + for o in (f1, b1, b2, bl1): + sess.save(o) bl1.foos.append(f1) bl1.bars.append(b2) sess.flush() diff --git a/test/orm/mapper.py b/test/orm/mapper.py index de4dfa0ce0..8dd9c0cca3 100644 --- a/test/orm/mapper.py +++ b/test/orm/mapper.py @@ -3,7 +3,7 @@ import testbase import unittest, sys, os from sqlalchemy import * import sqlalchemy.exceptions as exceptions -from sqlalchemy.ext.sessioncontext import SessionContext +from sqlalchemy.ext.sessioncontext import SessionContext, SessionContextExt from tables import * import tables @@ -167,7 +167,7 @@ class MapperTest(MapperSuperTest): def __init__(self): raise ex mapper(Foo, users) - + try: Foo() assert False @@ -179,7 +179,7 @@ class MapperTest(MapperSuperTest): object_session(self).expunge(self) raise ex - mapper(Bar, orders) + mapper(Bar, orders, extension=SessionContextExt(SessionContext())) try: Bar(_sa_session=sess) -- 2.47.3