Unit tests are run using nose. Nose is available at::
- http://pypi.python.org/pypi/nose/
+ https://pypi.python.org/pypi/nose/
SQLAlchemy implements a nose plugin that must be present when tests are run.
This plugin is invoked when the test runner script provided with
SQLAlchemy is used.
+The test suite as of version 0.8.2 also requires the mock library. While
+mock is part of the Python standard library as of 3.3, previous versions
+will need to have it installed, and is available at::
+
+ https://pypi.python.org/pypi/mock
+
**NOTE:** - the nose plugin is no longer installed by setuptools as of
version 0.7 ! Use "python setup.py test" or "./sqla_nose.py".
.. changelog::
:version: 0.8.2
+ .. change::
+ :tags: requirements
+
+ The Python `mock <https://pypi.python.org/pypi/mock>`_ library
+ is now required in order to run the unit test suite. While part
+ of the standard library as of Python 3.3, previous Python installations
+ will need to install this in order to run unit tests or to
+ use the ``sqlalchemy.testing`` package for external dialects.
+
.. change::
:tags: bug, orm
:tickets: 2750
crashes = skip
from .config import db, requirements as requires
+
+from . import mock
--- /dev/null
+"""Import stub for mock library.
+"""
+from __future__ import absolute_import
+from ..util import py33
+
+if py33:
+ from unittest.mock import MagicMock, Mock, call
+else:
+ try:
+ from mock import MagicMock, Mock, call
+ except ImportError:
+ raise ImportError(
+ "SQLAlchemy's test suite requires the "
+ "'mock' library as of 0.8.2.")
+
# the MIT License: http://www.opensource.org/licenses/mit-license.php
from .compat import callable, cmp, reduce, \
- threading, py3k, py2k, py3k_warning, jython, pypy, cpython, win32, set_types, \
+ threading, py3k, py33, py2k, py3k_warning, jython, pypy, cpython, win32, \
+ set_types, \
pickle, dottedgetter, parse_qsl, namedtuple, next, WeakSet, reraise, \
raise_from_cause, u, b, ue, string_types, text_type, int_types
except ImportError:
import dummy_threading as threading
+py33 = sys.version_info >= (3, 3)
py32 = sys.version_info >= (3, 2)
py3k_warning = getattr(sys, 'py3kwarning', False) or sys.version_info >= (3, 0)
py3k = sys.version_info >= (3, 0)
license="MIT License",
cmdclass=cmdclass,
- tests_require=['nose >= 0.11'],
+ tests_require=['nose >= 0.11', 'mock'],
test_suite="sqla_nose",
long_description=readme,
classifiers=[
from sqlalchemy.testing import fixtures, AssertsExecutionResults, profiling
from sqlalchemy import testing
from sqlalchemy.testing import eq_
+from sqlalchemy.engine.result import RowProxy
+import sys
+
NUM_FIELDS = 10
NUM_RECORDS = 1000
__requires__ = 'cpython',
def _rowproxy_fixture(self, keys, processors, row):
- from sqlalchemy.engine.result import RowProxy
class MockMeta(object):
def __init__(self):
pass
return RowProxy(metadata, row, processors, keymap)
def _test_getitem_value_refcounts(self, seq_factory):
- import sys
col1, col2 = object(), object()
def proc1(value):
return value
from sqlalchemy import event, exc
from sqlalchemy.testing import fixtures
from sqlalchemy.testing.util import gc_collect
+from sqlalchemy.testing.mock import Mock, call
+
class EventsTest(fixtures.TestBase):
"""Test class- and instance-level event registration."""
def test_lis_subcalss_lis(self):
@event.listens_for(self.TargetOne, "event_one")
def handler1(x, y):
- print 'handler1'
+ pass
class SubTarget(self.TargetOne):
pass
def test_lis_multisub_lis(self):
@event.listens_for(self.TargetOne, "event_one")
def handler1(x, y):
- print 'handler1'
+ pass
class SubTarget(self.TargetOne):
pass
event._remove_dispatcher(self.Target.__dict__['dispatch'].events)
def test_listen_override(self):
- result = []
- def listen_one(x):
- result.append(x)
-
- def listen_two(x, y):
- result.append((x, y))
+ listen_one = Mock()
+ listen_two = Mock()
event.listen(self.Target, "event_one", listen_one, add=True)
event.listen(self.Target, "event_one", listen_two)
t1.dispatch.event_one(5, 7)
t1.dispatch.event_one(10, 5)
- eq_(result,
- [
- 12, (5, 7), 15, (10, 5)
- ]
+ eq_(
+ listen_one.mock_calls,
+ [call(12), call(15)]
+ )
+ eq_(
+ listen_two.mock_calls,
+ [call(5, 7), call(10, 5)]
)
class PropagateTest(fixtures.TestBase):
def test_propagate(self):
- result = []
- def listen_one(target, arg):
- result.append((target, arg))
-
- def listen_two(target, arg):
- result.append((target, arg))
+ listen_one = Mock()
+ listen_two = Mock()
t1 = self.Target()
t2.dispatch.event_one(t2, 1)
t2.dispatch.event_two(t2, 2)
- eq_(result, [(t2, 1)])
+
+ eq_(
+ listen_one.mock_calls,
+ [call(t2, 1)]
+ )
+ eq_(
+ listen_two.mock_calls,
+ []
+ )
class JoinTest(fixtures.TestBase):
def setUp(self):
if 'dispatch' in cls.__dict__:
event._remove_dispatcher(cls.__dict__['dispatch'].events)
- def _listener(self):
- canary = []
- def listen(target, arg):
- canary.append((target, arg))
- return listen, canary
-
def test_neither(self):
element = self.TargetFactory().create()
element.run_event(1)
element.run_event(3)
def test_parent_class_only(self):
- _listener, canary = self._listener()
+ l1 = Mock()
- event.listen(self.TargetFactory, "event_one", _listener)
+ event.listen(self.TargetFactory, "event_one", l1)
element = self.TargetFactory().create()
element.run_event(1)
element.run_event(2)
element.run_event(3)
eq_(
- canary,
- [(element, 1), (element, 2), (element, 3)]
+ l1.mock_calls,
+ [call(element, 1), call(element, 2), call(element, 3)]
)
def test_parent_class_child_class(self):
- l1, c1 = self._listener()
- l2, c2 = self._listener()
+ l1 = Mock()
+ l2 = Mock()
event.listen(self.TargetFactory, "event_one", l1)
event.listen(self.TargetElement, "event_one", l2)
element.run_event(2)
element.run_event(3)
eq_(
- c1,
- [(element, 1), (element, 2), (element, 3)]
+ l1.mock_calls,
+ [call(element, 1), call(element, 2), call(element, 3)]
)
eq_(
- c2,
- [(element, 1), (element, 2), (element, 3)]
+ l2.mock_calls,
+ [call(element, 1), call(element, 2), call(element, 3)]
)
def test_parent_class_child_instance_apply_after(self):
- l1, c1 = self._listener()
- l2, c2 = self._listener()
+ l1 = Mock()
+ l2 = Mock()
event.listen(self.TargetFactory, "event_one", l1)
element = self.TargetFactory().create()
element.run_event(3)
eq_(
- c1,
- [(element, 1), (element, 2), (element, 3)]
+ l1.mock_calls,
+ [call(element, 1), call(element, 2), call(element, 3)]
)
eq_(
- c2,
- [(element, 2), (element, 3)]
+ l2.mock_calls,
+ [call(element, 2), call(element, 3)]
)
def test_parent_class_child_instance_apply_before(self):
- l1, c1 = self._listener()
- l2, c2 = self._listener()
+ l1 = Mock()
+ l2 = Mock()
event.listen(self.TargetFactory, "event_one", l1)
element = self.TargetFactory().create()
element.run_event(3)
eq_(
- c1,
- [(element, 1), (element, 2), (element, 3)]
+ l1.mock_calls,
+ [call(element, 1), call(element, 2), call(element, 3)]
)
eq_(
- c2,
- [(element, 1), (element, 2), (element, 3)]
+ l2.mock_calls,
+ [call(element, 1), call(element, 2), call(element, 3)]
)
def test_parent_instance_child_class_apply_before(self):
- l1, c1 = self._listener()
- l2, c2 = self._listener()
+ l1 = Mock()
+ l2 = Mock()
event.listen(self.TargetElement, "event_one", l2)
element.run_event(3)
eq_(
- c1,
- [(element, 1), (element, 2), (element, 3)]
+ l1.mock_calls,
+ [call(element, 1), call(element, 2), call(element, 3)]
)
eq_(
- c2,
- [(element, 1), (element, 2), (element, 3)]
+ l2.mock_calls,
+ [call(element, 1), call(element, 2), call(element, 3)]
)
+
def test_parent_instance_child_class_apply_after(self):
- l1, c1 = self._listener()
- l2, c2 = self._listener()
+ l1 = Mock()
+ l2 = Mock()
event.listen(self.TargetElement, "event_one", l2)
# this can be changed to be "live" at the cost
# of performance.
eq_(
- c1,
- []
- #(element, 2), (element, 3)]
+ l1.mock_calls, []
)
eq_(
- c2,
- [(element, 1), (element, 2), (element, 3)]
+ l2.mock_calls,
+ [call(element, 1), call(element, 2), call(element, 3)]
)
def test_parent_instance_child_instance_apply_before(self):
- l1, c1 = self._listener()
- l2, c2 = self._listener()
+ l1 = Mock()
+ l2 = Mock()
factory = self.TargetFactory()
event.listen(factory, "event_one", l1)
element.run_event(3)
eq_(
- c1,
- [(element, 1), (element, 2), (element, 3)]
+ l1.mock_calls,
+ [call(element, 1), call(element, 2), call(element, 3)]
)
eq_(
- c2,
- [(element, 1), (element, 2), (element, 3)]
+ l2.mock_calls,
+ [call(element, 1), call(element, 2), call(element, 3)]
)
def test_parent_events_child_no_events(self):
- l1, c1 = self._listener()
+ l1 = Mock()
factory = self.TargetFactory()
event.listen(self.TargetElement, "event_one", l1)
element.run_event(3)
eq_(
- c1,
- [(element, 1), (element, 2), (element, 3)]
+ l1.mock_calls,
+ [call(element, 1), call(element, 2), call(element, 3)]
)
from sqlalchemy.dialects.postgresql import base as postgresql
import logging
import logging.handlers
+from sqlalchemy.testing.mock import Mock
class MiscTest(fixtures.TestBase, AssertsExecutionResults, AssertsCompiledSQL):
'The JDBC driver handles the version parsing')
def test_version_parsing(self):
-
- class MockConn(object):
-
- def __init__(self, res):
- self.res = res
-
- def execute(self, str):
- return self
-
- def scalar(self):
- return self.res
-
+ def mock_conn(res):
+ return Mock(
+ execute=Mock(
+ return_value=Mock(scalar=Mock(return_value=res))
+ )
+ )
for string, version in \
[('PostgreSQL 8.3.8 on i686-redhat-linux-gnu, compiled by '
('EnterpriseDB 9.1.2.2 on x86_64-unknown-linux-gnu, '
'compiled by gcc (GCC) 4.1.2 20080704 (Red Hat 4.1.2-50), '
'64-bit', (9, 1, 2))]:
- eq_(testing.db.dialect._get_server_version_info(MockConn(string)),
+ eq_(testing.db.dialect._get_server_version_info(mock_conn(string)),
version)
@testing.only_on('postgresql+psycopg2', 'psycopg2-specific feature')
from sqlalchemy.testing import eq_
from sqlalchemy.testing import engines
from sqlalchemy.testing import fixtures
-
-# TODO: we should probably build mock bases for
-# these to share with test_reconnect, test_parseconnect
-class MockDBAPI(object):
- paramstyle = 'qmark'
- def __init__(self):
- self.log = []
- def connect(self, *args, **kwargs):
- return MockConnection(self)
-
-class MockConnection(object):
- def __init__(self, parent):
- self.parent = parent
- def cursor(self):
- return MockCursor(self)
- def close(self):
- pass
- def rollback(self):
- pass
- def commit(self):
- pass
-
-class MockCursor(object):
- description = None
- rowcount = None
- def __init__(self, parent):
- self.parent = parent
- def execute(self, *args, **kwargs):
- if kwargs.get('direct', False):
- self.executedirect()
- else:
- self.parent.parent.log.append('execute')
- def executedirect(self, *args, **kwargs):
- self.parent.parent.log.append('executedirect')
- def close(self):
- pass
+from sqlalchemy.testing.mock import Mock
+
+def mock_dbapi():
+ return Mock(paramstyle='qmark',
+ connect=Mock(
+ return_value=Mock(
+ cursor=Mock(
+ return_value=Mock(
+ description=None,
+ rowcount=None)
+ )
+ )
+ )
+ )
class MxODBCTest(fixtures.TestBase):
def test_native_odbc_execute(self):
t1 = Table('t1', MetaData(), Column('c1', Integer))
- dbapi = MockDBAPI()
+ dbapi = mock_dbapi()
+
engine = engines.testing_engine('mssql+mxodbc://localhost',
options={'module': dbapi, '_initialize': False})
conn = engine.connect()
# crud: uses execute
-
conn.execute(t1.insert().values(c1='foo'))
conn.execute(t1.delete().where(t1.c.c1 == 'foo'))
- conn.execute(t1.update().where(t1.c.c1 == 'foo').values(c1='bar'
- ))
+ conn.execute(t1.update().where(t1.c.c1 == 'foo').values(c1='bar'))
# select: uses executedirect
-
conn.execute(t1.select())
# manual flagging
-
conn.execution_options(native_odbc_execute=True).\
execute(t1.select())
conn.execution_options(native_odbc_execute=False).\
- execute(t1.insert().values(c1='foo'
- ))
- eq_(dbapi.log, [
- 'executedirect',
- 'executedirect',
- 'executedirect',
- 'executedirect',
- 'execute',
- 'executedirect',
- ])
+ execute(t1.insert().values(c1='foo'))
+
+ eq_(
+ [c[2] for c in
+ dbapi.connect.return_value.cursor.return_value.execute.mock_calls],
+ [{'direct': True}, {'direct': True}, {'direct': True},
+ {'direct': True}, {'direct': False}, {'direct': True}]
+ )
from sqlalchemy.engine import default
from sqlalchemy import MetaData, Table, Column, Integer, Sequence
from sqlalchemy import schema
+from sqlalchemy.testing.mock import Mock
class EmitDDLTest(fixtures.TestBase):
def _mock_connection(self, item_exists):
- _canary = []
+ def has_item(connection, name, schema):
+ return item_exists(name)
- class MockDialect(default.DefaultDialect):
- supports_sequences = True
-
- def has_table(self, connection, name, schema):
- return item_exists(name)
-
- def has_sequence(self, connection, name, schema):
- return item_exists(name)
-
- class MockConnection(object):
- dialect = MockDialect()
- canary = _canary
-
- def execute(self, item):
- _canary.append(item)
-
- return MockConnection()
+ return Mock(dialect=Mock(
+ supports_sequences=True,
+ has_table=Mock(side_effect=has_item),
+ has_sequence=Mock(side_effect=has_item)
+ )
+ )
def _mock_create_fixture(self, checkfirst, tables,
item_exists=lambda item: False):
def _assert_ddl(self, ddl_cls, elements, generator, argument):
generator.traverse_single(argument)
- for c in generator.connection.canary:
+ for call_ in generator.connection.execute.mock_calls:
+ c = call_[1][0]
assert isinstance(c, ddl_cls)
assert c.element in elements, "element %r was not expected"\
% c.element
from sqlalchemy.engine import result as _result, default
from sqlalchemy.engine.base import Connection, Engine
from sqlalchemy.testing import fixtures
-import StringIO
+from sqlalchemy.testing.mock import Mock, call
users, metadata, users_autoinc = None, None, None
class ExecuteTest(fixtures.TestBase):
def test_transaction_engine_ctx_begin_fails(self):
engine = engines.testing_engine()
- class MockConnection(Connection):
- closed = False
- def begin(self):
- raise Exception("boom")
-
- def close(self):
- MockConnection.closed = True
- engine._connection_cls = MockConnection
- fn = self._trans_fn()
+
+ mock_connection = Mock(
+ return_value=Mock(
+ begin=Mock(side_effect=Exception("boom"))
+ )
+ )
+ engine._connection_cls = mock_connection
assert_raises(
Exception,
engine.begin
)
- assert MockConnection.closed
+
+ eq_(
+ mock_connection.return_value.close.mock_calls,
+ [call()]
+ )
def test_transaction_engine_ctx_rollback(self):
fn = self._trans_rollback_fn()
import sqlalchemy as tsa
from sqlalchemy.testing import fixtures
from sqlalchemy import testing
+from sqlalchemy.testing.mock import Mock
+
class ParseConnectTest(fixtures.TestBase):
def test_rfc1738(self):
every backend.
"""
- # pretend pysqlite throws the
- # "Cannot operate on a closed database." error
- # on connect. IRL we'd be getting Oracle's "shutdown in progress"
e = create_engine('sqlite://')
sqlite3 = e.dialect.dbapi
- class ThrowOnConnect(MockDBAPI):
- dbapi = sqlite3
- Error = sqlite3.Error
- ProgrammingError = sqlite3.ProgrammingError
- def connect(self, *args, **kw):
- raise sqlite3.ProgrammingError("Cannot operate on a closed database.")
+
+ dbapi = MockDBAPI()
+ dbapi.Error = sqlite3.Error,
+ dbapi.ProgrammingError = sqlite3.ProgrammingError
+ dbapi.connect = Mock(side_effect=sqlite3.ProgrammingError(
+ "Cannot operate on a closed database."))
try:
- create_engine('sqlite://', module=ThrowOnConnect()).connect()
+ create_engine('sqlite://', module=dbapi).connect()
assert False
except tsa.exc.DBAPIError, de:
assert de.connection_invalidated
def dbapi(cls, **kw):
return MockDBAPI()
-class MockDBAPI(object):
- version_info = sqlite_version_info = 99, 9, 9
- sqlite_version = '99.9.9'
-
- def __init__(self, **kwargs):
- self.kwargs = kwargs
- self.paramstyle = 'named'
-
- def connect(self, *args, **kwargs):
- for k in self.kwargs:
+def MockDBAPI(**assert_kwargs):
+ connection = Mock(get_server_version_info=Mock(return_value='5.0'))
+ def connect(*args, **kwargs):
+ for k in assert_kwargs:
assert k in kwargs, 'key %s not present in dictionary' % k
- assert kwargs[k] == self.kwargs[k], \
- 'value %s does not match %s' % (kwargs[k],
- self.kwargs[k])
- return MockConnection()
-
-
-class MockConnection(object):
- def get_server_info(self):
- return '5.0'
-
- def close(self):
- pass
-
- def cursor(self):
- return MockCursor()
-
-class MockCursor(object):
- def close(self):
- pass
+ eq_(
+ kwargs[k], assert_kwargs[k]
+ )
+ return connection
+
+ return Mock(
+ sqlite_version_info=(99, 9, 9,),
+ version_info=(99, 9, 9,),
+ sqlite_version='99.9.9',
+ paramstyle='named',
+ connect=Mock(side_effect=connect)
+ )
mock_dbapi = MockDBAPI()
mock_sqlite_dbapi = msd = MockDBAPI()
import sqlalchemy as tsa
from sqlalchemy import testing
from sqlalchemy.testing.util import gc_collect, lazy_gc
-from sqlalchemy.testing import eq_, assert_raises
+from sqlalchemy.testing import eq_, assert_raises, is_not_
from sqlalchemy.testing.engines import testing_engine
from sqlalchemy.testing import fixtures
-mcid = 1
-class MockDBAPI(object):
- throw_error = False
- def connect(self, *args, **kwargs):
- if self.throw_error:
- raise Exception("couldnt connect !")
- delay = kwargs.pop('delay', 0)
- if delay:
- time.sleep(delay)
- return MockConnection()
-class MockConnection(object):
- closed = False
- def __init__(self):
- global mcid
- self.id = mcid
- mcid += 1
- def close(self):
- self.closed = True
- def rollback(self):
- pass
- def cursor(self):
- return MockCursor()
-class MockCursor(object):
- def execute(self, *args, **kw):
- pass
- def close(self):
- pass
+from sqlalchemy.testing.mock import Mock, call
+
+def MockDBAPI():
+ def cursor():
+ while True:
+ yield Mock()
+ def connect():
+ while True:
+ yield Mock(cursor=Mock(side_effect=cursor()))
+
+ return Mock(connect=Mock(side_effect=connect()))
class PoolTestBase(fixtures.TestBase):
def setup(self):
assert c4 is not c5
def test_manager_with_key(self):
- class NoKws(object):
- def connect(self, arg):
- return MockConnection()
- manager = pool.manage(NoKws(), use_threadlocal=True)
+ dbapi = MockDBAPI()
+ manager = pool.manage(dbapi, use_threadlocal=True)
c1 = manager.connect('foo.db', sa_pool_key="a")
c2 = manager.connect('foo.db', sa_pool_key="b")
assert c1.cursor() is not None
assert c1 is not c2
- assert c1 is c3
-
+ assert c1 is c3
+ eq_(dbapi.connect.mock_calls,
+ [
+ call("foo.db"),
+ call("foo.db"),
+ ]
+ )
def test_bad_args(self):
p = cls(creator=mock_dbapi.connect)
conn = p.connect()
conn.close()
- mock_dbapi.throw_error = True
+ mock_dbapi.connect.side_effect = Exception("error!")
p.dispose()
p.recreate()
self.assert_('foo2' in c.info)
c2 = p.connect()
- self.assert_(c.connection is not c2.connection)
- self.assert_(not c2.info)
- self.assert_('foo2' in c.info)
+ is_not_(c.connection, c2.connection)
+ assert not c2.info
+ assert 'foo2' in c.info
class PoolDialectTest(PoolTestBase):
def test_dispose_closes_pooled(self):
dbapi = MockDBAPI()
- def creator():
- return dbapi.connect()
- p = pool.QueuePool(creator=creator,
+ p = pool.QueuePool(creator=dbapi.connect,
pool_size=2, timeout=None,
max_overflow=0)
c1 = p.connect()
c2 = p.connect()
- conns = [c1.connection, c2.connection]
+ c1_con = c1.connection
+ c2_con = c2.connection
+
c1.close()
- eq_([c.closed for c in conns], [False, False])
+
+ eq_(c1_con.close.call_count, 0)
+ eq_(c2_con.close.call_count, 0)
+
p.dispose()
- eq_([c.closed for c in conns], [True, False])
+
+ eq_(c1_con.close.call_count, 1)
+ eq_(c2_con.close.call_count, 0)
# currently, if a ConnectionFairy is closed
# after the pool has been disposed, there's no
# immediately - it just gets returned to the
# pool normally...
c2.close()
- eq_([c.closed for c in conns], [True, False])
+ eq_(c1_con.close.call_count, 1)
+ eq_(c2_con.close.call_count, 0)
# ...and that's the one we'll get back next.
c3 = p.connect()
- assert c3.connection is conns[1]
+ assert c3.connection is c2_con
def test_no_overflow(self):
self._test_overflow(40, 0)
strong_refs.add(c.connection)
return c
- for j in xrange(5):
- conns = [_conn() for i in xrange(4)]
+ for j in range(5):
+ # open 4 conns at a time. each time this
+ # will yield two pooled connections + two
+ # overflow connections.
+ conns = [_conn() for i in range(4)]
for c in conns:
c.close()
- still_opened = len([c for c in strong_refs if not c.closed])
- eq_(still_opened, 2)
+ # doing that for a total of 5 times yields
+ # ten overflow connections closed plus the
+ # two pooled connections unclosed.
+
+ eq_(
+ set([c.close.call_count for c in strong_refs]),
+ set([1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0])
+ )
@testing.requires.predictable_gc
def test_weakref_kaboom(self):
dbapi, p = self._queuepool_dbapi_fixture(pool_size=1, max_overflow=0)
c1 = p.connect()
c1.detach()
- c_id = c1.connection.id
- c2 = p.connect()
- assert c2.connection.id != c1.connection.id
- dbapi.raise_error = True
- c2.invalidate()
- c2 = None
c2 = p.connect()
- assert c2.connection.id != c1.connection.id
- con = c1.connection
- assert not con.closed
+ eq_(dbapi.connect.mock_calls, [call("foo.db"), call("foo.db")])
+
+ c1_con = c1.connection
+ assert c1_con is not None
+ eq_(c1_con.close.call_count, 0)
c1.close()
- assert con.closed
+ eq_(c1_con.close.call_count, 1)
+
+ def test_detach_via_invalidate(self):
+ dbapi, p = self._queuepool_dbapi_fixture(pool_size=1, max_overflow=0)
+
+ c1 = p.connect()
+ c1_con = c1.connection
+ c1.invalidate()
+ assert c1.connection is None
+ eq_(c1_con.close.call_count, 1)
+
+ c2 = p.connect()
+ assert c2.connection is not c1_con
+ c2_con = c2.connection
+
+ c2.close()
+ eq_(c2_con.close.call_count, 0)
def test_threadfairy(self):
p = self._queuepool_fixture(pool_size=3, max_overflow=-1, use_threadlocal=True)
been called."""
dbapi = MockDBAPI()
- p = pool.SingletonThreadPool(creator=dbapi.connect,
- pool_size=3)
+
+ lock = threading.Lock()
+ def creator():
+ # the mock iterator isn't threadsafe...
+ with lock:
+ return dbapi.connect()
+ p = pool.SingletonThreadPool(creator=creator, pool_size=3)
if strong_refs:
sr = set()
assert len(p._all_conns) == 3
if strong_refs:
- still_opened = len([c for c in sr if not c.closed])
+ still_opened = len([c for c in sr if not c.close.call_count])
eq_(still_opened, 3)
class AssertionPoolTest(PoolTestBase):
dbapi = MockDBAPI()
p = pool.NullPool(creator=lambda: dbapi.connect('foo.db'))
c1 = p.connect()
- c_id = c1.connection.id
+
c1.close()
c1 = None
c1 = p.connect()
- dbapi.raise_error = True
c1.invalidate()
c1 = None
c1 = p.connect()
- assert c1.connection.id != c_id
+ dbapi.connect.assert_has_calls([
+ call('foo.db'),
+ call('foo.db')],
+ any_order=True)
class StaticPoolTest(PoolTestBase):
from sqlalchemy.testing import eq_, assert_raises, assert_raises_message
import time
-import weakref
-from sqlalchemy import select, MetaData, Integer, String, pool, create_engine
+from sqlalchemy import select, MetaData, Integer, String, create_engine, pool
from sqlalchemy.testing.schema import Table, Column
import sqlalchemy as tsa
from sqlalchemy import testing
from sqlalchemy import exc
from sqlalchemy.testing import fixtures
from sqlalchemy.testing.engines import testing_engine
+from sqlalchemy.testing import is_not_
+from sqlalchemy.testing.mock import Mock, call
class MockError(Exception):
pass
class MockDisconnect(MockError):
pass
-class MockDBAPI(object):
- def __init__(self):
- self.paramstyle = 'named'
- self.connections = weakref.WeakKeyDictionary()
- def connect(self, *args, **kwargs):
- return MockConnection(self)
- def shutdown(self, explode='execute'):
- for c in self.connections:
- c.explode = explode
- Error = MockError
-
-class MockConnection(object):
- def __init__(self, dbapi):
- dbapi.connections[self] = True
- self.explode = ""
- def rollback(self):
- if self.explode == 'rollback':
+def mock_connection():
+ def mock_cursor():
+ def execute(*args, **kwargs):
+ if conn.explode == 'execute':
+ raise MockDisconnect("Lost the DB connection on execute")
+ elif conn.explode in ('execute_no_disconnect', ):
+ raise MockError(
+ "something broke on execute but we didn't lose the connection")
+ elif conn.explode in ('rollback', 'rollback_no_disconnect'):
+ raise MockError(
+ "something broke on execute but we didn't lose the connection")
+ elif args and "SELECT" in args[0]:
+ cursor.description = [('foo', None, None, None, None, None)]
+ else:
+ return
+
+ def close():
+ cursor.fetchall = cursor.fetchone = \
+ Mock(side_effect=MockError("cursor closed"))
+ cursor = Mock(
+ execute=Mock(side_effect=execute),
+ close=Mock(side_effect=close)
+ )
+ return cursor
+
+ def cursor():
+ while True:
+ yield mock_cursor()
+
+ def rollback():
+ if conn.explode == 'rollback':
raise MockDisconnect("Lost the DB connection on rollback")
- if self.explode == 'rollback_no_disconnect':
+ if conn.explode == 'rollback_no_disconnect':
raise MockError(
"something broke on rollback but we didn't lose the connection")
else:
return
- def commit(self):
- pass
- def cursor(self):
- return MockCursor(self)
- def close(self):
- pass
-
-class MockCursor(object):
- def __init__(self, parent):
- self.explode = parent.explode
- self.description = ()
- self.closed = False
- def execute(self, *args, **kwargs):
- if self.explode == 'execute':
- raise MockDisconnect("Lost the DB connection on execute")
- elif self.explode in ('execute_no_disconnect', ):
- raise MockError(
- "something broke on execute but we didn't lose the connection")
- elif self.explode in ('rollback', 'rollback_no_disconnect'):
- raise MockError(
- "something broke on execute but we didn't lose the connection")
- elif args and "select" in args[0]:
- self.description = [('foo', None, None, None, None, None)]
- else:
- return
- def fetchall(self):
- if self.closed:
- raise MockError("cursor closed")
- return []
- def fetchone(self):
- if self.closed:
- raise MockError("cursor closed")
- return None
- def close(self):
- self.closed = True
-
-db, dbapi = None, None
+
+ conn = Mock(
+ rollback=Mock(side_effect=rollback),
+ cursor=Mock(side_effect=cursor())
+ )
+ return conn
+
+def MockDBAPI():
+ connections = []
+ def connect():
+ while True:
+ conn = mock_connection()
+ connections.append(conn)
+ yield conn
+
+ def shutdown(explode='execute'):
+ for c in connections:
+ c.explode = explode
+
+ def dispose():
+ for c in connections:
+ c.explode = None
+ connections[:] = []
+
+ return Mock(
+ connect=Mock(side_effect=connect()),
+ shutdown=Mock(side_effect=shutdown),
+ dispose=Mock(side_effect=dispose),
+ paramstyle='named',
+ connections=connections,
+ Error=MockError
+ )
+
+
class MockReconnectTest(fixtures.TestBase):
def setup(self):
- global db, dbapi
- dbapi = MockDBAPI()
+ self.dbapi = MockDBAPI()
- # note - using straight create_engine here
- # since we are testing gc
- db = create_engine(
+ self.db = testing_engine(
'postgresql://foo:bar@localhost/test',
- module=dbapi, _initialize=False)
+ options=dict(module=self.dbapi, _initialize=False))
+ self.mock_connect = call(host='localhost', password='bar',
+ user='foo', database='test')
# monkeypatch disconnect checker
- db.dialect.is_disconnect = lambda e, conn, cursor: isinstance(e, MockDisconnect)
+ self.db.dialect.is_disconnect = lambda e, conn, cursor: isinstance(e, MockDisconnect)
def teardown(self):
- db.dispose()
+ self.dbapi.dispose()
def test_reconnect(self):
"""test that an 'is_disconnect' condition will invalidate the
connection, and additionally dispose the previous connection
pool and recreate."""
- pid = id(db.pool)
+ db_pool = self.db.pool
# make a connection
- conn = db.connect()
+ conn = self.db.connect()
# connection works
# create a second connection within the pool, which we'll ensure
# also goes away
- conn2 = db.connect()
+ conn2 = self.db.connect()
conn2.close()
# two connections opened total now
- assert len(dbapi.connections) == 2
+ assert len(self.dbapi.connections) == 2
# set it to fail
- dbapi.shutdown()
- try:
- conn.execute(select([1]))
- assert False
- except tsa.exc.DBAPIError:
- pass
+ self.dbapi.shutdown()
+ assert_raises(
+ tsa.exc.DBAPIError,
+ conn.execute, select([1])
+ )
# assert was invalidated
# close shouldnt break
conn.close()
- assert id(db.pool) != pid
+ is_not_(self.db.pool, db_pool)
# ensure all connections closed (pool was recycled)
- gc_collect()
- assert len(dbapi.connections) == 0
- conn = db.connect()
+ eq_(
+ [c.close.mock_calls for c in self.dbapi.connections],
+ [[call()], [call()]]
+ )
+
+ conn = self.db.connect()
conn.execute(select([1]))
conn.close()
- assert len(dbapi.connections) == 1
+
+ eq_(
+ [c.close.mock_calls for c in self.dbapi.connections],
+ [[call()], [call()], []]
+ )
def test_invalidate_trans(self):
- conn = db.connect()
+ conn = self.db.connect()
trans = conn.begin()
- dbapi.shutdown()
- try:
- conn.execute(select([1]))
- assert False
- except tsa.exc.DBAPIError:
- pass
+ self.dbapi.shutdown()
- # assert was invalidated
+ assert_raises(
+ tsa.exc.DBAPIError,
+ conn.execute, select([1])
+ )
- gc_collect()
- assert len(dbapi.connections) == 0
+ eq_(
+ [c.close.mock_calls for c in self.dbapi.connections],
+ [[call()]]
+ )
assert not conn.closed
assert conn.invalidated
assert trans.is_active
conn.execute, select([1])
)
assert trans.is_active
- try:
- trans.commit()
- assert False
- except tsa.exc.InvalidRequestError, e:
- assert str(e) \
- == "Can't reconnect until invalid transaction is "\
- "rolled back"
+
+ assert_raises_message(
+ tsa.exc.InvalidRequestError,
+ "Can't reconnect until invalid transaction is "
+ "rolled back",
+ trans.commit
+ )
+
assert trans.is_active
trans.rollback()
assert not trans.is_active
conn.execute(select([1]))
assert not conn.invalidated
- assert len(dbapi.connections) == 1
+ eq_(
+ [c.close.mock_calls for c in self.dbapi.connections],
+ [[call()], []]
+ )
def test_conn_reusable(self):
- conn = db.connect()
+ conn = self.db.connect()
conn.execute(select([1]))
- assert len(dbapi.connections) == 1
+ eq_(
+ self.dbapi.connect.mock_calls,
+ [self.mock_connect]
+ )
- dbapi.shutdown()
+ self.dbapi.shutdown()
assert_raises(
tsa.exc.DBAPIError,
assert not conn.closed
assert conn.invalidated
- # ensure all connections closed (pool was recycled)
- gc_collect()
- assert len(dbapi.connections) == 0
+ eq_(
+ [c.close.mock_calls for c in self.dbapi.connections],
+ [[call()]]
+ )
# test reconnects
conn.execute(select([1]))
assert not conn.invalidated
- assert len(dbapi.connections) == 1
+
+ eq_(
+ [c.close.mock_calls for c in self.dbapi.connections],
+ [[call()], []]
+ )
def test_invalidated_close(self):
- conn = db.connect()
+ conn = self.db.connect()
- dbapi.shutdown()
+ self.dbapi.shutdown()
assert_raises(
tsa.exc.DBAPIError,
)
def test_noreconnect_execute_plus_closewresult(self):
- conn = db.connect(close_with_result=True)
+ conn = self.db.connect(close_with_result=True)
- dbapi.shutdown("execute_no_disconnect")
+ self.dbapi.shutdown("execute_no_disconnect")
# raises error
assert_raises_message(
assert not conn.invalidated
def test_noreconnect_rollback_plus_closewresult(self):
- conn = db.connect(close_with_result=True)
+ conn = self.db.connect(close_with_result=True)
- dbapi.shutdown("rollback_no_disconnect")
+ self.dbapi.shutdown("rollback_no_disconnect")
# raises error
assert_raises_message(
)
def test_reconnect_on_reentrant(self):
- conn = db.connect()
+ conn = self.db.connect()
conn.execute(select([1]))
- assert len(dbapi.connections) == 1
+ assert len(self.dbapi.connections) == 1
- dbapi.shutdown("rollback")
+ self.dbapi.shutdown("rollback")
# raises error
assert_raises_message(
assert conn.invalidated
def test_reconnect_on_reentrant_plus_closewresult(self):
- conn = db.connect(close_with_result=True)
+ conn = self.db.connect(close_with_result=True)
- dbapi.shutdown("rollback")
+ self.dbapi.shutdown("rollback")
# raises error
assert_raises_message(
)
def test_check_disconnect_no_cursor(self):
- conn = db.connect()
- result = conn.execute("select 1")
+ conn = self.db.connect()
+ result = conn.execute(select([1]))
result.cursor.close()
conn.close()
+
assert_raises_message(
tsa.exc.DBAPIError,
"cursor closed",
class CursorErrTest(fixtures.TestBase):
def setup(self):
- global db, dbapi
-
- class MDBAPI(MockDBAPI):
- def connect(self, *args, **kwargs):
- return MConn(self)
-
- class MConn(MockConnection):
- def cursor(self):
- return MCursor(self)
+ def MockDBAPI():
+ def cursor():
+ while True:
+ yield Mock(
+ description=[],
+ close=Mock(side_effect=Exception("explode")))
+ def connect():
+ while True:
+ yield Mock(cursor=Mock(side_effect=cursor()))
+
+ return Mock(connect=Mock(side_effect=connect()))
- class MCursor(MockCursor):
- def close(self):
- raise Exception("explode")
-
- dbapi = MDBAPI()
-
- db = testing_engine(
+ dbapi = MockDBAPI()
+ self.db = testing_engine(
'postgresql://foo:bar@localhost/test',
options=dict(module=dbapi, _initialize=False))
def test_cursor_explode(self):
- conn = db.connect()
+ conn = self.db.connect()
result = conn.execute("select foo")
result.close()
conn.close()
def teardown(self):
- db.dispose()
+ self.db.dispose()
+
+
+def _assert_invalidated(fn, *args):
+ try:
+ fn(*args)
+ assert False
+ except tsa.exc.DBAPIError as e:
+ if not e.connection_invalidated:
+ raise
-engine = None
class RealReconnectTest(fixtures.TestBase):
def setup(self):
- global engine
- engine = engines.reconnecting_engine()
+ self.engine = engines.reconnecting_engine()
def teardown(self):
- engine.dispose()
+ self.engine.dispose()
@testing.fails_on('+informixdb',
"Wrong error thrown, fix in informixdb?")
def test_reconnect(self):
- conn = engine.connect()
+ conn = self.engine.connect()
eq_(conn.execute(select([1])).scalar(), 1)
assert not conn.closed
- engine.test_shutdown()
+ self.engine.test_shutdown()
- try:
- conn.execute(select([1]))
- assert False
- except tsa.exc.DBAPIError, e:
- if not e.connection_invalidated:
- raise
+ _assert_invalidated(conn.execute, select([1]))
assert not conn.closed
assert conn.invalidated
assert not conn.invalidated
# one more time
- engine.test_shutdown()
- try:
- conn.execute(select([1]))
- assert False
- except tsa.exc.DBAPIError, e:
- if not e.connection_invalidated:
- raise
+ self.engine.test_shutdown()
+ _assert_invalidated(conn.execute, select([1]))
+
assert conn.invalidated
eq_(conn.execute(select([1])).scalar(), 1)
assert not conn.invalidated
conn.close()
def test_multiple_invalidate(self):
- c1 = engine.connect()
- c2 = engine.connect()
+ c1 = self.engine.connect()
+ c2 = self.engine.connect()
eq_(c1.execute(select([1])).scalar(), 1)
- p1 = engine.pool
- engine.test_shutdown()
+ p1 = self.engine.pool
+ self.engine.test_shutdown()
- try:
- c1.execute(select([1]))
- assert False
- except tsa.exc.DBAPIError, e:
- assert e.connection_invalidated
+ _assert_invalidated(c1.execute, select([1]))
- p2 = engine.pool
+ p2 = self.engine.pool
- try:
- c2.execute(select([1]))
- assert False
- except tsa.exc.DBAPIError, e:
- assert e.connection_invalidated
+ _assert_invalidated(c2.execute, select([1]))
# pool isn't replaced
- assert engine.pool is p2
+ assert self.engine.pool is p2
def test_ensure_is_disconnect_gets_connection(self):
# though MySQLdb we get a non-working cursor.
# assert cursor is None
- engine.dialect.is_disconnect = is_disconnect
- conn = engine.connect()
- engine.test_shutdown()
+ self.engine.dialect.is_disconnect = is_disconnect
+ conn = self.engine.connect()
+ self.engine.test_shutdown()
assert_raises(
tsa.exc.DBAPIError,
conn.execute, select([1])
)
def test_rollback_on_invalid_plain(self):
- conn = engine.connect()
+ conn = self.engine.connect()
trans = conn.begin()
conn.invalidate()
trans.rollback()
@testing.requires.two_phase_transactions
def test_rollback_on_invalid_twophase(self):
- conn = engine.connect()
+ conn = self.engine.connect()
trans = conn.begin_twophase()
conn.invalidate()
trans.rollback()
@testing.requires.savepoints
def test_rollback_on_invalid_savepoint(self):
- conn = engine.connect()
+ conn = self.engine.connect()
trans = conn.begin()
trans2 = conn.begin_nested()
conn.invalidate()
trans2.rollback()
def test_invalidate_twice(self):
- conn = engine.connect()
+ conn = self.engine.connect()
conn.invalidate()
conn.invalidate()
eq_(conn.execute(select([1])).scalar(), 1)
assert not conn.closed
engine.test_shutdown()
- try:
- conn.execute(select([1]))
- assert False
- except tsa.exc.DBAPIError, e:
- if not e.connection_invalidated:
- raise
+ _assert_invalidated(conn.execute, select([1]))
assert not conn.closed
assert conn.invalidated
eq_(conn.execute(select([1])).scalar(), 1)
@testing.fails_on('+informixdb',
"Wrong error thrown, fix in informixdb?")
def test_close(self):
- conn = engine.connect()
+ conn = self.engine.connect()
eq_(conn.execute(select([1])).scalar(), 1)
assert not conn.closed
- engine.test_shutdown()
+ self.engine.test_shutdown()
- try:
- conn.execute(select([1]))
- assert False
- except tsa.exc.DBAPIError, e:
- if not e.connection_invalidated:
- raise
+ _assert_invalidated(conn.execute, select([1]))
conn.close()
- conn = engine.connect()
+ conn = self.engine.connect()
eq_(conn.execute(select([1])).scalar(), 1)
@testing.fails_on('+informixdb',
"Wrong error thrown, fix in informixdb?")
def test_with_transaction(self):
- conn = engine.connect()
+ conn = self.engine.connect()
trans = conn.begin()
eq_(conn.execute(select([1])).scalar(), 1)
assert not conn.closed
- engine.test_shutdown()
- try:
- conn.execute(select([1]))
- assert False
- except tsa.exc.DBAPIError, e:
- if not e.connection_invalidated:
- raise
+ self.engine.test_shutdown()
+ _assert_invalidated(conn.execute, select([1]))
assert not conn.closed
assert conn.invalidated
assert trans.is_active
conn.execute, select([1])
)
assert trans.is_active
- try:
- trans.commit()
- assert False
- except tsa.exc.InvalidRequestError, e:
- assert str(e) \
- == "Can't reconnect until invalid transaction is "\
- "rolled back"
+ assert_raises_message(
+ tsa.exc.InvalidRequestError,
+ "Can't reconnect until invalid transaction is rolled back",
+ trans.commit
+ )
assert trans.is_active
trans.rollback()
assert not trans.is_active
eq_(conn.execute(select([1])).scalar(), 1)
conn.close()
-meta, table, engine = None, None, None
class InvalidateDuringResultTest(fixtures.TestBase):
def setup(self):
- global meta, table, engine
- engine = engines.reconnecting_engine()
- meta = MetaData(engine)
- table = Table('sometable', meta,
+ self.engine = engines.reconnecting_engine()
+ self.meta = MetaData(self.engine)
+ table = Table('sometable', self.meta,
Column('id', Integer, primary_key=True),
Column('name', String(50)))
- meta.create_all()
+ self.meta.create_all()
table.insert().execute(
- [{'id':i, 'name':'row %d' % i} for i in range(1, 100)]
+ [{'id': i, 'name': 'row %d' % i} for i in range(1, 100)]
)
def teardown(self):
- meta.drop_all()
- engine.dispose()
+ self.meta.drop_all()
+ self.engine.dispose()
@testing.fails_if([
'+mysqlconnector', '+mysqldb',
@testing.fails_on('+informixdb',
"Wrong error thrown, fix in informixdb?")
def test_invalidate_on_results(self):
- conn = engine.connect()
+ conn = self.engine.connect()
result = conn.execute('select * from sometable')
for x in xrange(20):
result.fetchone()
- engine.test_shutdown()
- try:
- print 'ghost result: %r' % result.fetchone()
- assert False
- except tsa.exc.DBAPIError, e:
- if not e.connection_invalidated:
- raise
+ self.engine.test_shutdown()
+ _assert_invalidated(result.fetchone)
assert conn.invalidated