]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Added query_cls= override to scoped_session's query_property
authorJason Kirtland <jek@discorporate.us>
Sat, 27 Sep 2008 01:37:26 +0000 (01:37 +0000)
committerJason Kirtland <jek@discorporate.us>
Sat, 27 Sep 2008 01:37:26 +0000 (01:37 +0000)
CHANGES
lib/sqlalchemy/orm/scoping.py
test/engine/reconnect.py
test/orm/memusage.py
test/orm/scoping.py

diff --git a/CHANGES b/CHANGES
index 5fb8861cb042c8204492ff7866a076cc71756225..abd018871c088d5ee60db876f15898f0478b6968 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -21,7 +21,10 @@ CHANGES
       upon reentrant mapper compile() calls, something that
       occurs when using declarative constructs inside of
       ForeignKey objects.
-      
+
+    - ScopedSession.query_property now accepts a query_cls factory,
+      overriding the session's configured query_cls.
+
 - sql
     - column.in_(someselect) can now be used as 
       a columns-clause expression without the subquery
index 422b362633c06df4543213334e6846cf3520ab55..5dd17a2898fe2ea06577b395d4bc5a632526d056 100644 (file)
@@ -77,7 +77,7 @@ class ScopedSession(object):
 
         self.session_factory.configure(**kwargs)
 
-    def query_property(self):
+    def query_property(self, query_cls=None):
         """return a class property which produces a `Query` object against the
         class when called.
 
@@ -90,13 +90,26 @@ class ScopedSession(object):
             # after mappers are defined
             result = MyClass.query.filter(MyClass.name=='foo').all()
 
-        """
+        Produces instances of the session's configured query class by
+        default.  To override and use a custom implementation, provide
+        a ``query_cls`` callable.  The callable will be invoked with
+        the class's mapper as a positional argument and a session
+        keyword argument.
+
+        There is no limit to the number of query properties placed on
+        a class.
 
+        """
         class query(object):
             def __get__(s, instance, owner):
                 mapper = class_mapper(owner, raiseerror=False)
                 if mapper:
-                    return self.registry().query(mapper)
+                    if query_cls:
+                        # custom query class
+                        return query_cls(mapper, session=self.registry())
+                    else:
+                        # session's configured query class
+                        return self.registry().query(mapper)
                 else:
                     return None
         return query()
index d50267a1f24ea2bc134a91d56eaa923ae90a8327..f08dbcdd540c9b289207038f2c6274db165821b3 100644 (file)
@@ -46,7 +46,7 @@ class MockCursor(object):
         pass
 
 db, dbapi = None, None
-class MockReconnectTest(TestBase):
+class MockReconnectTest(object):
     def setUp(self):
         global db, dbapi
         dbapi = MockDBAPI()
@@ -175,7 +175,7 @@ class MockReconnectTest(TestBase):
         assert len(dbapi.connections) == 1
 
 engine = None
-class RealReconnectTest(TestBase):
+class RealReconnectTest(object):
     def setUp(self):
         global engine
         engine = engines.reconnecting_engine()
@@ -281,7 +281,7 @@ class RealReconnectTest(TestBase):
         self.assertEquals(conn.execute(select([1])).scalar(), 1)
         assert not conn.invalidated
 
-class RecycleTest(TestBase):
+class RecycleTest(object):
     def test_basic(self):
         for threadlocal in (False, True):
             engine = engines.reconnecting_engine(options={'pool_recycle':1, 'pool_threadlocal':threadlocal})
index 6348f22371be9de97029916d47ffe58059a47391..06e2c1423885735bb5faabe58a40372a67bd3a68 100644 (file)
@@ -54,7 +54,7 @@ class EnsureZeroed(_base.ORMTest):
         _sessions.clear()
         _mapper_registry.clear()
 
-class MemUsageTest(EnsureZeroed):
+class MemUsageTest(object):
     
     # ensure a pure growing test trips the assertion
     @testing.fails_if(lambda:True)
index cdc0c16b413c83a7efdb44af0ffe0c2c5ec273fb..32e0dedb0e56e9201544e98af6e037214cc088fc 100644 (file)
@@ -2,7 +2,7 @@ import testenv; testenv.configure_for_tests()
 from testlib import sa, testing
 from sqlalchemy.orm import scoped_session
 from testlib.sa import Table, Column, Integer, String, ForeignKey
-from testlib.sa.orm import mapper, relation
+from testlib.sa.orm import mapper, relation, query
 from testlib.testing import eq_
 from orm import _base
 
@@ -38,10 +38,14 @@ class ScopedSessionTest(_base.MappedTest):
     def test_basic(self):
         Session = scoped_session(sa.orm.sessionmaker())
 
+        class CustomQuery(query.Query):
+            pass
+
         class SomeObject(_base.ComparableEntity):
             query = Session.query_property()
         class SomeOtherObject(_base.ComparableEntity):
             query = Session.query_property()
+            custom_query = Session.query_property(query_cls=CustomQuery)
 
         mapper(SomeObject, table1, properties={
             'options':relation(SomeOtherObject)})
@@ -62,6 +66,9 @@ class ScopedSessionTest(_base.MappedTest):
         eq_(SomeOtherObject(someid=1),
             SomeOtherObject.query.filter(
                 SomeOtherObject.someid == sso.someid).one())
+        assert isinstance(SomeOtherObject.query, query.Query)
+        assert not isinstance(SomeOtherObject.query, CustomQuery)
+        assert isinstance(SomeOtherObject.custom_query, query.Query)
 
 
 class ScopedMapperTest(_ScopedTest):