]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- moved another chunk of mapper.py tests to query.py test suite
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 5 Jun 2007 23:12:03 +0000 (23:12 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 5 Jun 2007 23:12:03 +0000 (23:12 +0000)
- got all tests/extensions working with new APIs
- axed proxyengine until further notice
- SelectResults folds into a 10 line wrapper for Query, loses join_to() (use join())
- test cleanup

23 files changed:
CHANGES
doc/build/content/plugins.txt
lib/sqlalchemy/ext/activemapper.py
lib/sqlalchemy/ext/proxy.py [deleted file]
lib/sqlalchemy/ext/selectresults.py
lib/sqlalchemy/ext/sqlsoup.py
test/engine/alltests.py
test/engine/proxy_engine.py [deleted file]
test/engine/transaction.py
test/ext/activemapper.py
test/ext/assignmapper.py
test/ext/associationproxy.py
test/ext/orderinglist.py
test/ext/selectresults.py
test/orm/eagertest3.py
test/orm/inheritance/polymorph2.py
test/orm/mapper.py
test/orm/query.py
test/sql/defaults.py
test/sql/unicode.py
test/testbase.py
test/zblog/mappers.py
test/zblog/tests.py

diff --git a/CHANGES b/CHANGES
index 615d150c9daf396cbd79b3aa22c1c8512678a2ec..5f4e67e01354760002d98e7db09386ef4bb07d23 100644 (file)
--- a/CHANGES
+++ b/CHANGES
       argument, which can be set to 'select' or 'deferred'
     - added undefer_group() MapperOption, sets a set of "deferred"
       columns joined by a "group" to load as "undeferred".
-
 - sql
   - long-identifier detection fixed to use > rather than >= for 
     max ident length [ticket:589]
+
 - mysql
   - added 'fields' to reserved words [ticket:590]
-    
+- extensions
+  - proxyengine is temporarily removed, pending an actually working
+    replacement.
+  - SelectResults has been replaced by Query.  SelectResults / 
+    SelectResultsExt still exist but just return a slightly modified
+    Query object for backwards-compatibility.  join_to() method 
+    from SelectResults isn't present anymore, need to use join(). 
+      
 0.3.8
 - engines
   - added detach() to Connection, allows underlying DBAPI connection
index b4a0bebae9f68e3517eb6bce46041e66ecc6850e..2071b13859d6116c2a76564927c5fd13d485d062 100644 (file)
@@ -600,13 +600,4 @@ Full SqlSoup documentation is on the [SQLAlchemy Wiki](http://www.sqlalchemy.org
 
 ### ProxyEngine
 
-**Author:** Jason Pellerin
-
-The `ProxyEngine` is used to "wrap" an `Engine`, and via subclassing `ProxyEngine` one can instrument the functionality of an arbitrary `Engine` instance through the decorator pattern.  It also provides a `connect()` method which will send all `Engine` requests to different underlying engines.  Its functionality in that regard is largely superceded now by `DynamicMetaData` which is a better solution.
-
-    {python}
-    from sqlalchemy.ext.proxy import ProxyEngine
-    proxy = ProxyEngine()
-    
-    proxy.connect('postgres://user:pw@host/db')
-
+This extension is temporarily removed from the 0.4 series of SQLAlchemy.  A newer method of SQL instrumentation will eventually be re-introduced.
index 004caf84946038e4e31fe0532a366373981c5f47..3ba98f3457d6e5a3899ee9591b139b53e7318460 100644 (file)
@@ -1,11 +1,11 @@
-from sqlalchemy             import create_session, relation, mapper, \
-                                   join, DynamicMetaData, class_mapper, \
-                                   util, Integer
+from sqlalchemy             import join, DynamicMetaData, util, Integer
 from sqlalchemy             import and_, or_
 from sqlalchemy             import Table, Column, ForeignKey
+from sqlalchemy.orm         import class_mapper, relation, mapper, create_session
+                                   
 from sqlalchemy.ext.sessioncontext import SessionContext
 from sqlalchemy.ext.assignmapper import assign_mapper
-from sqlalchemy import backref as create_backref
+from sqlalchemy.orm import backref as create_backref
 import sqlalchemy
 
 import inspect
diff --git a/lib/sqlalchemy/ext/proxy.py b/lib/sqlalchemy/ext/proxy.py
deleted file mode 100644 (file)
index b81702f..0000000
+++ /dev/null
@@ -1,113 +0,0 @@
-try:
-    from threading import local
-except ImportError:
-    from sqlalchemy.util import ThreadLocal as local
-
-from sqlalchemy import sql
-from sqlalchemy.engine import create_engine, Engine
-
-__all__ = ['BaseProxyEngine', 'AutoConnectEngine', 'ProxyEngine']
-
-class BaseProxyEngine(sql.Executor):
-    """Basis for all proxy engines."""
-
-    def get_engine(self):
-        raise NotImplementedError
-
-    def set_engine(self, engine):
-        raise NotImplementedError
-
-    engine = property(lambda s:s.get_engine(), lambda s,e:s.set_engine(e))
-
-    def execute_compiled(self, *args, **kwargs):
-        """Override superclass behaviour.
-
-        This method is required to be present as it overrides the
-        `execute_compiled` present in ``sql.Engine``.
-        """
-
-        return self.get_engine().execute_compiled(*args, **kwargs)
-
-    def compiler(self, *args, **kwargs):
-        """Override superclass behaviour.
-
-        This method is required to be present as it overrides the
-        `compiler` method present in ``sql.Engine``.
-        """
-
-        return self.get_engine().compiler(*args, **kwargs)
-
-    def __getattr__(self, attr):
-        """Provide proxying for methods that are not otherwise present on this ``BaseProxyEngine``.
-
-        Note that methods which are present on the base class
-        ``sql.Engine`` will **not** be proxied through this, and must
-        be explicit on this class.
-        """
-
-        # call get_engine() to give subclasses a chance to change
-        # connection establishment behavior
-        e = self.get_engine()
-        if e is not None:
-            return getattr(e, attr)
-        raise AttributeError("No connection established in ProxyEngine: "
-                             " no access to %s" % attr)
-
-class AutoConnectEngine(BaseProxyEngine):
-    """An SQLEngine proxy that automatically connects when necessary."""
-
-    def __init__(self, dburi, **kwargs):
-        BaseProxyEngine.__init__(self)
-        self.dburi = dburi
-        self.kwargs = kwargs
-        self._engine = None
-
-    def get_engine(self):
-        if self._engine is None:
-            if callable(self.dburi):
-                dburi = self.dburi()
-            else:
-                dburi = self.dburi
-            self._engine = create_engine(dburi, **self.kwargs)
-        return self._engine
-
-
-class ProxyEngine(BaseProxyEngine):
-    """Engine proxy for lazy and late initialization.
-
-    This engine will delegate access to a real engine set with connect().
-    """
-
-    def __init__(self, **kwargs):
-        BaseProxyEngine.__init__(self)
-        # create the local storage for uri->engine map and current engine
-        self.storage = local()
-        self.kwargs = kwargs
-
-    def connect(self, *args, **kwargs):
-        """Establish connection to a real engine."""
-
-        kwargs.update(self.kwargs)
-        if not kwargs:
-            key = repr(args)
-        else:
-            key = "%s, %s" % (repr(args), repr(sorted(kwargs.items())))
-        try:
-            map = self.storage.connection
-        except AttributeError:
-            self.storage.connection = {}
-            self.storage.engine = None
-            map = self.storage.connection
-        try:
-            self.storage.engine = map[key]
-        except KeyError:
-            map[key] = create_engine(*args, **kwargs)
-            self.storage.engine = map[key]
-
-    def get_engine(self):
-        if not hasattr(self.storage, 'engine') or self.storage.engine is None:
-            raise AttributeError("No connection established")
-        return self.storage.engine
-
-    def set_engine(self, engine):
-        self.storage.engine = engine
index 68538f3cb4e6065f6bce44f3fc5b98ff55ca3fa5..1920b6f924a7b0f0d5b054caf9e57c943dd562e5 100644 (file)
+"""SelectResults has been rolled into Query.  This class is now just a placeholder."""
+
 import sqlalchemy.sql as sql
 import sqlalchemy.orm as orm
 
 class SelectResultsExt(orm.MapperExtension):
     """a MapperExtension that provides SelectResults functionality for the
     results of query.select_by() and query.select()"""
+    
     def select_by(self, query, *args, **params):
-        return SelectResults(query, query.join_by(*args, **params))
+        q = query
+        for a in args:
+            q = q.filter(a)
+        return q.filter_by(**params)
+        
     def select(self, query, arg=None, **kwargs):
         if isinstance(arg, sql.FromClause) and arg.supports_execution():
             return orm.EXT_PASS
         else:
-            return SelectResults(query, arg, ops=kwargs)
-
-class SelectResults(object):
-    """Build a query one component at a time via separate method
-    calls, each call transforming the previous ``SelectResults``
-    instance into a new ``SelectResults`` instance with further
-    limiting criterion added. When interpreted in an iterator context
-    (such as via calling ``list(selectresults)``), executes the query.
-    """
-
-    def __init__(self, query, clause=None, ops={}, joinpoint=None):
-        """Construct a new ``SelectResults`` using the given ``Query``
-        object and optional ``WHERE`` clause.  `ops` is an optional
-        dictionary of bind parameter values.
-        """
-
-        self._query = query
-        self._clause = clause
-        self._ops = {}
-        self._ops.update(ops)
-        self._joinpoint = joinpoint or (self._query.table, self._query.mapper)
-
-    def options(self,*args, **kwargs):
-        """Apply mapper options to the underlying query.
-
-        See also ``Query.options``.
-        """
-
-        new = self.clone()
-        new._query = new._query.options(*args, **kwargs)
-        return new
-
-    def count(self):
-        """Execute the SQL ``count()`` function against the ``SelectResults`` criterion."""
-
-        return self._query.count(self._clause, **self._ops)
-
-    def _col_aggregate(self, col, func):
-        """Execute ``func()`` function against the given column.
-
-        For performance, only use subselect if `order_by` attribute is set.
-        """
-
-        if self._ops.get('order_by'):
-            s1 = sql.select([col], self._clause, **self._ops).alias('u')
-            return sql.select([func(s1.corresponding_column(col))]).scalar()
-        else:
-            return sql.select([func(col)], self._clause, **self._ops).scalar()
-
-    def min(self, col):
-        """Execute the SQL ``min()`` function against the given column."""
-
-        return self._col_aggregate(col, sql.func.min)
-
-    def max(self, col):
-        """Execute the SQL ``max()`` function against the given column."""
-
-        return self._col_aggregate(col, sql.func.max)
-
-    def sum(self, col):
-        """Execute the SQL ``sum()`` function against the given column."""
-
-        return self._col_aggregate(col, sql.func.sum)
-
-    def avg(self, col):
-        """Execute the SQL ``avg()`` function against the given column."""
-
-        return self._col_aggregate(col, sql.func.avg)
-
-    def clone(self):
-        """Create a copy of this ``SelectResults``."""
-
-        return SelectResults(self._query, self._clause, self._ops.copy(), self._joinpoint)
-
-    def filter(self, clause):
-        """Apply an additional ``WHERE`` clause against the query."""
-
-        new = self.clone()
-        new._clause = sql.and_(self._clause, clause)
-        return new
-
-    def select(self, clause):
-        return self.filter(clause)
-
-    def select_by(self, *args, **kwargs):
-        return self.filter(self._query._join_by(args, kwargs, start=self._joinpoint[1]))
-
-    def order_by(self, order_by):
-        """Apply an ``ORDER BY`` to the query."""
-
-        new = self.clone()
-        new._ops['order_by'] = order_by
-        return new
-
-    def limit(self, limit):
-        """Apply a ``LIMIT`` to the query."""
-
-        return self[:limit]
-
-    def offset(self, offset):
-        """Apply an ``OFFSET`` to the query."""
-
-        return self[offset:]
-
-    def distinct(self):
-        """Apply a ``DISTINCT`` to the query."""
-
-        new = self.clone()
-        new._ops['distinct'] = True
-        return new
-
-    def list(self):
-        """Return the results represented by this ``SelectResults`` as a list.
-
-        This results in an execution of the underlying query.
-        """
-
-        return list(self)
-
-    def select_from(self, from_obj):
-        """Set the `from_obj` parameter of the query.
-
-        `from_obj` is a list of one or more tables.
-        """
-
-        new = self.clone()
-        new._ops['from_obj'] = from_obj
-        return new
-
-    def join_to(self, prop):
-        """Join the table of this ``SelectResults`` to the table located against the given property name.
-
-        Subsequent calls to join_to or outerjoin_to will join against
-        the rightmost table located from the previous `join_to` or
-        `outerjoin_to` call, searching for the property starting with
-        the rightmost mapper last located.
-        """
-
-        new = self.clone()
-        (clause, mapper) = self._join_to(prop, outerjoin=False)
-        new._ops['from_obj'] = [clause]
-        new._joinpoint = (clause, mapper)
-        return new
-
-    def outerjoin_to(self, prop):
-        """Outer join the table of this ``SelectResults`` to the 
-        table located against the given property name.
-
-        Subsequent calls to join_to or outerjoin_to will join against
-        the rightmost table located from the previous ``join_to`` or
-        ``outerjoin_to`` call, searching for the property starting with
-        the rightmost mapper last located.
-        """
-
-        new = self.clone()
-        (clause, mapper) = self._join_to(prop, outerjoin=True)
-        new._ops['from_obj'] = [clause]
-        new._joinpoint = (clause, mapper)
-        return new
-
-    def _join_to(self, prop, outerjoin=False):
-        [keys,p] = self._query._locate_prop(prop, start=self._joinpoint[1])
-        clause = self._joinpoint[0]
-        mapper = self._joinpoint[1]
-        for key in keys:
-            prop = mapper.props[key]
-            if outerjoin:
-                clause = clause.outerjoin(prop.select_table, prop.get_join(mapper))
-            else:
-                clause = clause.join(prop.select_table, prop.get_join(mapper))
-            mapper = prop.mapper
-        return (clause, mapper)
-
-    def compile(self):
-        return self._query.compile(self._clause, **self._ops)
-
-    def __getitem__(self, item):
-        if isinstance(item, slice):
-            start = item.start
-            stop = item.stop
-            if (isinstance(start, int) and start < 0) or \
-               (isinstance(stop, int) and stop < 0):
-                return list(self)[item]
-            else:
-                res = self.clone()
-                if start is not None and stop is not None:
-                    res._ops.update(dict(offset=self._ops.get('offset', 0)+start, limit=stop-start))
-                elif start is None and stop is not None:
-                    res._ops.update(dict(limit=stop))
-                elif start is not None and stop is None:
-                    res._ops.update(dict(offset=self._ops.get('offset', 0)+start))
-                if item.step is not None:
-                    return list(res)[None:None:item.step]
-                else:
-                    return res
-        else:
-            return list(self[item:item+1])[0]
-
-    def __iter__(self):
-        return iter(self._query.select_whereclause(self._clause, **self._ops))
+            if arg is not None:
+                query = query.filter(arg)
+            return query._legacy_select_kwargs(**kwargs)
+
+def SelectResults(query, clause=None, ops={}):
+    if clause is not None:
+        query = query.filter(clause)
+    query = query.options(orm.extension(SelectResultsExt()))
+    return query._legacy_select_kwargs(**ops)
index 21c1fac51b7ac48d1de4a84a6b18b0facbeb7d98..31b5947846779a74f20d5a04f0c5f314ae91fab1 100644 (file)
@@ -294,6 +294,7 @@ Boring tests here.  Nothing of real expository value.
 """
 
 from sqlalchemy import *
+from sqlalchemy.orm import *
 from sqlalchemy.ext.sessioncontext import SessionContext
 from sqlalchemy.ext.assignmapper import assign_mapper
 from sqlalchemy.exceptions import *
index f3243a37249a5391b752482abacd76f3c2f28021..5b96bc1148df397e7f68be14a2f9e77df2ecfac0 100644 (file)
@@ -14,7 +14,6 @@ def suite():
         # schema/tables
         'engine.reflection', 
 
-       'engine.proxy_engine'
         )
     alltests = unittest.TestSuite()
     for name in modules_to_test:
diff --git a/test/engine/proxy_engine.py b/test/engine/proxy_engine.py
deleted file mode 100644 (file)
index 26b738e..0000000
+++ /dev/null
@@ -1,204 +0,0 @@
-from testbase import PersistTest
-import testbase
-import os
-
-from sqlalchemy import *
-from sqlalchemy.ext.proxy import ProxyEngine
-
-
-#
-# Define an engine, table and mapper at the module level, to show that the
-# table and mapper can be used with different real engines in multiple threads
-#
-
-
-class ProxyTestBase(PersistTest):
-    def setUpAll(self):
-
-        global users, User, module_engine, module_metadata
-
-        module_engine = ProxyEngine(echo=testbase.echo)
-        module_metadata = MetaData()
-
-        users = Table('users', module_metadata, 
-                      Column('user_id', Integer, primary_key=True),
-                      Column('user_name', String(16)),
-                      Column('password', String(20))
-                      )
-
-        class User(object):
-            pass
-
-        User.mapper = mapper(User, users)
-    def tearDownAll(self):
-        clear_mappers()
-
-class ConstructTest(ProxyTestBase):
-    """tests that we can build SQL constructs without engine-specific parameters, particulary
-    oid_column, being needed, as the proxy engine is usually not connected yet."""
-
-    def test_join(self):
-        engine = ProxyEngine()
-        t = Table('table1', engine, 
-            Column('col1', Integer, primary_key=True))
-        t2 = Table('table2', engine, 
-            Column('col2', Integer, ForeignKey('table1.col1')))
-        j = join(t, t2)
-        
-
-class ProxyEngineTest1(ProxyTestBase):
-
-    def test_engine_connect(self):
-        # connect to a real engine
-        module_engine.connect(testbase.db_uri)
-        module_metadata.create_all(module_engine)
-
-        session = create_session(bind_to=module_engine)
-        try:
-
-            user = User()
-            user.user_name='fred'
-            user.password='*'
-
-            session.save(user)
-            session.flush()
-
-            query = session.query(User)
-
-            # select
-            sqluser = query.select_by(user_name='fred')[0]
-            assert sqluser.user_name == 'fred'
-
-            # modify
-            sqluser.user_name = 'fred jones'
-
-            # flush - saves everything that changed
-            session.flush()
-        
-            allusers = [ user.user_name for user in query.select() ]
-            assert allusers == ['fred jones']
-
-        finally:
-            module_metadata.drop_all(module_engine)
-
-
-class ThreadProxyTest(ProxyTestBase):
-
-    def tearDownAll(self):
-        try:
-            os.remove('threadtesta.db')
-        except OSError:
-            pass
-        try:
-            os.remove('threadtestb.db')
-        except OSError:
-            pass
-        
-    @testbase.supported('sqlite')
-    def test_multi_thread(self):
-        
-        from threading import Thread
-        from Queue import Queue
-        
-        # start 2 threads with different connection params
-        # and perform simultaneous operations, showing that the
-        # 2 threads don't share a connection
-        qa = Queue()
-        qb = Queue()
-        def run(db_uri, uname, queue):
-            def test():
-                
-                try:
-                    module_engine.connect(db_uri)
-                    module_metadata.create_all(module_engine)
-                    try:
-                        session = create_session(bind_to=module_engine)
-
-                        query = session.query(User)
-
-                        all = list(query.select())
-                        assert all == []
-
-                        u = User()
-                        u.user_name = uname
-                        u.password = 'whatever'
-
-                        session.save(u)
-                        session.flush()
-
-                        names = [u.user_name for u in query.select()]
-                        assert names == [uname]
-                    finally:
-                        module_metadata.drop_all(module_engine)
-                        module_engine.get_engine().dispose()
-                except Exception, e:
-                    import traceback
-                    traceback.print_exc()
-                    queue.put(e)
-                else:
-                    queue.put(False)
-            return test
-
-        a = Thread(target=run('sqlite:///threadtesta.db', 'jim', qa))
-        b = Thread(target=run('sqlite:///threadtestb.db', 'joe', qb))
-        
-        a.start()
-        b.start()
-        
-        # block and wait for the threads to push their results
-        res = qa.get()
-        if res != False:
-            raise res
-
-        res = qb.get()
-        if res != False:
-            raise res
-
-
-class ProxyEngineTest2(ProxyTestBase):
-
-    def test_table_singleton_a(self):
-        """set up for table singleton check
-        """
-        #
-        # For this 'test', create a proxy engine instance, connect it
-        # to a real engine, and make it do some work
-        #
-        engine = ProxyEngine()
-        cats = Table('cats', engine,
-                     Column('cat_id', Integer, primary_key=True),
-                     Column('cat_name', String))
-
-        engine.connect(testbase.db_uri)
-
-        cats.create(engine)
-        cats.drop(engine)
-
-        ProxyEngineTest2.cats_table_a = cats
-        assert isinstance(cats, Table)
-
-    def test_table_singleton_b(self):
-        """check that a table on a 2nd proxy engine instance gets 2nd table
-        instance
-        """
-        #
-        # Now create a new proxy engine instance and attach the same
-        # table as the first test. This should result in 2 table instances,
-        # since different proxy engine instances can't attach to the
-        # same table instance
-        #
-        engine = ProxyEngine()
-        cats = Table('cats', engine,
-                     Column('cat_id', Integer, primary_key=True),
-                     Column('cat_name', String))
-        assert id(cats) != id(ProxyEngineTest2.cats_table_a)
-
-        # the real test -- if we're still using the old engine reference,
-        # this will fail because the old reference's local storage will
-        # not have the default attributes
-        engine.connect(testbase.db_uri)
-        cats.create(engine)
-        cats.drop(engine)
-
-if __name__ == "__main__":
-    testbase.main()
index c89bf4b145eee09e41e69286b8eb883feeda5013..f80352f63efd1968dd48cf4119370f016b139a54 100644 (file)
@@ -4,6 +4,7 @@ import unittest, sys, datetime
 import tables
 db = testbase.db
 from sqlalchemy import *
+from sqlalchemy.orm import *
 
 
 class TransactionTest(testbase.PersistTest):
index 25e20168c9d5aca7753a636c2feac7dac62b4808..78aad0c93476b9f55911986b609563ac5bc68f89 100644 (file)
@@ -1,7 +1,8 @@
 import testbase
 from sqlalchemy.ext.activemapper           import ActiveMapper, column, one_to_many, one_to_one, many_to_many, objectstore
-from sqlalchemy             import and_, or_, clear_mappers, backref, create_session, exceptions
+from sqlalchemy             import and_, or_, exceptions
 from sqlalchemy             import ForeignKey, String, Integer, DateTime, Table, Column
+from sqlalchemy.orm         import clear_mappers, backref, create_session, class_mapper
 from datetime               import datetime
 import sqlalchemy
 
@@ -10,7 +11,7 @@ import sqlalchemy.ext.activemapper as activemapper
 
 class testcase(testbase.PersistTest):
     def setUpAll(self):
-        sqlalchemy.clear_mappers()
+        clear_mappers()
         objectstore.clear()
         global Person, Preferences, Address
         
@@ -262,7 +263,7 @@ class testcase(testbase.PersistTest):
 
 class testmanytomany(testbase.PersistTest):
      def setUpAll(self):
-         sqlalchemy.clear_mappers()
+         clear_mappers()
          objectstore.clear()
          global secondarytable, foo, baz
          secondarytable = Table("secondarytable",
@@ -312,14 +313,14 @@ class testmanytomany(testbase.PersistTest):
 
          # Optimistically based on activemapper one_to_many test, try  to append
          # baz1 to foo1.bazrel - (AttributeError: 'foo' object has no attribute 'bazrel')
-         print sqlalchemy.class_mapper(foo).props
-         print sqlalchemy.class_mapper(baz).props
+         print class_mapper(foo).props
+         print class_mapper(baz).props
          foo1.bazrel.append(baz1)
          assert (foo1.bazrel == [baz1])
         
 class testselfreferential(testbase.PersistTest):
     def setUpAll(self):
-        sqlalchemy.clear_mappers()
+        clear_mappers()
         objectstore.clear()
         global TreeNode
         class TreeNode(activemapper.ActiveMapper):
index d42a809c204077907bcb41401c46a07c724f3a42..479e9f399e7a8278cf8043029f3469ee2ba903ed 100644 (file)
@@ -2,6 +2,7 @@ from testbase import PersistTest, AssertMixin
 import testbase
 
 from sqlalchemy import *
+from sqlalchemy.orm import create_session, clear_mappers, relation, class_mapper
 
 from sqlalchemy.ext.assignmapper import assign_mapper
 from sqlalchemy.ext.sessioncontext import SessionContext
@@ -30,10 +31,7 @@ class OverrideAttributesTest(PersistTest):
         
         ctx = SessionContext(create_session)
         assign_mapper(ctx, SomeObject, table, properties={
-            # this is the current workaround for class attribute name/collection collision: specify collection_class
-            # explicitly.   when we do away with class attributes specifying collection classes, this wont be
-            # needed anymore.
-            'options':relation(SomeOtherObject, collection_class=list)
+            'options':relation(SomeOtherObject)
         })
         assign_mapper(ctx, SomeOtherObject, table2)
         class_mapper(SomeObject)
index b6476c836137eb3c7a9668506f97305574d336a3..57efe89e39d33edd896715eed137574e6b75a4bc 100644 (file)
@@ -3,6 +3,7 @@ import sqlalchemy.util as util
 import unittest
 import testbase
 from sqlalchemy import *
+from sqlalchemy.orm import *
 from sqlalchemy.ext.associationproxy import *
 
 db = testbase.db
index 73d0405a4a5509c049e1ec9b0cb7c17b53046305..c2811d5ce06e7b2426239ce5d966a8b1c71de467 100644 (file)
@@ -3,6 +3,7 @@ import sqlalchemy.util as util
 import unittest, sys, os
 import testbase
 from sqlalchemy import *
+from sqlalchemy.orm import *
 from sqlalchemy.ext.orderinglist import *
 
 db = testbase.db
index 5af61b6a70647f7e6c28155475099b60d4e5cfd7..eeaff7d549b8ec8b401379f670fd7e6d286c9ad9 100644 (file)
@@ -3,6 +3,7 @@ import testbase
 import tables
 
 from sqlalchemy import *
+from sqlalchemy.orm import *
 
 from sqlalchemy.ext.selectresults import SelectResultsExt, SelectResults
 
@@ -157,7 +158,7 @@ class RelationsTest(AssertMixin):
         })
         session = create_session()
         query = SelectResults(session.query(tables.User))
-        x = query.join_to('orders').join_to('items').select(tables.Item.c.item_id==2)
+        x = query.join(['orders','items']).select(tables.Item.c.item_id==2)
         print x.compile()
         self.assert_result(list(x), tables.User, tables.user_result[2])
     def test_outerjointo(self):
@@ -169,7 +170,7 @@ class RelationsTest(AssertMixin):
         })
         session = create_session()
         query = SelectResults(session.query(tables.User))
-        x = query.outerjoin_to('orders').outerjoin_to('items').select(or_(tables.Order.c.order_id==None,tables.Item.c.item_id==2))
+        x = query.outerjoin(['orders', 'items']).select(or_(tables.Order.c.order_id==None,tables.Item.c.item_id==2))
         print x.compile()
         self.assert_result(list(x), tables.User, *tables.user_result[1:3])
     def test_outerjointo_count(self):
@@ -181,7 +182,7 @@ class RelationsTest(AssertMixin):
         })
         session = create_session()
         query = SelectResults(session.query(tables.User))
-        x = query.outerjoin_to('orders').outerjoin_to('items').select(or_(tables.Order.c.order_id==None,tables.Item.c.item_id==2)).count()
+        x = query.outerjoin(['orders', 'items']).select(or_(tables.Order.c.order_id==None,tables.Item.c.item_id==2)).count()
         assert x==2
     def test_from(self):
         mapper(tables.User, tables.users, properties={
index 0a55a1b56a43957129b8055a6af92d6c4b2b8a2b..6453d7ddf751cff845f5d94a863d8647f6141e85 100644 (file)
@@ -122,10 +122,10 @@ class EagerTest(AssertMixin):
     def test_dslish(self):
         """test the same as witheagerload except building the query via SelectResults"""
         s = create_session()
-        q=SelectResults(s.query(Test).options(eagerload('category')))
-        l=q.select ( 
+        q=s.query(Test).options(eagerload('category'))
+        l=q.filter ( 
             and_(tests.c.owner_id==1,or_(options.c.someoption==None,options.c.someoption==False))
-            ).outerjoin_to('owner_option')
+            ).outerjoin('owner_option')
             
         result = ["%d %s" % ( t.id,t.category.name ) for t in l]
         print result
@@ -316,15 +316,9 @@ class EagerTest4(testbase.ORMTest):
         sess.flush()
 
         q = sess.query(Department)
-        filters = [q.join_to('employees'),
-                   Employee.c.name.startswith('J')]
-
-        d = SelectResults(q)
-        d = d.join_to('employees').filter(Employee.c.name.startswith('J'))
-        d = d.distinct()
-        d = d.order_by([desc(Department.c.name)])
-        assert d.count() == 2
-        assert d[0] is d2
+        q = q.join('employees').filter(Employee.c.name.startswith('J')).distinct().order_by([desc(Department.c.name)])
+        assert q.count() == 2
+        assert q[0] is d2
 
 class EagerTest5(testbase.ORMTest):
     """test the construction of AliasedClauses for the same eager load property but different 
index fbcdb5131a02daac2e53fa37075b07dc17281dad..ac47ff33cee1df9081bc37fd27d638b668c45a4c 100644 (file)
@@ -118,6 +118,10 @@ class RelationTest2(testbase.ORMTest):
         self.do_test("join1", True)
     def testrelationonsubclass_j2_data(self):
         self.do_test("join2", True)
+    def testrelationonsubclass_j3_nodata(self):
+        self.do_test("join3", False)
+    def testrelationonsubclass_j3_data(self):
+        self.do_test("join3", True)
                 
     def do_test(self, jointype="join1", usedata=False):
         class Person(AttrSettable):
@@ -130,19 +134,24 @@ class RelationTest2(testbase.ORMTest):
                 'person':people.select(people.c.type=='person'),
                 'manager':join(people, managers, people.c.person_id==managers.c.person_id)
             }, None)
+            polymorphic_on=poly_union.c.type
         elif jointype == "join2":
             poly_union = polymorphic_union({
                 'person':people.select(people.c.type=='person'),
                 'manager':managers.join(people, people.c.person_id==managers.c.person_id)
             }, None)
-
+            polymorphic_on=poly_union.c.type
+        elif jointype == "join3":
+            poly_union = None
+            polymorphic_on = people.c.type
+            
         if usedata:
             class Data(object):
                 def __init__(self, data):
                     self.data = data
             mapper(Data, data)
             
-        mapper(Person, people, select_table=poly_union, polymorphic_identity='person', polymorphic_on=poly_union.c.type)
+        mapper(Person, people, select_table=poly_union, polymorphic_identity='person', polymorphic_on=polymorphic_on)
 
         if usedata:
             mapper(Manager, managers, inherits=Person, inherit_condition=people.c.person_id==managers.c.person_id, polymorphic_identity='manager',
@@ -204,6 +213,10 @@ class RelationTest3(testbase.ORMTest):
        self.do_test("join1", True)
     def testrelationonbaseclass_j2_data(self):
        self.do_test("join2", True)
+    def testrelationonbaseclass_j3_nodata(self):
+       self.do_test("join3", False)
+    def testrelationonbaseclass_j3_data(self):
+       self.do_test("join3", True)
 
     def do_test(self, jointype="join1", usedata=False):
         class Person(AttrSettable):
@@ -226,6 +239,8 @@ class RelationTest3(testbase.ORMTest):
                 'manager':join(people, managers, people.c.person_id==managers.c.person_id),
                 'person':people.select(people.c.type=='person')
             }, None)
+        elif jointype == "join3":
+            poly_union=None
             
         if usedata:
             mapper(Data, data)
index e754945bb8f77ebd253c5e98536e30e62f7c3562..43125982670c89b61485c48db113cc25348b69df 100644 (file)
@@ -1223,18 +1223,6 @@ class EagerTest(MapperSuperTest):
         l = session.query(User).select()
         self.assert_result(l, User, *user_address_result)
         
-    def testcompile(self):
-        """tests deferred operation of a pre-compiled mapper statement"""
-        session = create_session()
-        m = mapper(User, users, properties = dict(
-            addresses = relation(mapper(Address, addresses), lazy = False)
-        ))
-        s = session.query(m).filter(and_(addresses.c.email_address == bindparam('emailad'), addresses.c.user_id==users.c.user_id)).compile()
-        c = s.compile()
-        self.echo("\n" + str(c) + repr(c.get_params()))
-        
-        l = m.instances(s.execute(emailad = 'jack@bean.com'), session)
-        self.echo(repr(l))
     
     def testonselect(self):
         """test eager loading of a mapper which is against a select"""
@@ -1414,179 +1402,6 @@ class EagerTest(MapperSuperTest):
                ])},
         )
 
-class InstancesTest(MapperSuperTest):
-    def testcustomfromalias(self):
-        mapper(User, users, properties={
-            'addresses':relation(Address, lazy=True)
-        })
-        mapper(Address, addresses)
-        query = users.select(users.c.user_id==7).union(users.select(users.c.user_id>7)).alias('ulist').outerjoin(addresses).select(use_labels=True,order_by=['ulist.user_id', addresses.c.address_id])
-        q = create_session().query(User)
-        
-        def go():
-            l = q.options(contains_alias('ulist'), contains_eager('addresses')).instances(query.execute())
-            self.assert_result(l, User, *user_address_result)
-        self.assert_sql_count(testbase.db, go, 1)
-        
-    def testcustomeagerquery(self):
-        mapper(User, users, properties={
-            # setting lazy=True - the contains_eager() option below
-            # should imply eagerload()
-            'addresses':relation(Address, lazy=True)
-        })
-        mapper(Address, addresses)
-        
-        selectquery = users.outerjoin(addresses).select(use_labels=True, order_by=[users.c.user_id, addresses.c.address_id])
-        q = create_session().query(User)
-        
-        def go():
-            l = q.options(contains_eager('addresses')).instances(selectquery.execute())
-            self.assert_result(l, User, *user_address_result)
-        self.assert_sql_count(testbase.db, go, 1)
-
-    def testcustomeagerwithstringalias(self):
-        mapper(User, users, properties={
-            'addresses':relation(Address, lazy=False)
-        })
-        mapper(Address, addresses)
-
-        adalias = addresses.alias('adalias')
-        selectquery = users.outerjoin(adalias).select(use_labels=True, order_by=[users.c.user_id, adalias.c.address_id])
-        q = create_session().query(User)
-
-        def go():
-            l = q.options(contains_eager('addresses', alias="adalias")).instances(selectquery.execute())
-            self.assert_result(l, User, *user_address_result)
-        self.assert_sql_count(testbase.db, go, 1)
-
-    def testcustomeagerwithalias(self):
-        mapper(User, users, properties={
-            'addresses':relation(Address, lazy=False)
-        })
-        mapper(Address, addresses)
-
-        adalias = addresses.alias('adalias')
-        selectquery = users.outerjoin(adalias).select(use_labels=True, order_by=[users.c.user_id, adalias.c.address_id])
-        q = create_session().query(User)
-
-        def go():
-            l = q.options(contains_eager('addresses', alias=adalias)).instances(selectquery.execute())
-            self.assert_result(l, User, *user_address_result)
-        self.assert_sql_count(testbase.db, go, 1)
-
-    def testcustomeagerwithdecorator(self):
-        mapper(User, users, properties={
-            'addresses':relation(Address, lazy=False)
-        })
-        mapper(Address, addresses)
-
-        adalias = addresses.alias('adalias')
-        selectquery = users.outerjoin(adalias).select(use_labels=True, order_by=[users.c.user_id, adalias.c.address_id])
-        def decorate(row):
-            d = {}
-            for c in addresses.columns:
-                d[c] = row[adalias.corresponding_column(c)]
-            return d
-            
-        q = create_session().query(User)
-
-        def go():
-            l = q.options(contains_eager('addresses', decorator=decorate)).instances(selectquery.execute())
-            self.assert_result(l, User, *user_address_result)
-        self.assert_sql_count(testbase.db, go, 1)
-    
-    def testmultiplemappers(self):
-        mapper(User, users, properties={
-            'addresses':relation(Address, lazy=True)
-        })
-        mapper(Address, addresses)
-
-        sess = create_session()
-        
-        (user7, user8, user9) = sess.query(User).select()
-        (address1, address2, address3, address4) = sess.query(Address).select()
-        
-        selectquery = users.outerjoin(addresses).select(use_labels=True, order_by=[users.c.user_id, addresses.c.address_id])
-        q = sess.query(User)
-        l = q.instances(selectquery.execute(), Address)
-        # note the result is a cartesian product
-        assert l == [
-            (user7, address1),
-            (user8, address2),
-            (user8, address3),
-            (user8, address4),
-            (user9, None)
-        ]
-    
-    def testmultipleonquery(self):
-        mapper(User, users, properties={
-            'addresses':relation(Address, lazy=True)
-        })
-        mapper(Address, addresses)
-        sess = create_session()
-        (user7, user8, user9) = sess.query(User).select()
-        (address1, address2, address3, address4) = sess.query(Address).select()
-        q = sess.query(User)
-        q = q.add_entity(Address).outerjoin('addresses')
-        l = q.list()
-        assert l == [
-            (user7, address1),
-            (user8, address2),
-            (user8, address3),
-            (user8, address4),
-            (user9, None)
-        ]
-
-    def testcolumnonquery(self):
-        mapper(User, users, properties={
-            'addresses':relation(Address, lazy=True)
-        })
-        mapper(Address, addresses)
-        
-        sess = create_session()
-        (user7, user8, user9) = sess.query(User).select()
-        q = sess.query(User)
-        q = q.group_by([c for c in users.c]).order_by(User.c.user_id).outerjoin('addresses').add_column(func.count(addresses.c.address_id).label('count'))
-        l = q.list()
-        assert l == [
-            (user7, 1),
-            (user8, 3),
-            (user9, 0)
-        ], repr(l)
-        
-    def testmapperspluscolumn(self):
-        mapper(User, users)
-        s = select([users, func.count(addresses.c.address_id).label('count')], from_obj=[users.outerjoin(addresses)], group_by=[c for c in users.c], order_by=[users.c.user_id])
-        sess = create_session()
-        (user7, user8, user9) = sess.query(User).select()
-        q = sess.query(User)
-        l = q.instances(s.execute(), "count")
-        assert l == [
-            (user7, 1),
-            (user8, 3),
-            (user9, 0)
-        ]
-        
-    def testmappersplustwocolumns(self):
-        mapper(User, users)
-
-        # Fixme ticket #475!
-        if db.engine.name == 'mysql':
-            col2 = func.concat("Name:", users.c.user_name).label('concat')
-        else:
-            col2 = ("Name:" + users.c.user_name).label('concat')
-        
-        s = select([users, func.count(addresses.c.address_id).label('count'), col2], from_obj=[users.outerjoin(addresses)], group_by=[c for c in users.c], order_by=[users.c.user_id])
-        sess = create_session()
-        (user7, user8, user9) = sess.query(User).select()
-        q = sess.query(User)
-        l = q.instances(s.execute(), "count", "concat")
-        print l
-        assert l == [
-            (user7, 1, "Name:jack"),
-            (user8, 3, "Name:ed"),
-            (user9, 0, "Name:fred")
-        ]
 
 
 if __name__ == "__main__":    
index fbee5e88c86646e06092fa0cc18d0c7be2be802b..b7b906466d29211bea989a102e143733f7f2e29c 100644 (file)
@@ -2,6 +2,40 @@ from sqlalchemy import *
 from sqlalchemy.orm import *
 import testbase
 
+class Base(object):
+    def __init__(self, **kwargs):
+        for k in kwargs:
+            setattr(self, k, kwargs[k])
+            
+    def __ne__(self, other):
+        return not self.__eq__(other)
+        
+    def __eq__(self, other):
+        """'passively' compare this object to another.
+        
+        only look at attributes that are present on the source object.
+        
+        """
+        # use __dict__ to avoid instrumented properties
+        for attr in self.__dict__.keys():
+            if attr[0] == '_':
+                continue
+            value = getattr(self, attr)
+            if hasattr(value, '__iter__') and not isinstance(value, basestring):
+                if len(value) == 0:
+                    continue
+                for (us, them) in zip(value, getattr(other, attr)):
+                    if us != them:
+                        return False
+                else:
+                    continue
+            else:
+                if value is not None:
+                    if value != getattr(other, attr):
+                        return False
+        else:
+            return True
+
 class QueryTest(testbase.ORMTest):
     keep_mappers = True
     keep_data = True
@@ -123,29 +157,11 @@ class QueryTest(testbase.ORMTest):
             dict(keyword_id=7, item_id=2),
             dict(keyword_id=6, item_id=3)
         )
-        
+
+    
     def setup_mappers(self):
-        global User, Order, Item, Keyword, Address, Base
+        global User, Order, Item, Keyword, Address
         
-        class Base(object):
-            def __init__(self, **kwargs):
-                for k in kwargs:
-                    setattr(self, k, kwargs[k])
-            def __eq__(self, other):
-                for attr in dir(self):
-                    if attr[0] == '_':
-                        continue
-                    value = getattr(self, attr)
-                    if isinstance(value, list):
-                        for (us, them) in zip(value, getattr(other, attr)):
-                            if us != them:
-                                return False
-                        else:
-                            return True
-                    else:
-                        if value is not None:
-                            return value == getattr(other, attr)
-                    
         class User(Base):pass
         class Order(Base):pass
         class Item(Base):pass
@@ -153,6 +169,7 @@ class QueryTest(testbase.ORMTest):
         class Address(Base):pass
 
         mapper(User, users, properties={
+            'addresses':relation(Address),
             'orders':relation(Order, backref='user'), # o2m, m2o
         })
         mapper(Address, addresses)
@@ -165,6 +182,23 @@ class QueryTest(testbase.ORMTest):
         })
         mapper(Keyword, keywords)
 
+    @property
+    def user_address_result(self):
+        return [
+            User(id=7, addresses=[
+                Address(id=1)
+            ]), 
+            User(id=8, addresses=[
+                Address(id=2),
+                Address(id=3),
+                Address(id=4)
+            ]), 
+            User(id=9, addresses=[
+                Address(id=5)
+            ]), 
+            User(id=10, addresses=[])
+        ]
+
 class GetTest(QueryTest):
     def test_get(self):
         s = create_session()
@@ -190,18 +224,26 @@ class GetTest(QueryTest):
         mapper(LocalFoo, table)
         assert create_session().query(LocalFoo).get(ustring) == LocalFoo(id=ustring, data=ustring)
 
+class CompileTest(QueryTest):
+    def test_deferred(self):
+        session = create_session()
+        s = session.query(User).filter(and_(addresses.c.email_address == bindparam('emailad'), addresses.c.user_id==users.c.id)).compile()
+        
+        l = session.query(User).instances(s.execute(emailad = 'jack@bean.com'))
+        assert [User(id=7)] == l
+    
 class SliceTest(QueryTest):
     def test_first(self):
-        assert create_session().query(User).first() == User(id=7)
+        assert  User(id=7) == create_session().query(User).first()
         
         assert create_session().query(User).filter(users.c.id==27).first() is None
         
 class FilterTest(QueryTest):
     def test_basic(self):
-        assert create_session().query(User).all() == [User(id=7), User(id=8), User(id=9),User(id=10)]
+        assert [User(id=7), User(id=8), User(id=9),User(id=10)] == create_session().query(User).all()
 
     def test_onefilter(self):
-        assert create_session().query(User).filter(users.c.name.endswith('ed')).all() == [User(id=8), User(id=9)]
+        assert [User(id=8), User(id=9)] == create_session().query(User).filter(users.c.name.endswith('ed')).all()
 
 class ParentTest(QueryTest):
     def test_o2m(self):
@@ -259,6 +301,114 @@ class JoinTest(QueryTest):
         result = create_session().query(User).select_from(users.join(oalias)).filter(oalias.c.description.in_("order 1", "order 2", "order 3")).join(['orders', 'items']).filter_by(id=4).all()
         assert [User(id=7, name='jack')] == result
 
+class InstancesTest(QueryTest):
+
+    def test_from_alias(self):
+
+        query = users.select(users.c.id==7).union(users.select(users.c.id>7)).alias('ulist').outerjoin(addresses).select(use_labels=True,order_by=['ulist.id', addresses.c.id])
+        q = create_session().query(User)
+
+        def go():
+            l = q.options(contains_alias('ulist'), contains_eager('addresses')).instances(query.execute())
+            assert self.user_address_result == l
+        self.assert_sql_count(testbase.db, go, 1)
+
+    def test_contains_eager(self):
+
+        selectquery = users.outerjoin(addresses).select(use_labels=True, order_by=[users.c.id, addresses.c.id])
+        q = create_session().query(User)
+
+        def go():
+            l = q.options(contains_eager('addresses')).instances(selectquery.execute())
+            assert self.user_address_result == l
+        self.assert_sql_count(testbase.db, go, 1)
+
+    def test_contains_eager_alias(self):
+        adalias = addresses.alias('adalias')
+        selectquery = users.outerjoin(adalias).select(use_labels=True, order_by=[users.c.id, adalias.c.id])
+        q = create_session().query(User)
+
+        def go():
+            # test using a string alias name
+            l = q.options(contains_eager('addresses', alias="adalias")).instances(selectquery.execute())
+            assert self.user_address_result == l
+        self.assert_sql_count(testbase.db, go, 1)
+
+        def go():
+            # test using the Alias object itself
+            l = q.options(contains_eager('addresses', alias=adalias)).instances(selectquery.execute())
+            assert self.user_address_result == l
+        self.assert_sql_count(testbase.db, go, 1)
+
+        def decorate(row):
+            d = {}
+            for c in addresses.columns:
+                d[c] = row[adalias.corresponding_column(c)]
+            return d
+
+        def go():
+            # test using a custom 'decorate' function
+            l = q.options(contains_eager('addresses', decorator=decorate)).instances(selectquery.execute())
+            assert self.user_address_result == l
+        self.assert_sql_count(testbase.db, go, 1)
+
+    def test_multi_mappers(self):
+        sess = create_session()
+
+        (user7, user8, user9, user10) = sess.query(User).all()
+        (address1, address2, address3, address4, address5) = sess.query(Address).all()
+
+        # note the result is a cartesian product
+        expected = [(user7, address1),
+            (user8, address2),
+            (user8, address3),
+            (user8, address4),
+            (user9, address5),
+            (user10, None)]
+        
+        selectquery = users.outerjoin(addresses).select(use_labels=True, order_by=[users.c.id, addresses.c.id])
+        q = sess.query(User)
+        l = q.instances(selectquery.execute(), Address)
+        assert l == expected
+
+        q = sess.query(User)
+        q = q.add_entity(Address).outerjoin('addresses')
+        l = q.all()
+        assert l == expected
+
+    def test_multi_columns(self):
+        sess = create_session()
+        (user7, user8, user9, user10) = sess.query(User).select()
+        expected = [(user7, 1),
+            (user8, 3),
+            (user9, 1),
+            (user10, 0)
+            ]
+            
+        q = sess.query(User)
+        q = q.group_by([c for c in users.c]).order_by(User.c.id).outerjoin('addresses').add_column(func.count(addresses.c.id).label('count'))
+        l = q.all()
+        assert l == expected
+
+        s = select([users, func.count(addresses.c.id).label('count')], from_obj=[users.outerjoin(addresses)], group_by=[c for c in users.c], order_by=[users.c.id])
+        q = sess.query(User)
+        l = q.instances(s.execute(), "count")
+        assert l == expected
+
+    @testbase.unsupported('mysql') # only because of "+" operator requiring "concat" in mysql (fix #475)
+    def test_two_columns(self):
+        sess = create_session()
+        (user7, user8, user9, user10) = sess.query(User).select()
+        expected = [
+            (user7, 1, "Name:jack"),
+            (user8, 3, "Name:ed"),
+            (user9, 1, "Name:fred"),
+            (user10, 0, "Name:chuck")]
+
+        s = select([users, func.count(addresses.c.id).label('count'), ("Name:" + users.c.name).label('concat')], from_obj=[users.outerjoin(addresses)], group_by=[c for c in users.c], order_by=[users.c.id])
+        q = create_session().query(User)
+        l = q.instances(s.execute(), "count", "concat")
+        assert l == expected
 
 
 
index 0bc1a6b2e10d8bbad74a240cfe0a8bd3621e85eb..992a3afc03e23402dfaf4f212e9b18fc12dc06fb 100644 (file)
@@ -4,6 +4,7 @@ import unittest, sys, os
 import sqlalchemy.schema as schema
 import testbase
 from sqlalchemy import *
+from sqlalchemy.orm import mapper, create_session
 import sqlalchemy
 
 db = testbase.db
@@ -159,6 +160,9 @@ class AutoIncrementTest(PersistTest):
             table.drop()    
 
     def testfetchid(self):
+        
+        # TODO: what does this test do that all the various ORM tests dont ?
+        
         meta = BoundMetaData(testbase.db)
         table = Table("aitest", meta, 
             Column('id', Integer, primary_key=True),
index 65a7cce0d095df386d9a9687a10f2a4e4b45bcb1..15f4dd14c3eeaf2d286afc0937962a2ceae8537b 100644 (file)
@@ -2,6 +2,7 @@
 import testbase
 
 from sqlalchemy import *
+from sqlalchemy.orm import mapper, relation, create_session, eagerload
 
 """verrrrry basic unicode column name testing"""
 
@@ -76,4 +77,4 @@ class UnicodeSchemaTest(testbase.PersistTest):
         assert new_a1.t2s[0].a == b1.a
         
 if __name__ == '__main__':
-    testbase.main()
\ No newline at end of file
+    testbase.main()
index fdbe6aa5b4a5a95db13c4f503f9a03b061c5c64c..51c5c6d7f191ba0ab29088379922d267b28a234b 100644 (file)
@@ -1,11 +1,13 @@
+"""base import for all test cases.  Patches in enhancements to unittest.TestCase, 
+instruments SQLAlchemy dialect/engine to track SQL statements for assertion purposes,
+provides base test classes for common test scenarios."""
+
 import sys
 sys.path.insert(0, './lib/')
-import os, unittest, StringIO, re, ConfigParser
+
+import os, unittest, StringIO, re, ConfigParser, optparse
 import sqlalchemy
-from sqlalchemy import sql, engine, pool
-import sqlalchemy.engine.base as base
-import optparse
-from sqlalchemy.schema import BoundMetaData
+from sqlalchemy import sql, engine, pool, BoundMetaData
 from sqlalchemy.orm import clear_mappers
 
 db = None
@@ -134,9 +136,6 @@ firebird=firebird://sysdba:s@localhost/tmp/test.fdb
         return ExecutionContextWrapper(create_context(*args, **kwargs))
     db.dialect.create_execution_context = create_exec_context
     
-    global testdata
-    testdata = TestData(db)
-    
     if options.topological:
         from sqlalchemy.orm import unitofwork
         from sqlalchemy import topological
@@ -161,6 +160,7 @@ firebird=firebird://sysdba:s@localhost/tmp/test.fdb
     
 def unsupported(*dbs):
     """a decorator that marks a test as unsupported by one or more database implementations"""
+    
     def decorate(func):
         name = db.name
         for d in dbs:
@@ -175,6 +175,7 @@ def unsupported(*dbs):
 
 def supported(*dbs):
     """a decorator that marks a test as supported by one or more database implementations"""
+    
     def decorate(func):
         name = db.name
         for d in dbs:
@@ -189,19 +190,27 @@ def supported(*dbs):
 
         
 class PersistTest(unittest.TestCase):
-    """persist base class, provides default setUpAll, tearDownAll and echo functionality"""
+
     def __init__(self, *args, **params):
         unittest.TestCase.__init__(self, *args, **params)
+
     def echo(self, text):
+        """DEPRECATED.  use print <statement>"""
         echo_text(text)
+        
     def install_threadlocal(self):
+        """DEPRECATED."""
         sqlalchemy.mods.threadlocal.install_plugin()
+        
     def uninstall_threadlocal(self):
+        """DEPRECATED."""
         sqlalchemy.mods.threadlocal.uninstall_plugin()
+
     def setUpAll(self):
         pass
     def tearDownAll(self):
         pass
+
     def shortDescription(self):
         """overridden to not return docstrings"""
         return None
@@ -209,15 +218,18 @@ class PersistTest(unittest.TestCase):
 class AssertMixin(PersistTest):
     """given a list-based structure of keys/properties which represent information within an object structure, and
     a list of actual objects, asserts that the list of objects corresponds to the structure."""
+    
     def assert_result(self, result, class_, *objects):
         result = list(result)
         if echo:
             print repr(result)
         self.assert_list(result, class_, objects)
+        
     def assert_list(self, result, class_, list):
         self.assert_(len(result) == len(list), "result list is not the same size as test list, for class " + class_.__name__)
         for i in range(0, len(list)):
             self.assert_row(class_, result[i], list[i])
+            
     def assert_row(self, class_, rowobj, desc):
         self.assert_(rowobj.__class__ is class_, "item class is not " + repr(class_))
         for key, value in desc.iteritems():
@@ -228,9 +240,10 @@ class AssertMixin(PersistTest):
                     self.assert_row(value[0], getattr(rowobj, key), value[1])
             else:
                 self.assert_(getattr(rowobj, key) == value, "attribute %s value %s does not match %s" % (key, getattr(rowobj, key), value))
+                
     def assert_sql(self, db, callable_, list, with_sequences=None):
         global testdata
-        testdata = TestData(db)
+        testdata = TestData()
         if with_sequences is not None and (db.engine.name == 'postgres' or db.engine.name == 'oracle'):
             testdata.set_assert_list(self, with_sequences)
         else:
@@ -242,7 +255,7 @@ class AssertMixin(PersistTest):
 
     def assert_sql_count(self, db, callable_, count):
         global testdata
-        testdata = TestData(db)
+        testdata = TestData()
         try:
             callable_()
         finally:
@@ -250,7 +263,7 @@ class AssertMixin(PersistTest):
 
     def capture_sql(self, db, callable_):
         global testdata
-        testdata = TestData(db)
+        testdata = TestData()
         buffer = StringIO.StringIO()
         testdata.buffer = buffer
         try:
@@ -281,9 +294,9 @@ class ORMTest(AssertMixin):
                 t.delete().execute().close()
 
 class TestData(object):
-    def __init__(self, engine):
-        self._engine = engine
-        self.logger = engine.logger
+    """tracks SQL expressions as theyre executed via an instrumented ExecutionContext."""
+    
+    def __init__(self):
         self.set_assert_list(None, None)
         self.sql_count = 0
         self.buffer = None
@@ -293,8 +306,13 @@ class TestData(object):
         self.assert_list = list
         if list is not None:
             self.assert_list.reverse()
-    
+
+testdata = TestData()
+
 class ExecutionContextWrapper(object):
+    """instruments the ExecutionContext created by the Engine so that SQL expressions
+    can be tracked."""
+    
     def __init__(self, ctx):
         self.__dict__['ctx'] = ctx
     def __getattr__(self, key):
index 244a53d0e9b3dc1b944f3728bfc4f4334db65ab3..11eaf4fd04020c334c04e57bb6ab9cfda7519ac2 100644 (file)
@@ -4,6 +4,7 @@ import zblog.tables as tables
 import zblog.user as user
 from zblog.blog import *
 from sqlalchemy import *
+from sqlalchemy.orm import *
 import sqlalchemy.util as util
 
 def zblog_mappers():
index e538cff9d832b7340c42f35f0367fac9330280fc..f0b5f58f251f9c7b619a4e8d0fe2be6dfad80fa5 100644 (file)
@@ -4,6 +4,7 @@ import unittest
 
 db = testbase.db
 from sqlalchemy import *
+from sqlalchemy.orm import *
 
 from zblog import mappers, tables
 from zblog.user import *