]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Restored 2.3 compat. in lib/sqlalchemy
authorJason Kirtland <jek@discorporate.us>
Sat, 19 Jan 2008 23:37:11 +0000 (23:37 +0000)
committerJason Kirtland <jek@discorporate.us>
Sat, 19 Jan 2008 23:37:11 +0000 (23:37 +0000)
- Part one of test suite fixes to run on 2.3
  Lots of failures still around sets; sets.Set differs from __builtin__.set
  particularly in the binops. We depend on set extensively now and may need to
  provide a corrected sets.Set subclass on 2.3.

17 files changed:
lib/sqlalchemy/orm/dependency.py
lib/sqlalchemy/orm/query.py
test/base/utils.py
test/orm/dynamic.py
test/orm/inheritance/abc_polymorphic.py
test/orm/inheritance/basic.py
test/orm/inheritance/magazine.py
test/orm/inheritance/polymorph.py
test/orm/inheritance/polymorph2.py
test/orm/relationships.py
test/testlib/__init__.py
test/testlib/compat.py [new file with mode: 0644]
test/testlib/engines.py
test/testlib/profiling.py
test/testlib/tables.py
test/testlib/testing.py
test/zblog/blog.py

index c28598542870dea38b73b3f9a02bd055d4b44100..f675aff5009bda2989c34e37b843e5b0ad55df33 100644 (file)
@@ -282,7 +282,7 @@ class DetectKeySwitch(DependencyProcessor):
             self._process_key_switches(deplist, uowcommit)
     
     def _process_key_switches(self, deplist, uowcommit):    
-        switchers = util.Set(s for s in deplist if self._pks_changed(uowcommit, s))
+        switchers = util.Set([s for s in deplist if self._pks_changed(uowcommit, s)])
         if switchers:
             # yes, we're doing a linear search right now through the UOW.  only 
             # takes effect when primary key values have actually changed.
index ab31b328bfe59d0fc40c68c7af25f4d437539950..201d0e2e308264b74960461ae60279e6002110f1 100644 (file)
@@ -880,7 +880,7 @@ class Query(object):
             if tuples:
                 rows = util.OrderedSet()
                 for row in fetch:
-                    rows.add(tuple(proc(context, row) for proc in process))
+                    rows.add(tuple([proc(context, row) for proc in process]))
             else:
                 rows = util.UniqueAppender([])
                 for row in fetch:
index 5ebc33921553a9a3d328d6cc40d6722671cce933..61e2b95a8a692c701da72ffef57355a142b56343 100644 (file)
@@ -2,7 +2,7 @@ import testenv; testenv.configure_for_tests()
 import unittest
 from sqlalchemy import util, sql, exceptions
 from testlib import *
-
+from testlib import sorted
 
 class OrderedDictTest(PersistTest):
     def test_odict(self):
index 199eb474ffa789621aee557032e01502b5fb2dee..3382f0205992fca63f2761ee14aeb247715f57ea 100644 (file)
@@ -170,6 +170,7 @@ class FlushTest(FixtureTest):
         sess.delete(u)
         sess.close()
 
+
 def create_backref_test(autoflush, saveuser):
     def test_backref(self):
         mapper(User, users, properties={
@@ -203,10 +204,9 @@ def create_backref_test(autoflush, saveuser):
             sess.flush()
         self.assert_(list(u.addresses) == [])
 
-    test_backref.__name__ = "test%s%s" % (
-        (autoflush and "_autoflush" or ""),
-        (saveuser and "_saveuser" or "_savead"),
-    )
+    test_backref = _function_named(
+        test_backref, "test%s%s" % ((autoflush and "_autoflush" or ""),
+                                    (saveuser and "_saveuser" or "_savead")))
     setattr(FlushTest, test_backref.__name__, test_backref)
 
 for autoflush in (False, True):
@@ -216,7 +216,7 @@ for autoflush in (False, True):
 class DontDereferenceTest(ORMTest):
     def define_tables(self, metadata):
         global users_table, addresses_table
-        
+
         users_table = Table('users', metadata,
                            Column('id', Integer, primary_key=True),
                            Column('name', String(40)),
@@ -245,7 +245,7 @@ class DontDereferenceTest(ORMTest):
         session.save(user)
         session.flush()
         session.clear()
-        
+
         def query1():
             session = create_session(metadata.bind)
             user = session.query(User).first()
@@ -263,7 +263,7 @@ class DontDereferenceTest(ORMTest):
         self.assertEquals(query1(), [Address(email_address='joe@joesdomain.example')]  )
         self.assertEquals(query2(), [Address(email_address='joe@joesdomain.example')]  )
         self.assertEquals(query3(), [Address(email_address='joe@joesdomain.example')]  )
-        
-        
+
+
 if __name__ == '__main__':
     testenv.main()
index 79cc91d2c196df9d45bf9aa9b8731f18e71099ae..076c7b76b8e787143f08089a0ee6c86be1fc1e39 100644 (file)
@@ -77,7 +77,8 @@ class ABCTest(ORMTest):
                 C(cdata='c2', bdata='c2', adata='c2'),
             ] == sess.query(C).all()
 
-        test_roundtrip.__name__ = 'test_%s' % fetchtype
+        test_roundtrip = _function_named(
+            test_roundtrip, 'test_%s' % fetchtype)
         return test_roundtrip
 
     test_union = make_test('union')
index 39b9fb91670c236c02acd2c1ecdbfe080ba1e527..3f3bf4bdb760c2cee4ebb3bf98abd7a7809fefd6 100644 (file)
@@ -215,7 +215,7 @@ class GetTest(ORMTest):
 
                 self.assert_sql_count(testing.db, go, 3)
 
-        test_get.__name__ = name
+        test_get = _function_named(test_get, name)
         return test_get
 
     test_get_polymorphic = create_test(True, 'test_get_polymorphic')
index b5c5096f5bb5c11dce427f19b9ec9836476a2417..621f9639f41a6c41f7bc35c025b55bd2f9652011 100644 (file)
@@ -207,7 +207,8 @@ def generate_round_trip_test(use_unions=False, use_joins=False):
         print [page, page2, page3]
         assert repr(p.issues[0].locations[0].magazine.pages) == repr([page, page2, page3]), repr(p.issues[0].locations[0].magazine.pages)
 
-    test_roundtrip.__name__ = "test_%s" % (not use_union and (use_joins and "joins" or "select") or "unions")
+    test_roundtrip = _function_named(
+        test_roundtrip, "test_%s" % (not use_union and (use_joins and "joins" or "select") or "unions"))
     setattr(MagazineTest, test_roundtrip.__name__, test_roundtrip)
 
 for (use_union, use_join) in [(True, False), (False, True), (False, False)]:
index 5f0e74fa351fd0d875bf2b0dcaefb1f3cd3e880f..faee633601b3784676fb7ed8418d52935a275182 100644 (file)
@@ -342,13 +342,13 @@ def generate_round_trip_test(include_base=False, lazy_relation=True, redefine_co
         session.delete(c)
         session.flush()
 
-    test_roundtrip.__name__ = "test_%s%s%s%s%s" % (
-        (lazy_relation and "lazy" or "eager"),
-        (include_base and "_inclbase" or ""),
-        (redefine_colprop and "_redefcol" or ""),
-        (polymorphic_fetch != 'union' and '_' + polymorphic_fetch or (use_literal_join and "_litjoin" or "")),
-        (use_outer_joins and '_outerjoins' or '')
-    )
+    test_roundtrip = _function_named(
+        test_roundtrip, "test_%s%s%s%s%s" % (
+          (lazy_relation and "lazy" or "eager"),
+          (include_base and "_inclbase" or ""),
+          (redefine_colprop and "_redefcol" or ""),
+          (polymorphic_fetch != 'union' and '_' + polymorphic_fetch or (use_literal_join and "_litjoin" or "")),
+          (use_outer_joins and '_outerjoins' or '')))
     setattr(RoundTripTest, test_roundtrip.__name__, test_roundtrip)
 
 for include_base in [True, False]:
index 96a8ddbf65b41fc3120f12779a81783216669528..0f0fff6729a4dbc0c38a63be7f5ccf9dbe94d0ef 100644 (file)
@@ -289,7 +289,9 @@ def generate_test(jointype="join1", usedata=False):
             assert p.data.data == 'ps data'
             assert m.data.data == 'ms data'
 
-    do_test.__name__ = 'test_relationonbaseclass_%s_%s' % (jointype, data and "nodata" or "data")
+    do_test = _function_named(
+        do_test, 'test_relationonbaseclass_%s_%s' % (
+        jointype, data and "nodata" or "data"))
     return do_test
 
 for jointype in ["join1", "join2", "join3", "join4"]:
index e23f107ee2c66e79d09d988530709c4d37b46a26..40f70bf6271f1cdc7daa84e6c82997bd1d7e4309 100644 (file)
@@ -999,13 +999,17 @@ class CustomCollectionsTest(ORMTest):
             pass
 
         class MyCollection(object):
-            def __init__(self): self.data = []
+            def __init__(self):
+                self.data = []
             @collection.appender
-            def append(self, value): self.data.append(value)
+            def append(self, value):
+                self.data.append(value)
             @collection.remover
-            def remove(self, value): self.data.remove(value)
+            def remove(self, value):
+                self.data.remove(value)
             @collection.iterator
-            def __iter__(self): return iter(self.data)
+            def __iter__(self):
+                return iter(self.data)
 
         mapper(Parent, sometable, properties={
             'children':relation(Child, collection_class=MyCollection)
index d7daaddf8ba156733466deacab308b92f66595f4..46852191a26336df73892cd2c641d0a9873ee1f6 100644 (file)
@@ -11,6 +11,7 @@ from testlib.testing import rowset
 from testlib.testing import PersistTest, AssertMixin, ORMTest, SQLCompileTest
 import testlib.profiling as profiling
 import testlib.engines as engines
+from testlib.compat import set, sorted, _function_named
 
 
 __all__ = ('testing',
@@ -18,4 +19,5 @@ __all__ = ('testing',
            'Table', 'Column',
            'rowset',
            'PersistTest', 'AssertMixin', 'ORMTest', 'SQLCompileTest',
-           'profiling', 'engines')
+           'profiling', 'engines',
+           'set', 'sorted', '_function_named')
diff --git a/test/testlib/compat.py b/test/testlib/compat.py
new file mode 100644 (file)
index 0000000..590bf50
--- /dev/null
@@ -0,0 +1,22 @@
+import new
+
+__all__ = 'set', 'sorted', '_function_named'
+
+try:
+    set = set
+except NameError:
+    from sets import Set as set
+
+try:
+    sorted = sorted
+except NameError:
+    def sorted(iterable):
+        return list(iterable).sort()
+
+def _function_named(fn, newname):
+    try:
+        fn.__name__ = newname
+    except:
+        fn = new.function(fn.func_code, fn.func_globals, newname,
+                          fn.func_defaults, fn.func_closure)
+    return fn
index b576a15364a08ef407665bbccb310e67f061bcd6..8cb321597245b1b1f7527b6f61e6dfacf57be9eb 100644 (file)
@@ -1,14 +1,15 @@
 import sys, weakref
 from testlib import config
+from testlib.compat import *
 
 
 class ConnectionKiller(object):
     def __init__(self):
         self.proxy_refs = weakref.WeakKeyDictionary()
-        
+
     def checkout(self, dbapi_con, con_record, con_proxy):
         self.proxy_refs[con_proxy] = True
-        
+
     def _apply_all(self, methods):
         for rec in self.proxy_refs:
             if rec is not None and rec.is_valid:
@@ -29,12 +30,12 @@ class ConnectionKiller(object):
 
     def close_all(self):
         self._apply_all(('rollback', 'close'))
-        
+
     def assert_all_closed(self):
         for rec in self.proxy_refs:
             if rec.is_valid:
                 assert False
-        
+
 testing_reaper = ConnectionKiller()
 
 def assert_conns_closed(fn):
@@ -43,9 +44,8 @@ def assert_conns_closed(fn):
             fn(*args, **kw)
         finally:
             testing_reaper.assert_all_closed()
-    decorated.__name__ = fn.__name__
-    return decorated
-    
+    return _function_named(decorated, fn.__name__)
+
 def rollback_open_connections(fn):
     """Decorator that rolls back all open connections after fn execution."""
 
@@ -54,8 +54,7 @@ def rollback_open_connections(fn):
             fn(*args, **kw)
         finally:
             testing_reaper.rollback_all()
-    decorated.__name__ = fn.__name__
-    return decorated
+    return _function_named(decorated, fn.__name__)
 
 def close_open_connections(fn):
     """Decorator that closes all connections after fn execution."""
@@ -65,14 +64,13 @@ def close_open_connections(fn):
             fn(*args, **kw)
         finally:
             testing_reaper.close_all()
-    decorated.__name__ = fn.__name__
-    return decorated
+    return _function_named(decorated, fn.__name__)
 
 class ReconnectFixture(object):
     def __init__(self, dbapi):
         self.dbapi = dbapi
         self.connections = []
-    
+
     def __getattr__(self, key):
         return getattr(self.dbapi, key)
 
@@ -85,17 +83,17 @@ class ReconnectFixture(object):
         for c in list(self.connections):
             c.close()
         self.connections = []
-        
+
 def reconnecting_engine(url=None, options=None):
     url = url or config.db_url
     dbapi = config.db.dialect.dbapi
     engine = testing_engine(url, {'module':ReconnectFixture(dbapi)})
     engine.test_shutdown = engine.dialect.dbapi.shutdown
     return engine
-    
+
 def testing_engine(url=None, options=None):
     """Produce an engine configured by --options with optional overrides."""
-    
+
     from sqlalchemy import create_engine
     from testlib.testing import ExecutionContextWrapper
 
@@ -133,5 +131,3 @@ def utf8_engine(url=None, options=None):
             url = str(url)
 
     return testing_engine(url, options)
-
-
index ac7ca84d7e20552a5198727eca7cb359d7522fe3..8867d016923ce8cc30d2d8174a82ac4a5d80dba2 100644 (file)
@@ -2,6 +2,7 @@
 
 import os, sys
 from testlib.config import parser, post_configure
+from testlib.compat import *
 import testlib.config
 
 __all__ = 'profiled', 'function_call_count'
index 4ec92cc8bb1d15cf9bd05afdaebe5828288a1ad8..33b1b20db9aaddbf087407b6892dceda5b7cddfc 100644 (file)
@@ -135,7 +135,10 @@ def data():
 
 class BaseObject(object):
     def __repr__(self):
-        return "%s(%s)" % (self.__class__.__name__, ",".join("%s=%s" % (k, repr(v)) for k, v in self.__dict__.iteritems() if k[0] != '_'))
+        return "%s(%s)" % (self.__class__.__name__,
+                           ",".join(["%s=%s" % (k, repr(v))
+                                     for k, v in self.__dict__.iteritems()
+                                     if k[0] != '_']))
 
 class User(BaseObject):
     def __init__(self):
index 1b5a55f047a25c6529319113a275427db9ae7028..b05795efdf390c47d3a0eff1c58438c0adf15a00 100644 (file)
@@ -2,9 +2,11 @@
 
 # monkeypatches unittest.TestLoader.suiteClass at import time
 
-import itertools, unittest, re, sys, os, operator, warnings
+import itertools, os, operator, re, sys, unittest, warnings
 from cStringIO import StringIO
 import testlib.config as config
+from testlib.compat import *
+
 sql, MetaData, clear_mappers, Session, util = None, None, None, None, None
 sa_exceptions = None
 
@@ -23,6 +25,7 @@ _ops = { '<': operator.lt,
 # sugar ('testing.db'); set here by config() at runtime
 db = None
 
+
 def fails_on(*dbs):
     """Mark a test as expected to fail on one or more database implementations.
 
@@ -49,11 +52,7 @@ def fails_on(*dbs):
                     raise AssertionError(
                         "Unexpected success for '%s' on DB implementation '%s'" %
                         (fn_name, config.db.name))
-        try:
-            maybe.__name__ = fn_name
-        except:
-            pass
-        return maybe
+        return _function_named(maybe, fn_name)
     return decorate
 
 def fails_on_everything_except(*dbs):
@@ -80,11 +79,7 @@ def fails_on_everything_except(*dbs):
                     raise AssertionError(
                         "Unexpected success for '%s' on DB implementation '%s'" %
                         (fn_name, config.db.name))
-        try:
-            maybe.__name__ = fn_name
-        except:
-            pass
-        return maybe
+        return _function_named(maybe, fn_name)
     return decorate
 
 def unsupported(*dbs):
@@ -103,17 +98,14 @@ def unsupported(*dbs):
                 return True
             else:
                 return fn(*args, **kw)
-        try:
-            maybe.__name__ = fn_name
-        except:
-            pass
-        return maybe
+        return _function_named(maybe, fn_name)
     return decorate
 
 def exclude(db, op, spec):
     """Mark a test as unsupported by specific database server versions.
 
     Stackable, both with other excludes and other decorators. Examples::
+
       # Not supported by mydb versions less than 1, 0
       @exclude('mydb', '<', (1,0))
       # Other operators work too
@@ -130,11 +122,7 @@ def exclude(db, op, spec):
                 return True
             else:
                 return fn(*args, **kw)
-        try:
-            maybe.__name__ = fn_name
-        except:
-            pass
-        return maybe
+        return _function_named(maybe, fn_name)
     return decorate
 
 def _is_excluded(db, op, spec):
@@ -204,11 +192,7 @@ def emits_warning(*messages):
                 return fn(*args, **kw)
             finally:
                 resetwarnings()
-        try:
-            safe.__name__ = fn.__name__
-        except:
-            pass
-        return safe
+        return _function_named(safe, fn.__name__)
     return decorate
 
 def uses_deprecated(*messages):
@@ -247,11 +231,7 @@ def uses_deprecated(*messages):
                 return fn(*args, **kw)
             finally:
                 resetwarnings()
-        try:
-            safe.__name__ = fn.__name__
-        except:
-            pass
-        return safe
+        return _function_named(safe, fn.__name__)
     return decorate
 
 def resetwarnings():
index 04dd33ac51557269c31f5ff6a6d67738323cdbc1..9e48a202f032e7380db8700244a9da9f7a29c8bb 100644 (file)
@@ -1,6 +1,7 @@
 __all__ = ['Blog', 'Post', 'Topic', 'TopicAssociation', 'Comment']
 
 import datetime
+from testlib.compat import *
 
 class Blog(object):
     def __init__(self, owner=None):