self._process_key_switches(deplist, uowcommit)
def _process_key_switches(self, deplist, uowcommit):
- switchers = util.Set(s for s in deplist if self._pks_changed(uowcommit, s))
+ switchers = util.Set([s for s in deplist if self._pks_changed(uowcommit, s)])
if switchers:
# yes, we're doing a linear search right now through the UOW. only
# takes effect when primary key values have actually changed.
if tuples:
rows = util.OrderedSet()
for row in fetch:
- rows.add(tuple(proc(context, row) for proc in process))
+ rows.add(tuple([proc(context, row) for proc in process]))
else:
rows = util.UniqueAppender([])
for row in fetch:
import unittest
from sqlalchemy import util, sql, exceptions
from testlib import *
-
+from testlib import sorted
class OrderedDictTest(PersistTest):
def test_odict(self):
sess.delete(u)
sess.close()
+
def create_backref_test(autoflush, saveuser):
def test_backref(self):
mapper(User, users, properties={
sess.flush()
self.assert_(list(u.addresses) == [])
- test_backref.__name__ = "test%s%s" % (
- (autoflush and "_autoflush" or ""),
- (saveuser and "_saveuser" or "_savead"),
- )
+ test_backref = _function_named(
+ test_backref, "test%s%s" % ((autoflush and "_autoflush" or ""),
+ (saveuser and "_saveuser" or "_savead")))
setattr(FlushTest, test_backref.__name__, test_backref)
for autoflush in (False, True):
class DontDereferenceTest(ORMTest):
def define_tables(self, metadata):
global users_table, addresses_table
-
+
users_table = Table('users', metadata,
Column('id', Integer, primary_key=True),
Column('name', String(40)),
session.save(user)
session.flush()
session.clear()
-
+
def query1():
session = create_session(metadata.bind)
user = session.query(User).first()
self.assertEquals(query1(), [Address(email_address='joe@joesdomain.example')] )
self.assertEquals(query2(), [Address(email_address='joe@joesdomain.example')] )
self.assertEquals(query3(), [Address(email_address='joe@joesdomain.example')] )
-
-
+
+
if __name__ == '__main__':
testenv.main()
C(cdata='c2', bdata='c2', adata='c2'),
] == sess.query(C).all()
- test_roundtrip.__name__ = 'test_%s' % fetchtype
+ test_roundtrip = _function_named(
+ test_roundtrip, 'test_%s' % fetchtype)
return test_roundtrip
test_union = make_test('union')
self.assert_sql_count(testing.db, go, 3)
- test_get.__name__ = name
+ test_get = _function_named(test_get, name)
return test_get
test_get_polymorphic = create_test(True, 'test_get_polymorphic')
print [page, page2, page3]
assert repr(p.issues[0].locations[0].magazine.pages) == repr([page, page2, page3]), repr(p.issues[0].locations[0].magazine.pages)
- test_roundtrip.__name__ = "test_%s" % (not use_union and (use_joins and "joins" or "select") or "unions")
+ test_roundtrip = _function_named(
+ test_roundtrip, "test_%s" % (not use_union and (use_joins and "joins" or "select") or "unions"))
setattr(MagazineTest, test_roundtrip.__name__, test_roundtrip)
for (use_union, use_join) in [(True, False), (False, True), (False, False)]:
session.delete(c)
session.flush()
- test_roundtrip.__name__ = "test_%s%s%s%s%s" % (
- (lazy_relation and "lazy" or "eager"),
- (include_base and "_inclbase" or ""),
- (redefine_colprop and "_redefcol" or ""),
- (polymorphic_fetch != 'union' and '_' + polymorphic_fetch or (use_literal_join and "_litjoin" or "")),
- (use_outer_joins and '_outerjoins' or '')
- )
+ test_roundtrip = _function_named(
+ test_roundtrip, "test_%s%s%s%s%s" % (
+ (lazy_relation and "lazy" or "eager"),
+ (include_base and "_inclbase" or ""),
+ (redefine_colprop and "_redefcol" or ""),
+ (polymorphic_fetch != 'union' and '_' + polymorphic_fetch or (use_literal_join and "_litjoin" or "")),
+ (use_outer_joins and '_outerjoins' or '')))
setattr(RoundTripTest, test_roundtrip.__name__, test_roundtrip)
for include_base in [True, False]:
assert p.data.data == 'ps data'
assert m.data.data == 'ms data'
- do_test.__name__ = 'test_relationonbaseclass_%s_%s' % (jointype, data and "nodata" or "data")
+ do_test = _function_named(
+ do_test, 'test_relationonbaseclass_%s_%s' % (
+ jointype, data and "nodata" or "data"))
return do_test
for jointype in ["join1", "join2", "join3", "join4"]:
pass
class MyCollection(object):
- def __init__(self): self.data = []
+ def __init__(self):
+ self.data = []
@collection.appender
- def append(self, value): self.data.append(value)
+ def append(self, value):
+ self.data.append(value)
@collection.remover
- def remove(self, value): self.data.remove(value)
+ def remove(self, value):
+ self.data.remove(value)
@collection.iterator
- def __iter__(self): return iter(self.data)
+ def __iter__(self):
+ return iter(self.data)
mapper(Parent, sometable, properties={
'children':relation(Child, collection_class=MyCollection)
from testlib.testing import PersistTest, AssertMixin, ORMTest, SQLCompileTest
import testlib.profiling as profiling
import testlib.engines as engines
+from testlib.compat import set, sorted, _function_named
__all__ = ('testing',
'Table', 'Column',
'rowset',
'PersistTest', 'AssertMixin', 'ORMTest', 'SQLCompileTest',
- 'profiling', 'engines')
+ 'profiling', 'engines',
+ 'set', 'sorted', '_function_named')
--- /dev/null
+import new
+
+__all__ = 'set', 'sorted', '_function_named'
+
+try:
+ set = set
+except NameError:
+ from sets import Set as set
+
+try:
+ sorted = sorted
+except NameError:
+ def sorted(iterable):
+ return list(iterable).sort()
+
+def _function_named(fn, newname):
+ try:
+ fn.__name__ = newname
+ except:
+ fn = new.function(fn.func_code, fn.func_globals, newname,
+ fn.func_defaults, fn.func_closure)
+ return fn
import sys, weakref
from testlib import config
+from testlib.compat import *
class ConnectionKiller(object):
def __init__(self):
self.proxy_refs = weakref.WeakKeyDictionary()
-
+
def checkout(self, dbapi_con, con_record, con_proxy):
self.proxy_refs[con_proxy] = True
-
+
def _apply_all(self, methods):
for rec in self.proxy_refs:
if rec is not None and rec.is_valid:
def close_all(self):
self._apply_all(('rollback', 'close'))
-
+
def assert_all_closed(self):
for rec in self.proxy_refs:
if rec.is_valid:
assert False
-
+
testing_reaper = ConnectionKiller()
def assert_conns_closed(fn):
fn(*args, **kw)
finally:
testing_reaper.assert_all_closed()
- decorated.__name__ = fn.__name__
- return decorated
-
+ return _function_named(decorated, fn.__name__)
+
def rollback_open_connections(fn):
"""Decorator that rolls back all open connections after fn execution."""
fn(*args, **kw)
finally:
testing_reaper.rollback_all()
- decorated.__name__ = fn.__name__
- return decorated
+ return _function_named(decorated, fn.__name__)
def close_open_connections(fn):
"""Decorator that closes all connections after fn execution."""
fn(*args, **kw)
finally:
testing_reaper.close_all()
- decorated.__name__ = fn.__name__
- return decorated
+ return _function_named(decorated, fn.__name__)
class ReconnectFixture(object):
def __init__(self, dbapi):
self.dbapi = dbapi
self.connections = []
-
+
def __getattr__(self, key):
return getattr(self.dbapi, key)
for c in list(self.connections):
c.close()
self.connections = []
-
+
def reconnecting_engine(url=None, options=None):
url = url or config.db_url
dbapi = config.db.dialect.dbapi
engine = testing_engine(url, {'module':ReconnectFixture(dbapi)})
engine.test_shutdown = engine.dialect.dbapi.shutdown
return engine
-
+
def testing_engine(url=None, options=None):
"""Produce an engine configured by --options with optional overrides."""
-
+
from sqlalchemy import create_engine
from testlib.testing import ExecutionContextWrapper
url = str(url)
return testing_engine(url, options)
-
-
import os, sys
from testlib.config import parser, post_configure
+from testlib.compat import *
import testlib.config
__all__ = 'profiled', 'function_call_count'
class BaseObject(object):
def __repr__(self):
- return "%s(%s)" % (self.__class__.__name__, ",".join("%s=%s" % (k, repr(v)) for k, v in self.__dict__.iteritems() if k[0] != '_'))
+ return "%s(%s)" % (self.__class__.__name__,
+ ",".join(["%s=%s" % (k, repr(v))
+ for k, v in self.__dict__.iteritems()
+ if k[0] != '_']))
class User(BaseObject):
def __init__(self):
# monkeypatches unittest.TestLoader.suiteClass at import time
-import itertools, unittest, re, sys, os, operator, warnings
+import itertools, os, operator, re, sys, unittest, warnings
from cStringIO import StringIO
import testlib.config as config
+from testlib.compat import *
+
sql, MetaData, clear_mappers, Session, util = None, None, None, None, None
sa_exceptions = None
# sugar ('testing.db'); set here by config() at runtime
db = None
+
def fails_on(*dbs):
"""Mark a test as expected to fail on one or more database implementations.
raise AssertionError(
"Unexpected success for '%s' on DB implementation '%s'" %
(fn_name, config.db.name))
- try:
- maybe.__name__ = fn_name
- except:
- pass
- return maybe
+ return _function_named(maybe, fn_name)
return decorate
def fails_on_everything_except(*dbs):
raise AssertionError(
"Unexpected success for '%s' on DB implementation '%s'" %
(fn_name, config.db.name))
- try:
- maybe.__name__ = fn_name
- except:
- pass
- return maybe
+ return _function_named(maybe, fn_name)
return decorate
def unsupported(*dbs):
return True
else:
return fn(*args, **kw)
- try:
- maybe.__name__ = fn_name
- except:
- pass
- return maybe
+ return _function_named(maybe, fn_name)
return decorate
def exclude(db, op, spec):
"""Mark a test as unsupported by specific database server versions.
Stackable, both with other excludes and other decorators. Examples::
+
# Not supported by mydb versions less than 1, 0
@exclude('mydb', '<', (1,0))
# Other operators work too
return True
else:
return fn(*args, **kw)
- try:
- maybe.__name__ = fn_name
- except:
- pass
- return maybe
+ return _function_named(maybe, fn_name)
return decorate
def _is_excluded(db, op, spec):
return fn(*args, **kw)
finally:
resetwarnings()
- try:
- safe.__name__ = fn.__name__
- except:
- pass
- return safe
+ return _function_named(safe, fn.__name__)
return decorate
def uses_deprecated(*messages):
return fn(*args, **kw)
finally:
resetwarnings()
- try:
- safe.__name__ = fn.__name__
- except:
- pass
- return safe
+ return _function_named(safe, fn.__name__)
return decorate
def resetwarnings():
__all__ = ['Blog', 'Post', 'Topic', 'TopicAssociation', 'Comment']
import datetime
+from testlib.compat import *
class Blog(object):
def __init__(self, owner=None):