]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- removed enhance_classes from scoped_session, replaced with
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 3 Aug 2007 19:31:38 +0000 (19:31 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 3 Aug 2007 19:31:38 +0000 (19:31 +0000)
scoped_session(...).mapper.  'mapper' essentially does the same
thing as assign_mapper less verbosely.
- adapted assignmapper unit tests into scoped_session tests

lib/sqlalchemy/orm/scoping.py
test/orm/session.py
test/orm/unitofwork.py

index 96d9a23fc0f19d8cd2d6fb6151bbd9b07feaf786..5d11a99a4ae57dd2cbdc1d74a45dc6a53587849e 100644 (file)
@@ -1,4 +1,4 @@
-from sqlalchemy.util import ScopedRegistry, warn_deprecated
+from sqlalchemy.util import ScopedRegistry, warn_deprecated, to_list
 from sqlalchemy.orm import MapperExtension, EXT_CONTINUE
 from sqlalchemy.orm.session import Session
 from sqlalchemy.orm.mapper import global_extensions
@@ -13,16 +13,21 @@ class ScopedSession(object):
 
     Usage::
 
-      Session = scoped_session(sessionmaker(autoflush=True), enhance_classes=True)
+      Session = scoped_session(sessionmaker(autoflush=True))
+      
+      To map classes so that new instances are saved in the current
+      Session automatically, as well as to provide session-aware
+      class attributes such as "query":
+      
+      mapper = Session.mapper
+      mapper(Class, table, ...)
 
     """
 
-    def __init__(self, session_factory, scopefunc=None, enhance_classes=False):
+    def __init__(self, session_factory, scopefunc=None):
         self.session_factory = session_factory
-        self.enhance_classes = enhance_classes
         self.registry = ScopedRegistry(session_factory, scopefunc)
-        if self.enhance_classes:
-            global_extensions.append(_ScopedExt(self))
+        self.extension = _ScopedExt(self)
 
     def __call__(self, **kwargs):
         if kwargs:
@@ -39,15 +44,28 @@ class ScopedSession(object):
         else:
             return self.registry()
 
+    def mapper(self, *args, **kwargs):
+        """return a mapper() function which associates this ScopedSession with the Mapper."""
+        
+        from sqlalchemy.orm import mapper
+        validate = kwargs.pop('validate', False)
+        extension = to_list(kwargs.setdefault('extension', []))
+        if validate:
+            extension.append(self.extension.validating())
+        else:
+            extension.append(self.extension)
+        return mapper(*args, **kwargs)
+        
     def configure(self, **kwargs):
-        """reconfigure the sessionmaker used by this SessionContext"""
+        """reconfigure the sessionmaker used by this ScopedSession."""
+        
         self.session_factory.configure(**kwargs)
 
 def instrument(name):
     def do(self, *args, **kwargs):
         return getattr(self.registry(), name)(*args, **kwargs)
     return do
-for meth in ('get', 'close', 'save', 'commit', 'update', 'flush', 'query', 'delete'):
+for meth in ('get', 'close', 'save', 'commit', 'update', 'flush', 'query', 'delete', 'clear'):
     setattr(ScopedSession, meth, instrument(meth))
 
 def makeprop(name):
@@ -67,18 +85,22 @@ for prop in ('close_all',):
     setattr(ScopedSession, prop, clslevel(prop))
     
 class _ScopedExt(MapperExtension):
-    def __init__(self, context):
+    def __init__(self, context, validate=False):
         self.context = context
+        self.validate = validate
+    
+    def validating(self):
+        return _ScopedExt(self.context, validate=True)
         
     def get_session(self):
         return self.context.registry()
 
     def instrument_class(self, mapper, class_):
         class query(object):
-            def __getattr__(self, key):
-                return getattr(registry().query(class_), key)
-            def __call__(self):
-                return registry().query(class_)
+            def __getattr__(s, key):
+                return getattr(self.context.registry().query(class_), key)
+            def __call__(s):
+                return self.context.registry().query(class_)
 
         if not hasattr(class_, 'query'): 
             class_.query = query()
@@ -87,9 +109,9 @@ class _ScopedExt(MapperExtension):
         session = kwargs.pop('_sa_session', self.context.registry())
         if not isinstance(oldinit, types.MethodType):
             for key, value in kwargs.items():
-                #if validate:
-                #    if not self.mapper.get_property(key, resolve_synonyms=False, raiseerr=False):
-                #        raise exceptions.ArgumentError("Invalid __init__ argument: '%s'" % key)
+                if self.validate:
+                    if not mapper.get_property(key, resolve_synonyms=False, raiseerr=False):
+                        raise exceptions.ArgumentError("Invalid __init__ argument: '%s'" % key)
                 setattr(instance, key, value)
         session._save_impl(instance, entity_name=kwargs.pop('_sa_entity_name', None))
         return EXT_CONTINUE
index d3eed5c5707181d6c48bcd9b68366ead04211b1a..0b56b84d4fe7c3d68ec3a940e1150cb2313a2564 100644 (file)
@@ -4,7 +4,6 @@ from sqlalchemy.orm import *
 from testlib import *
 from testlib.tables import *
 import testlib.tables as tables
-from sqlalchemy.orm.session import Session
 
 class SessionTest(AssertMixin):
     def setUpAll(self):
@@ -98,7 +97,7 @@ class SessionTest(AssertMixin):
         conn1 = testbase.db.connect()
         conn2 = testbase.db.connect()
         
-        sess = Session(bind=conn1, transactional=True, autoflush=True)
+        sess = create_session(bind=conn1, transactional=True, autoflush=True)
         u = User()
         u.user_name='ed'
         sess.save(u)
@@ -116,7 +115,7 @@ class SessionTest(AssertMixin):
         mapper(User, users)
 
         try:
-            sess = Session(transactional=True, autoflush=True)
+            sess = create_session(transactional=True, autoflush=True)
             u = User()
             u.user_name='ed'
             sess.save(u)
@@ -137,7 +136,7 @@ class SessionTest(AssertMixin):
         conn1 = testbase.db.connect()
         conn2 = testbase.db.connect()
         
-        sess = Session(bind=conn1, transactional=True, autoflush=True)
+        sess = create_session(bind=conn1, transactional=True, autoflush=True)
         u = User()
         u.user_name='ed'
         sess.save(u)
@@ -153,7 +152,7 @@ class SessionTest(AssertMixin):
             'addresses':relation(Address)
         })
         
-        sess = Session(transactional=True, autoflush=True)
+        sess = create_session(transactional=True, autoflush=True)
         u = sess.query(User).get(8)
         newad = Address()
         newad.email_address == 'something new'
@@ -173,7 +172,7 @@ class SessionTest(AssertMixin):
         mapper(User, users)
         conn = testbase.db.connect()
         trans = conn.begin()
-        sess = Session(conn, transactional=True, autoflush=True)
+        sess = create_session(bind=conn, transactional=True, autoflush=True)
         sess.begin() 
         u = User()
         sess.save(u)
@@ -189,7 +188,7 @@ class SessionTest(AssertMixin):
         try:
             conn = testbase.db.connect()
             trans = conn.begin()
-            sess = Session(conn, transactional=True, autoflush=True)
+            sess = create_session(bind=conn, transactional=True, autoflush=True)
             u1 = User()
             sess.save(u1)
             sess.flush()
@@ -217,7 +216,7 @@ class SessionTest(AssertMixin):
         mapper(Address, addresses)
         
         engine2 = create_engine(testbase.db.url)
-        sess = Session(transactional=False, autoflush=False, twophase=True)
+        sess = create_session(transactional=False, autoflush=False, twophase=True)
         sess.bind_mapper(User, testbase.db)
         sess.bind_mapper(Address, engine2)
         sess.begin()
@@ -234,7 +233,7 @@ class SessionTest(AssertMixin):
     def test_joined_transaction(self):
         class User(object):pass
         mapper(User, users)
-        sess = Session(transactional=True, autoflush=True)
+        sess = create_session(transactional=True, autoflush=True)
         sess.begin()  
         u = User()
         sess.save(u)
@@ -440,6 +439,75 @@ class SessionTest(AssertMixin):
         key = s.identity_key(User, row=row, entity_name="en")
         self._assert_key(key, (User, (1,), "en"))
         
+class ScopedSessionTest(PersistTest):
+    def setUpAll(self):
+        global metadata, table, table2
+        metadata = MetaData(testbase.db)
+        table = Table('sometable', metadata, 
+            Column('id', Integer, primary_key=True),
+            Column('data', String(30)))
+        table2 = Table('someothertable', metadata, 
+            Column('id', Integer, primary_key=True),
+            Column('someid', None, ForeignKey('sometable.id'))
+            )
+        metadata.create_all()
+
+    def setUp(self):
+        global SomeObject, SomeOtherObject
+        class SomeObject(object):pass
+        class SomeOtherObject(object):pass
         
+        global Session
+        
+        Session = scoped_session(create_session)
+        Session.mapper(SomeObject, table, properties={
+            'options':relation(SomeOtherObject)
+        })
+        Session.mapper(SomeOtherObject, table2)
+
+        s = SomeObject()
+        s.id = 1
+        s.data = 'hello'
+        sso = SomeOtherObject()
+        s.options.append(sso)
+        Session.flush()
+        Session.clear()
+
+    def tearDownAll(self):
+        metadata.drop_all()
+        
+    def tearDown(self):
+        for table in metadata.table_iterator(reverse=True):
+            table.delete().execute()
+        clear_mappers()
+
+    def test_query(self):
+        sso = SomeOtherObject.query().first()
+        assert SomeObject.query.filter_by(id=1).one().options[0].id == sso.id
+
+    def test_validating_constructor(self):
+        s2 = SomeObject(someid=12)
+        s3 = SomeOtherObject(someid=123, bogus=345)
+
+        class ValidatedOtherObject(object):pass
+        Session.mapper(ValidatedOtherObject, table2, validate=True)
+
+        v1 = ValidatedOtherObject(someid=12)
+        try:
+            v2 = ValidatedOtherObject(someid=12, bogus=345)
+            assert False
+        except exceptions.ArgumentError:
+            pass
+
+    def test_dont_clobber_methods(self):
+        class MyClass(object):
+            def expunge(self):
+                return "an expunge !"
+
+        Session.mapper(MyClass, table2)
+
+        assert MyClass().expunge() == "an expunge !"
+    
+
 if __name__ == "__main__":    
     testbase.main()
index f28e428ed981c3df5081673b2fad7f6b3290fc75..f065267f7d33efed3e270e1bc18dcf91dac78c86 100644 (file)
@@ -12,8 +12,9 @@ from testlib import tables
 
 class UnitOfWorkTest(AssertMixin):
     def setUpAll(self):
-        global Session
-        Session = scoped_session(sessionmaker(autoflush=True, transactional=True), enhance_classes=True)
+        global Session, mapper
+        Session = scoped_session(sessionmaker(autoflush=True, transactional=True))
+        mapper = Session.mapper
     def tearDownAll(self):
         global_extensions[:] = []
     def tearDown(self):