--- /dev/null
+.. change::
+ :tags: bug, pool, asyncio
+ :tickets: 5823
+
+ When using an asyncio engine, the connection pool will now detach and
+ discard a pooled connection that is was not explicitly closed/returned to
+ the pool when its tracking object is garbage collected, emitting a warning
+ that the connection was not properly closed. As this operation occurs
+ during Python gc finalizers, it's not safe to run any IO operations upon
+ the connection including transaction rollback or connection close as this
+ will often be outside of the event loop.
+
+
--- /dev/null
+.. change::
+ :tags: bug, asyncio
+ :tickets: 5827
+
+ Fixed bug in asyncio connection pool where ``asyncio.TimeoutError`` would
+ be raised rather than :class:`.exc.TimeoutError`. Also repaired the
+ :paramref:`_sa.create_engine.pool_timeout` parameter set to zero when using
+ the async engine, which previously would ignore the timeout and block
+ rather than timing out immediately as is the behavior with regular
+ :class:`.QueuePool`.
def has_table(self, connection, tablename, dbname, owner, schema):
if tablename.startswith("#"): # temporary table
tables = ischema.mssql_temp_table_columns
- result = connection.execute(
- sql.select(tables.c.table_name)
- .where(
- tables.c.table_name.like(
- self._temp_table_name_like_pattern(tablename)
- )
+
+ s = sql.select(tables.c.table_name).where(
+ tables.c.table_name.like(
+ self._temp_table_name_like_pattern(tablename)
)
- .limit(1)
)
+
+ result = connection.execute(s.limit(1))
return result.scalar() is not None
else:
tables = ischema.tables
+from sqlalchemy import inspect
+from sqlalchemy import Integer
from ... import create_engine
from ... import exc
+from ...schema import Column
+from ...schema import DropConstraint
+from ...schema import ForeignKeyConstraint
+from ...schema import MetaData
+from ...schema import Table
from ...testing.provision import create_db
+from ...testing.provision import drop_all_schema_objects_pre_tables
from ...testing.provision import drop_db
from ...testing.provision import get_temp_table_name
from ...testing.provision import log
# "where database_id=db_id('%s')" % ident):
# log.info("killing SQL server session %s", row['session_id'])
# conn.exec_driver_sql("kill %s" % row['session_id'])
-
conn.exec_driver_sql("drop database %s" % ident)
log.info("Reaped db: %s", ident)
return True
@get_temp_table_name.for_db("mssql")
def _mssql_get_temp_table_name(cfg, eng, base_name):
- return "#" + base_name
+ return "##" + base_name
+
+
+@drop_all_schema_objects_pre_tables.for_db("mssql")
+def drop_all_schema_objects_pre_tables(cfg, eng):
+ with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn:
+ inspector = inspect(conn)
+ for schema in (None, "dbo", cfg.test_schema, cfg.test_schema_2):
+ for tname in inspector.get_table_names(schema=schema):
+ tb = Table(
+ tname,
+ MetaData(),
+ Column("x", Integer),
+ Column("y", Integer),
+ schema=schema,
+ )
+ for fk in inspect(conn).get_foreign_keys(tname, schema=schema):
+ conn.execute(
+ DropConstraint(
+ ForeignKeyConstraint(
+ [tb.c.x], [tb.c.y], name=fk["name"]
+ )
+ )
+ )
* ``encoding_errors`` - see :ref:`cx_oracle_unicode_encoding_errors` for detail.
+
.. _cx_oracle_unicode:
Unicode
from ...testing.provision import drop_db
from ...testing.provision import follower_url_from_main
from ...testing.provision import log
+from ...testing.provision import post_configure_engine
from ...testing.provision import run_reap_dbs
from ...testing.provision import set_default_schema_on_connection
-from ...testing.provision import stop_test_class
+from ...testing.provision import stop_test_class_outside_fixtures
from ...testing.provision import temp_table_keyword_args
-from ...testing.provision import update_db_opts
@create_db.for_db("oracle")
_ora_drop_ignore(conn, "%s_ts2" % ident)
-@update_db_opts.for_db("oracle")
-def _oracle_update_db_opts(db_url, db_opts):
- pass
+@stop_test_class_outside_fixtures.for_db("oracle")
+def stop_test_class_outside_fixtures(config, db, cls):
+ with db.begin() as conn:
+ # run magic command to get rid of identity sequences
+ # https://floo.bar/2019/11/29/drop-the-underlying-sequence-of-an-identity-column/ # noqa E501
+ conn.exec_driver_sql("purge recyclebin")
-@stop_test_class.for_db("oracle")
-def stop_test_class(config, db, cls):
- """run magic command to get rid of identity sequences
+ # clear statement cache on all connections that were used
+ # https://github.com/oracle/python-cx_Oracle/issues/519
- # https://floo.bar/2019/11/29/drop-the-underlying-sequence-of-an-identity-column/
+ for cx_oracle_conn in _all_conns:
+ try:
+ sc = cx_oracle_conn.stmtcachesize
+ except db.dialect.dbapi.InterfaceError:
+ # connection closed
+ pass
+ else:
+ cx_oracle_conn.stmtcachesize = 0
+ cx_oracle_conn.stmtcachesize = sc
+ _all_conns.clear()
- """
- with db.begin() as conn:
- conn.exec_driver_sql("purge recyclebin")
+_all_conns = set()
+
+
+@post_configure_engine.for_db("oracle")
+def _oracle_post_configure_engine(url, engine, follower_ident):
+ from sqlalchemy import event
+
+ @event.listens_for(engine, "checkout")
+ def checkout(dbapi_con, con_record, con_proxy):
+ _all_conns.add(dbapi_con)
@run_reap_dbs.for_db("oracle")
def rollback(self):
if self._started:
self.await_(self._transaction.rollback())
-
self._transaction = None
self._started = False
from ...testing.provision import drop_all_schema_objects_pre_tables
from ...testing.provision import drop_db
from ...testing.provision import log
+from ...testing.provision import prepare_for_drop_tables
from ...testing.provision import set_default_schema_on_connection
from ...testing.provision import temp_table_keyword_args
postgresql.ENUM(name=enum["name"], schema=enum["schema"])
)
)
+
+
+@prepare_for_drop_tables.for_db("postgresql")
+def prepare_for_drop_tables(config, connection):
+ """Ensure there are no locks on the current username/database."""
+
+ result = connection.exec_driver_sql(
+ "select pid, state, wait_event_type, query "
+ # "select pg_terminate_backend(pid), state, wait_event_type "
+ "from pg_stat_activity where "
+ "usename=current_user "
+ "and datname=current_database() and state='idle in transaction' "
+ "and pid != pg_backend_pid()"
+ )
+ rows = result.all() # noqa
+ assert not rows, (
+ "PostgreSQL may not be able to DROP tables due to "
+ "idle in transaction: %s"
+ % ("; ".join(row._mapping["query"] for row in rows))
+ )
from ...testing.provision import log
from ...testing.provision import post_configure_engine
from ...testing.provision import run_reap_dbs
-from ...testing.provision import stop_test_class
+from ...testing.provision import stop_test_class_outside_fixtures
from ...testing.provision import temp_table_keyword_args
os.remove(path)
-@stop_test_class.for_db("sqlite")
-def stop_test_class(config, db, cls):
+@stop_test_class_outside_fixtures.for_db("sqlite")
+def stop_test_class_outside_fixtures(config, db, cls):
with db.connect() as conn:
files = [
row.file
return self.conn
def __exit__(self, type_, value, traceback):
-
- if type_ is not None:
- self.transaction.rollback()
- else:
- if self.transaction.is_active:
- self.transaction.commit()
- if not self.close_with_result:
- self.conn.close()
+ try:
+ if type_ is not None:
+ if self.transaction.is_active:
+ self.transaction.rollback()
+ else:
+ if self.transaction.is_active:
+ self.transaction.commit()
+ finally:
+ if not self.close_with_result:
+ self.conn.close()
def begin(self, close_with_result=False):
"""Return a context manager delivering a :class:`_engine.Connection`
c = base.Connection(
engine, connection=dbapi_connection, _has_events=False
)
- c._execution_options = util.immutabledict()
- dialect.initialize(c)
- dialect.do_rollback(c.connection)
+ c._execution_options = util.EMPTY_DICT
+
+ try:
+ dialect.initialize(c)
+ finally:
+ dialect.do_rollback(c.connection)
# previously, the "first_connect" event was used here, which was then
# scaled back if the "on_connect" handler were present. now,
return self.conn
def __exit__(self, type_, value, traceback):
- if type_ is not None:
- self.transaction.rollback()
- else:
- if self.transaction.is_active:
- self.transaction.commit()
- self.conn.close()
+ try:
+ if type_ is not None:
+ if self.transaction.is_active:
+ self.transaction.rollback()
+ else:
+ if self.transaction.is_active:
+ self.transaction.commit()
+ finally:
+ self.conn.close()
def begin(self):
"""Return a :class:`_future.Connection` object with a transaction
rec._checkin_failed(err)
echo = pool._should_log_debug()
fairy = _ConnectionFairy(dbapi_connection, rec, echo)
+
rec.fairy_ref = weakref.ref(
fairy,
lambda ref: _finalize_fairy
assert connection is None
connection = connection_record.connection
+ dont_restore_gced = pool._is_asyncio
+
+ if dont_restore_gced:
+ detach = not connection_record or ref
+ can_manipulate_connection = not ref
+ else:
+ detach = not connection_record
+ can_manipulate_connection = True
+
if connection is not None:
if connection_record and echo:
pool.logger.debug(
connection, connection_record, echo
)
assert fairy.connection is connection
- fairy._reset(pool)
+ if can_manipulate_connection:
+ fairy._reset(pool)
+
+ if detach:
+ if connection_record:
+ fairy._pool = pool
+ fairy.detach()
+
+ if can_manipulate_connection:
+ if pool.dispatch.close_detached:
+ pool.dispatch.close_detached(connection)
+
+ pool._close_connection(connection)
+ else:
+ util.warn(
+ "asyncio connection is being garbage "
+ "collected without being properly closed: %r"
+ % connection
+ )
- # Immediately close detached instances
- if not connection_record:
- if pool.dispatch.close_detached:
- pool.dispatch.close_detached(connection)
- pool._close_connection(connection)
except BaseException as e:
pool.logger.error(
"Exception during reset or similar", exc_info=True
from .assertions import is_ # noqa
from .assertions import is_false # noqa
from .assertions import is_instance_of # noqa
+from .assertions import is_none # noqa
from .assertions import is_not # noqa
from .assertions import is_not_ # noqa
+from .assertions import is_not_none # noqa
from .assertions import is_true # noqa
from .assertions import le_ # noqa
from .assertions import ne_ # noqa
is_(bool(a), False, msg=msg)
+def is_none(a, msg=None):
+ is_(a, None, msg=msg)
+
+
+def is_not_none(a, msg=None):
+ is_not(a, None, msg=msg)
+
+
def is_(a, b, msg=None):
"""Assert a is b, with repr messaging on failure."""
assert a is b, msg or "%r is not %r" % (a, b)
return _fixture_functions.get_current_test_name()
+def mark_base_test_class():
+ return _fixture_functions.mark_base_test_class()
+
+
class Config(object):
def __init__(self, db, db_opts, options, file_config):
self._set_name(db)
from __future__ import absolute_import
+import collections
import re
import warnings
import weakref
class ConnectionKiller(object):
def __init__(self):
self.proxy_refs = weakref.WeakKeyDictionary()
- self.testing_engines = weakref.WeakKeyDictionary()
- self.conns = set()
+ self.testing_engines = collections.defaultdict(set)
+ self.dbapi_connections = set()
def add_pool(self, pool):
- event.listen(pool, "connect", self.connect)
- event.listen(pool, "checkout", self.checkout)
- event.listen(pool, "invalidate", self.invalidate)
-
- def add_engine(self, engine):
- self.add_pool(engine.pool)
- self.testing_engines[engine] = True
+ event.listen(pool, "checkout", self._add_conn)
+ event.listen(pool, "checkin", self._remove_conn)
+ event.listen(pool, "close", self._remove_conn)
+ event.listen(pool, "close_detached", self._remove_conn)
+ # note we are keeping "invalidated" here, as those are still
+ # opened connections we would like to roll back
+
+ def _add_conn(self, dbapi_con, con_record, con_proxy):
+ self.dbapi_connections.add(dbapi_con)
+ self.proxy_refs[con_proxy] = True
- def connect(self, dbapi_conn, con_record):
- self.conns.add((dbapi_conn, con_record))
+ def _remove_conn(self, dbapi_conn, *arg):
+ self.dbapi_connections.discard(dbapi_conn)
- def checkout(self, dbapi_con, con_record, con_proxy):
- self.proxy_refs[con_proxy] = True
+ def add_engine(self, engine, scope):
+ self.add_pool(engine.pool)
- def invalidate(self, dbapi_con, con_record, exception):
- self.conns.discard((dbapi_con, con_record))
+ assert scope in ("class", "global", "function", "fixture")
+ self.testing_engines[scope].add(engine)
def _safe(self, fn):
try:
if rec is not None and rec.is_valid:
self._safe(rec.rollback)
- def close_all(self):
+ def checkin_all(self):
+ # run pool.checkin() for all ConnectionFairy instances we have
+ # tracked.
+
for rec in list(self.proxy_refs):
if rec is not None and rec.is_valid:
- self._safe(rec._close)
-
- def _after_test_ctx(self):
- # this can cause a deadlock with pg8000 - pg8000 acquires
- # prepared statement lock inside of rollback() - if async gc
- # is collecting in finalize_fairy, deadlock.
- # not sure if this should be for non-cpython only.
- # note that firebird/fdb definitely needs this though
- for conn, rec in list(self.conns):
- if rec.connection is None:
- # this is a hint that the connection is closed, which
- # is causing segfaults on mysqlclient due to
- # https://github.com/PyMySQL/mysqlclient-python/issues/270;
- # try to work around here
- continue
- self._safe(conn.rollback)
-
- def _stop_test_ctx(self):
- if config.options.low_connections:
- self._stop_test_ctx_minimal()
- else:
- self._stop_test_ctx_aggressive()
+ self.dbapi_connections.discard(rec.connection)
+ self._safe(rec._checkin)
- def _stop_test_ctx_minimal(self):
- self.close_all()
+ # for fairy refs that were GCed and could not close the connection,
+ # such as asyncio, roll back those remaining connections
+ for con in self.dbapi_connections:
+ self._safe(con.rollback)
+ self.dbapi_connections.clear()
- self.conns = set()
+ def close_all(self):
+ self.checkin_all()
- for rec in list(self.testing_engines):
- if rec is not config.db:
- rec.dispose()
+ def prepare_for_drop_tables(self, connection):
+ # don't do aggressive checks for third party test suites
+ if not config.bootstrapped_as_sqlalchemy:
+ return
- def _stop_test_ctx_aggressive(self):
- self.close_all()
- for conn, rec in list(self.conns):
- self._safe(conn.close)
- rec.connection = None
+ from . import provision
+
+ provision.prepare_for_drop_tables(connection.engine.url, connection)
+
+ def _drop_testing_engines(self, scope):
+ eng = self.testing_engines[scope]
+ for rec in list(eng):
+ for proxy_ref in list(self.proxy_refs):
+ if proxy_ref is not None and proxy_ref.is_valid:
+ if (
+ proxy_ref._pool is not None
+ and proxy_ref._pool is rec.pool
+ ):
+ self._safe(proxy_ref._checkin)
+ rec.dispose()
+ eng.clear()
+
+ def after_test(self):
+ self._drop_testing_engines("function")
+
+ def after_test_outside_fixtures(self, test):
+ # don't do aggressive checks for third party test suites
+ if not config.bootstrapped_as_sqlalchemy:
+ return
+
+ if test.__class__.__leave_connections_for_teardown__:
+ return
- self.conns = set()
- for rec in list(self.testing_engines):
- if hasattr(rec, "sync_engine"):
- rec.sync_engine.dispose()
- else:
- rec.dispose()
+ self.checkin_all()
+
+ # on PostgreSQL, this will test for any "idle in transaction"
+ # connections. useful to identify tests with unusual patterns
+ # that can't be cleaned up correctly.
+ from . import provision
+
+ with config.db.connect() as conn:
+ provision.prepare_for_drop_tables(conn.engine.url, conn)
+
+ def stop_test_class_inside_fixtures(self):
+ self.checkin_all()
+ self._drop_testing_engines("function")
+ self._drop_testing_engines("class")
+
+ def final_cleanup(self):
+ self.checkin_all()
+ for scope in self.testing_engines:
+ self._drop_testing_engines(scope)
def assert_all_closed(self):
for rec in self.proxy_refs:
testing_reaper = ConnectionKiller()
-def drop_all_tables(metadata, bind):
- testing_reaper.close_all()
- if hasattr(bind, "close"):
- bind.close()
-
- if not config.db.dialect.supports_alter:
- from . import assertions
-
- with assertions.expect_warnings("Can't sort tables", assert_=False):
- metadata.drop_all(bind)
- else:
- metadata.drop_all(bind)
-
-
@decorator
def assert_conns_closed(fn, *args, **kw):
try:
def close_first(fn, *args, **kw):
"""Decorator that closes all connections before fn execution."""
- testing_reaper.close_all()
+ testing_reaper.checkin_all()
fn(*args, **kw)
try:
fn(*args, **kw)
finally:
- testing_reaper.close_all()
+ testing_reaper.checkin_all()
def all_dialects(exclude=None):
return engine
-def testing_engine(url=None, options=None, future=False, asyncio=False):
+def testing_engine(url=None, options=None, future=None, asyncio=False):
"""Produce an engine configured by --options with optional overrides."""
if asyncio:
from sqlalchemy.ext.asyncio import create_async_engine as create_engine
- elif future or config.db and config.db._is_future:
+ elif future or (
+ config.db and config.db._is_future and future is not False
+ ):
from sqlalchemy.future import create_engine
else:
from sqlalchemy import create_engine
if not options:
use_reaper = True
+ scope = "function"
else:
use_reaper = options.pop("use_reaper", True)
+ scope = options.pop("scope", "function")
url = url or config.db.url
default_opt.update(options)
engine = create_engine(url, **options)
- if asyncio:
- engine.sync_engine._has_events = True
- else:
- engine._has_events = True # enable event blocks, helps with profiling
+
+ if scope == "global":
+ if asyncio:
+ engine.sync_engine._has_events = True
+ else:
+ engine._has_events = (
+ True # enable event blocks, helps with profiling
+ )
if isinstance(engine.pool, pool.QueuePool):
engine.pool._timeout = 0
- engine.pool._max_overflow = 5
+ engine.pool._max_overflow = 0
if use_reaper:
- testing_reaper.add_engine(engine)
+ testing_reaper.add_engine(engine, scope)
return engine
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
+import contextlib
import re
import sys
from . import assertions
from . import config
from . import schema
-from .engines import drop_all_tables
-from .engines import testing_engine
from .entities import BasicEntity
from .entities import ComparableEntity
from .entities import ComparableMixin # noqa
from .util import adict
+from .util import drop_all_tables_from_metadata
from .. import event
from .. import util
from ..orm import declarative_base
from ..orm.decl_api import DeclarativeMeta
from ..schema import sort_tables_and_constraints
-# whether or not we use unittest changes things dramatically,
-# as far as how pytest collection works.
-
+@config.mark_base_test_class()
class TestBase(object):
# A sequence of database names to always run, regardless of the
# constraints below.
# skipped.
__skip_if__ = None
+ # if True, the testing reaper will not attempt to touch connection
+ # state after a test is completed and before the outer teardown
+ # starts
+ __leave_connections_for_teardown__ = False
+
def assert_(self, val, msg=None):
assert val, msg
- # apparently a handful of tests are doing this....OK
- def setup(self):
- if hasattr(self, "setUp"):
- self.setUp()
-
- def teardown(self):
- if hasattr(self, "tearDown"):
- self.tearDown()
-
@config.fixture()
def connection(self):
- eng = getattr(self, "bind", config.db)
+ global _connection_fixture_connection
+
+ eng = getattr(self, "bind", None) or config.db
conn = eng.connect()
trans = conn.begin()
- try:
- yield conn
- finally:
- if trans.is_active:
- trans.rollback()
- conn.close()
+
+ _connection_fixture_connection = conn
+ yield conn
+
+ _connection_fixture_connection = None
+
+ if trans.is_active:
+ trans.rollback()
+ # trans would not be active here if the test is using
+ # the legacy @provide_metadata decorator still, as it will
+ # run a close all connections.
+ conn.close()
@config.fixture()
- def future_connection(self):
+ def future_connection(self, future_engine, connection):
+ # integrate the future_engine and connection fixtures so
+ # that users of the "connection" fixture will get at the
+ # "future" connection
+ yield connection
- eng = testing_engine(future=True)
- conn = eng.connect()
- trans = conn.begin()
- try:
- yield conn
- finally:
- if trans.is_active:
- trans.rollback()
- conn.close()
+ @config.fixture()
+ def future_engine(self):
+ eng = getattr(self, "bind", None) or config.db
+ with _push_future_engine(eng):
+ yield
+
+ @config.fixture()
+ def testing_engine(self):
+ from . import engines
+
+ def gen_testing_engine(
+ url=None, options=None, future=False, asyncio=False
+ ):
+ if options is None:
+ options = {}
+ options["scope"] = "fixture"
+ return engines.testing_engine(
+ url=url, options=options, future=future, asyncio=asyncio
+ )
+
+ yield gen_testing_engine
+
+ engines.testing_reaper._drop_testing_engines("fixture")
@config.fixture()
- def metadata(self):
+ def metadata(self, request):
"""Provide bound MetaData for a single test, dropping afterwards."""
- from . import engines
from ..sql import schema
metadata = schema.MetaData()
- try:
- yield metadata
- finally:
- engines.drop_all_tables(metadata, config.db)
+ request.instance.metadata = metadata
+ yield metadata
+ del request.instance.metadata
+ if (
+ _connection_fixture_connection
+ and _connection_fixture_connection.in_transaction()
+ ):
+ trans = _connection_fixture_connection.get_transaction()
+ trans.rollback()
+ with _connection_fixture_connection.begin():
+ drop_all_tables_from_metadata(
+ metadata, _connection_fixture_connection
+ )
+ else:
+ drop_all_tables_from_metadata(metadata, config.db)
-class FutureEngineMixin(object):
- @classmethod
- def setup_class(cls):
- from ..future.engine import Engine
- from sqlalchemy import testing
+_connection_fixture_connection = None
- facade = Engine._future_facade(config.db)
- config._current.push_engine(facade, testing)
- super_ = super(FutureEngineMixin, cls)
- if hasattr(super_, "setup_class"):
- super_.setup_class()
+@contextlib.contextmanager
+def _push_future_engine(engine):
- @classmethod
- def teardown_class(cls):
- super_ = super(FutureEngineMixin, cls)
- if hasattr(super_, "teardown_class"):
- super_.teardown_class()
+ from ..future.engine import Engine
+ from sqlalchemy import testing
+
+ facade = Engine._future_facade(engine)
+ config._current.push_engine(facade, testing)
+
+ yield facade
- from sqlalchemy import testing
+ config._current.pop(testing)
- config._current.pop(testing)
+
+class FutureEngineMixin(object):
+ @config.fixture(autouse=True, scope="class")
+ def _push_future_engine(self):
+ eng = getattr(self, "bind", None) or config.db
+ with _push_future_engine(eng):
+ yield
class TablesTest(TestBase):
other = None
sequences = None
- @property
- def tables_test_metadata(self):
- return self._tables_metadata
-
- @classmethod
- def setup_class(cls):
+ @config.fixture(autouse=True, scope="class")
+ def _setup_tables_test_class(self):
+ cls = self.__class__
cls._init_class()
cls._setup_once_tables()
cls._setup_once_inserts()
+ yield
+
+ cls._teardown_once_metadata_bind()
+
+ @config.fixture(autouse=True, scope="function")
+ def _setup_tables_test_instance(self):
+ self._setup_each_tables()
+ self._setup_each_inserts()
+
+ yield
+
+ self._teardown_each_tables()
+
+ @property
+ def tables_test_metadata(self):
+ return self._tables_metadata
+
@classmethod
def _init_class(cls):
if cls.run_define_tables == "each":
if self.run_define_tables == "each":
self.tables.clear()
if self.run_create_tables == "each":
- drop_all_tables(self._tables_metadata, self.bind)
+ drop_all_tables_from_metadata(self._tables_metadata, self.bind)
self._tables_metadata.clear()
elif self.run_create_tables == "each":
- drop_all_tables(self._tables_metadata, self.bind)
+ drop_all_tables_from_metadata(self._tables_metadata, self.bind)
# no need to run deletes if tables are recreated on setup
if (
file=sys.stderr,
)
- def setup(self):
- self._setup_each_tables()
- self._setup_each_inserts()
-
- def teardown(self):
- self._teardown_each_tables()
-
@classmethod
def _teardown_once_metadata_bind(cls):
if cls.run_create_tables:
- drop_all_tables(cls._tables_metadata, cls.bind)
+ drop_all_tables_from_metadata(cls._tables_metadata, cls.bind)
if cls.run_dispose_bind == "once":
cls.dispose_bind(cls.bind)
if cls.run_setup_bind is not None:
cls.bind = None
- @classmethod
- def teardown_class(cls):
- cls._teardown_once_metadata_bind()
-
@classmethod
def setup_bind(cls):
return config.db
self._event_fns.add((target, name, fn))
event.listen(target, name, fn, **kw)
- def teardown(self):
+ @config.fixture(autouse=True, scope="function")
+ def _remove_events(self):
+ yield
for key in self._event_fns:
event.remove(*key)
- super_ = super(RemovesEvents, self)
- if hasattr(super_, "teardown"):
- super_.teardown()
-
-
-class _ORMTest(object):
- @classmethod
- def teardown_class(cls):
- sa.orm.session.close_all_sessions()
- sa.orm.clear_mappers()
-def create_session(**kw):
- kw.setdefault("autoflush", False)
- kw.setdefault("expire_on_commit", False)
- return sa.orm.Session(config.db, **kw)
+_fixture_sessions = set()
def fixture_session(**kw):
kw.setdefault("autoflush", True)
kw.setdefault("expire_on_commit", True)
- return sa.orm.Session(config.db, **kw)
+ sess = sa.orm.Session(config.db, **kw)
+ _fixture_sessions.add(sess)
+ return sess
+
+
+def _close_all_sessions():
+ # will close all still-referenced sessions
+ sa.orm.session.close_all_sessions()
+ _fixture_sessions.clear()
+
+
+def stop_test_class_inside_fixtures(cls):
+ _close_all_sessions()
+ sa.orm.clear_mappers()
-class ORMTest(_ORMTest, TestBase):
+def after_test():
+
+ if _fixture_sessions:
+
+ _close_all_sessions()
+
+
+class ORMTest(TestBase):
pass
-class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults):
+class MappedTest(TablesTest, assertions.AssertsExecutionResults):
# 'once', 'each', None
run_setup_classes = "once"
classes = None
- @classmethod
- def setup_class(cls):
+ @config.fixture(autouse=True, scope="class")
+ def _setup_tables_test_class(self):
+ cls = self.__class__
cls._init_class()
if cls.classes is None:
cls._setup_once_mappers()
cls._setup_once_inserts()
- @classmethod
- def teardown_class(cls):
+ yield
+
cls._teardown_once_class()
cls._teardown_once_metadata_bind()
- def setup(self):
+ @config.fixture(autouse=True, scope="function")
+ def _setup_tables_test_instance(self):
self._setup_each_tables()
self._setup_each_classes()
self._setup_each_mappers()
self._setup_each_inserts()
- def teardown(self):
+ yield
+
sa.orm.session.close_all_sessions()
self._teardown_each_mappers()
self._teardown_each_classes()
@classmethod
def _teardown_once_class(cls):
cls.classes.clear()
- _ORMTest.teardown_class()
@classmethod
def _setup_once_classes(cls):
"""
cls_registry = cls.classes
+ assert cls_registry is not None
+
class FindFixture(type):
def __init__(cls, classname, bases, dict_):
cls_registry[classname] = cls
if to_bootstrap == "pytest":
sys.modules["sqla_plugin_base"] = load_file_as_module("plugin_base")
+ sys.modules["sqla_plugin_base"].bootstrapped_as_sqlalchemy = True
+ if sys.version_info < (3, 0):
+ sys.modules["sqla_reinvent_fixtures"] = load_file_as_module(
+ "reinvent_fixtures_py2k"
+ )
sys.modules["sqla_pytestplugin"] = load_file_as_module("pytestplugin")
else:
raise Exception("unknown bootstrap: %s" % to_bootstrap) # noqa
import re
import sys
+# flag which indicates we are in the SQLAlchemy testing suite,
+# and not that of Alembic or a third party dialect.
+bootstrapped_as_sqlalchemy = False
log = logging.getLogger("sqlalchemy.testing.plugin_base")
@post
def _set_disable_asyncio(opt, file_config):
- if opt.disable_asyncio:
+ if opt.disable_asyncio or not py3k:
from sqlalchemy.testing import asyncio
asyncio.ENABLE_ASYNCIO = False
config.requirements = testing.requires = req_cls()
+ config.bootstrapped_as_sqlalchemy = bootstrapped_as_sqlalchemy
+
@post
def _prep_testing_database(options, file_config):
yield cls
-def start_test_class(cls):
+def start_test_class_outside_fixtures(cls):
_do_skips(cls)
_setup_engine(cls)
def stop_test_class(cls):
- # from sqlalchemy import inspect
- # assert not inspect(testing.db).get_table_names()
+ # close sessions, immediate connections, etc.
+ fixtures.stop_test_class_inside_fixtures(cls)
+
+ # close outstanding connection pool connections, dispose of
+ # additional engines
+ engines.testing_reaper.stop_test_class_inside_fixtures()
- provision.stop_test_class(config, config.db, cls)
- engines.testing_reaper._stop_test_ctx()
+
+def stop_test_class_outside_fixtures(cls):
+ provision.stop_test_class_outside_fixtures(config, config.db, cls)
try:
if not options.low_connections:
assertions.global_cleanup_assertions()
def final_process_cleanup():
- engines.testing_reaper._stop_test_ctx_aggressive()
+ engines.testing_reaper.final_cleanup()
assertions.global_cleanup_assertions()
_restore_engine()
def _setup_engine(cls):
if getattr(cls, "__engine_options__", None):
- eng = engines.testing_engine(options=cls.__engine_options__)
+ opts = dict(cls.__engine_options__)
+ opts["scope"] = "class"
+ eng = engines.testing_engine(options=opts)
config._current.push_engine(eng, testing)
def after_test(test):
- engines.testing_reaper._after_test_ctx()
+ fixtures.after_test()
+ engines.testing_reaper.after_test()
+
+
+def after_test_fixtures(test):
+ engines.testing_reaper.after_test_outside_fixtures(test)
def _possible_configs_for_cls(cls, reasons=None, sparse=False):
def get_current_test_name(self):
raise NotImplementedError()
+ @abc.abstractmethod
+ def mark_base_test_class(self):
+ raise NotImplementedError()
+
_fixture_fn_class = None
import pytest
+
try:
import typing
except ImportError:
has_xdist = False
+py2k = sys.version_info < (3, 0)
+if py2k:
+ try:
+ import sqla_reinvent_fixtures as reinvent_fixtures_py2k
+ except ImportError:
+ from . import reinvent_fixtures_py2k
+
+
def pytest_addoption(parser):
group = parser.getgroup("sqlalchemy")
else:
newitems.append(item)
+ if py2k:
+ for item in newitems:
+ reinvent_fixtures_py2k.scan_for_fixtures_to_use_for_class(item)
+
# seems like the functions attached to a test class aren't sorted already?
# is that true and why's that? (when using unittest, they're sorted)
items[:] = sorted(
def pytest_pycollect_makeitem(collector, name, obj):
-
if inspect.isclass(obj) and plugin_base.want_class(name, obj):
from sqlalchemy.testing import config
obj = _apply_maybe_async(obj)
ctor = getattr(pytest.Class, "from_parent", pytest.Class)
-
return [
ctor(name=parametrize_cls.__name__, parent=collector)
for parametrize_cls in _parametrize_cls(collector.module, obj)
def _apply_maybe_async(obj, recurse=True):
from sqlalchemy.testing import asyncio
- setup_names = {"setup", "setup_class", "teardown", "teardown_class"}
for name, value in vars(obj).items():
if (
(callable(value) or isinstance(value, classmethod))
and not getattr(value, "_maybe_async_applied", False)
- and (name.startswith("test_") or name in setup_names)
+ and (name.startswith("test_"))
and not _is_wrapped_coroutine_function(value)
):
is_classmethod = False
return obj
-_current_class = None
-
-
def _parametrize_cls(module, cls):
"""implement a class-based version of pytest parametrize."""
return classes
+_current_class = None
+
+
def pytest_runtest_setup(item):
from sqlalchemy.testing import asyncio
- # here we seem to get called only based on what we collected
- # in pytest_collection_modifyitems. So to do class-based stuff
- # we have to tear that out.
- global _current_class
-
if not isinstance(item, pytest.Function):
return
- # ... so we're doing a little dance here to figure it out...
+ # pytest_runtest_setup runs *before* pytest fixtures with scope="class".
+ # plugin_base.start_test_class_outside_fixtures may opt to raise SkipTest
+ # for the whole class and has to run things that are across all current
+ # databases, so we run this outside of the pytest fixture system altogether
+ # and ensure asyncio greenlet if any engines are async
+
+ global _current_class
+
if _current_class is None:
- asyncio._maybe_async(class_setup, item.parent.parent)
+ asyncio._maybe_async_provisioning(
+ plugin_base.start_test_class_outside_fixtures,
+ item.parent.parent.cls,
+ )
_current_class = item.parent.parent
- # this is needed for the class-level, to ensure that the
- # teardown runs after the class is completed with its own
- # class-level teardown...
def finalize():
global _current_class
- asyncio._maybe_async(class_teardown, item.parent.parent)
_current_class = None
+ asyncio._maybe_async_provisioning(
+ plugin_base.stop_test_class_outside_fixtures,
+ item.parent.parent.cls,
+ )
+
item.parent.parent.addfinalizer(finalize)
- asyncio._maybe_async(test_setup, item)
+def pytest_runtest_call(item):
+ # runs inside of pytest function fixture scope
+ # before test function runs
-def pytest_runtest_teardown(item):
from sqlalchemy.testing import asyncio
- # ...but this works better as the hook here rather than
- # using a finalizer, as the finalizer seems to get in the way
- # of the test reporting failures correctly (you get a bunch of
- # pytest assertion stuff instead)
- asyncio._maybe_async(test_teardown, item)
+ asyncio._maybe_async(
+ plugin_base.before_test,
+ item,
+ item.parent.module.__name__,
+ item.parent.cls,
+ item.name,
+ )
-def test_setup(item):
- plugin_base.before_test(
- item, item.parent.module.__name__, item.parent.cls, item.name
- )
+def pytest_runtest_teardown(item, nextitem):
+ # runs inside of pytest function fixture scope
+ # after test function runs
+ from sqlalchemy.testing import asyncio
-def test_teardown(item):
- plugin_base.after_test(item)
+ asyncio._maybe_async(plugin_base.after_test, item)
-def class_setup(item):
+@pytest.fixture(scope="class")
+def setup_class_methods(request):
from sqlalchemy.testing import asyncio
- asyncio._maybe_async_provisioning(plugin_base.start_test_class, item.cls)
+ cls = request.cls
+
+ if hasattr(cls, "setup_test_class"):
+ asyncio._maybe_async(cls.setup_test_class)
+
+ if py2k:
+ reinvent_fixtures_py2k.run_class_fixture_setup(request)
+
+ yield
+
+ if py2k:
+ reinvent_fixtures_py2k.run_class_fixture_teardown(request)
+ if hasattr(cls, "teardown_test_class"):
+ asyncio._maybe_async(cls.teardown_test_class)
-def class_teardown(item):
- plugin_base.stop_test_class(item.cls)
+ asyncio._maybe_async(plugin_base.stop_test_class, cls)
+
+
+@pytest.fixture(scope="function")
+def setup_test_methods(request):
+ from sqlalchemy.testing import asyncio
+
+ # called for each test
+
+ self = request.instance
+
+ # 1. run outer xdist-style setup
+ if hasattr(self, "setup_test"):
+ asyncio._maybe_async(self.setup_test)
+
+ # alembic test suite is using setUp and tearDown
+ # xdist methods; support these in the test suite
+ # for the near term
+ if hasattr(self, "setUp"):
+ asyncio._maybe_async(self.setUp)
+
+ # 2. run homegrown function level "autouse" fixtures under py2k
+ if py2k:
+ reinvent_fixtures_py2k.run_fn_fixture_setup(request)
+
+ # inside the yield:
+
+ # 3. function level "autouse" fixtures under py3k (examples: TablesTest
+ # define tables / data, MappedTest define tables / mappers / data)
+
+ # 4. function level fixtures defined on test functions themselves,
+ # e.g. "connection", "metadata" run next
+
+ # 5. pytest hook pytest_runtest_call then runs
+
+ # 6. test itself runs
+
+ yield
+
+ # yield finishes:
+
+ # 7. pytest hook pytest_runtest_teardown hook runs, this is associated
+ # with fixtures close all sessions, provisioning.stop_test_class(),
+ # engines.testing_reaper -> ensure all connection pool connections
+ # are returned, engines created by testing_engine that aren't the
+ # config engine are disposed
+
+ # 8. function level fixtures defined on test functions
+ # themselves, e.g. "connection" rolls back the transaction, "metadata"
+ # emits drop all
+
+ # 9. function level "autouse" fixtures under py3k (examples: TablesTest /
+ # MappedTest delete table data, possibly drop tables and clear mappers
+ # depending on the flags defined by the test class)
+
+ # 10. run homegrown function-level "autouse" fixtures under py2k
+ if py2k:
+ reinvent_fixtures_py2k.run_fn_fixture_teardown(request)
+
+ asyncio._maybe_async(plugin_base.after_test_fixtures, self)
+
+ # 11. run outer xdist-style teardown
+ if hasattr(self, "tearDown"):
+ asyncio._maybe_async(self.tearDown)
+
+ if hasattr(self, "teardown_test"):
+ asyncio._maybe_async(self.teardown_test)
def getargspec(fn):
# for the wrapped function
decorated.__module__ = fn.__module__
decorated.__name__ = fn.__name__
+ if hasattr(fn, "pytestmark"):
+ decorated.pytestmark = fn.pytestmark
return decorated
return decorate
def skip_test_exception(self, *arg, **kw):
return pytest.skip.Exception(*arg, **kw)
+ def mark_base_test_class(self):
+ return pytest.mark.usefixtures(
+ "setup_class_methods", "setup_test_methods"
+ )
+
_combination_id_fns = {
"i": lambda obj: obj,
"r": repr,
fn = asyncio._maybe_async_wrapper(fn)
# other wrappers may be added here
- # now apply FixtureFunctionMarker
- fn = fixture(fn)
+ if py2k and "autouse" in kw:
+ # py2k workaround for too-slow collection of autouse fixtures
+ # in pytest 4.6.11. See notes in reinvent_fixtures_py2k for
+ # rationale.
+
+ # comment this condition out in order to disable the
+ # py2k workaround entirely.
+ reinvent_fixtures_py2k.add_fixture(fn, fixture)
+ else:
+ # now apply FixtureFunctionMarker
+ fn = fixture(fn)
+
return fn
if fn:
--- /dev/null
+"""
+invent a quick version of pytest autouse fixtures as pytest's unacceptably slow
+collection/high memory use in pytest 4.6.11, which is the highest version that
+works in py2k.
+
+by "too-slow" we mean the test suite can't even manage to be collected for a
+single process in less than 70 seconds or so and memory use seems to be very
+high as well. for two or four workers the job just times out after ten
+minutes.
+
+so instead we have invented a very limited form of these fixtures, as our
+current use of "autouse" fixtures are limited to those in fixtures.py.
+
+assumptions for these fixtures:
+
+1. we are only using "function" or "class" scope
+
+2. the functions must be associated with a test class
+
+3. the fixture functions cannot themselves use pytest fixtures
+
+4. the fixture functions must use yield, not return
+
+When py2k support is removed and we can stay on a modern pytest version, this
+can all be removed.
+
+
+"""
+import collections
+
+
+_py2k_fixture_fn_names = collections.defaultdict(set)
+_py2k_class_fixtures = collections.defaultdict(
+ lambda: collections.defaultdict(set)
+)
+_py2k_function_fixtures = collections.defaultdict(
+ lambda: collections.defaultdict(set)
+)
+
+_py2k_cls_fixture_stack = []
+_py2k_fn_fixture_stack = []
+
+
+def add_fixture(fn, fixture):
+ assert fixture.scope in ("class", "function")
+ _py2k_fixture_fn_names[fn.__name__].add((fn, fixture.scope))
+
+
+def scan_for_fixtures_to_use_for_class(item):
+ test_class = item.parent.parent.obj
+
+ for name in _py2k_fixture_fn_names:
+ for fixture_fn, scope in _py2k_fixture_fn_names[name]:
+ meth = getattr(test_class, name, None)
+ if meth and meth.im_func is fixture_fn:
+ for sup in test_class.__mro__:
+ if name in sup.__dict__:
+ if scope == "class":
+ _py2k_class_fixtures[test_class][sup].add(meth)
+ elif scope == "function":
+ _py2k_function_fixtures[test_class][sup].add(meth)
+ break
+ break
+
+
+def run_class_fixture_setup(request):
+
+ cls = request.cls
+ self = cls.__new__(cls)
+
+ fixtures_for_this_class = _py2k_class_fixtures.get(cls)
+
+ if fixtures_for_this_class:
+ for sup_ in cls.__mro__:
+ for fn in fixtures_for_this_class.get(sup_, ()):
+ iter_ = fn(self)
+ next(iter_)
+
+ _py2k_cls_fixture_stack.append(iter_)
+
+
+def run_class_fixture_teardown(request):
+ while _py2k_cls_fixture_stack:
+ iter_ = _py2k_cls_fixture_stack.pop(-1)
+ try:
+ next(iter_)
+ except StopIteration:
+ pass
+
+
+def run_fn_fixture_setup(request):
+ cls = request.cls
+ self = request.instance
+
+ fixtures_for_this_class = _py2k_function_fixtures.get(cls)
+
+ if fixtures_for_this_class:
+ for sup_ in reversed(cls.__mro__):
+ for fn in fixtures_for_this_class.get(sup_, ()):
+ iter_ = fn(self)
+ next(iter_)
+
+ _py2k_fn_fixture_stack.append(iter_)
+
+
+def run_fn_fixture_teardown(request):
+ while _py2k_fn_fixture_stack:
+ iter_ = _py2k_fn_fixture_stack.pop(-1)
+ try:
+ next(iter_)
+ except StopIteration:
+ pass
db_url = follower_url_from_main(db_url, follower_ident)
db_opts = {}
update_db_opts(db_url, db_opts)
+ db_opts["scope"] = "global"
eng = engines.testing_engine(db_url, db_opts)
post_configure_engine(db_url, eng, follower_ident)
eng.connect().close()
if config.requirements.schemas.enabled_for_config(cfg):
util.drop_all_tables(eng, inspector, schema=cfg.test_schema)
+ util.drop_all_tables(eng, inspector, schema=cfg.test_schema_2)
drop_all_schema_objects_post_tables(cfg, eng)
def post_configure_engine(url, engine, follower_ident):
"""Perform extra steps after configuring an engine for testing.
- (For the internal dialects, currently only used by sqlite.)
+ (For the internal dialects, currently only used by sqlite, oracle)
"""
pass
@register.init
-def stop_test_class(config, db, testcls):
+def prepare_for_drop_tables(config, connection):
+ pass
+
+
+@register.init
+def stop_test_class_outside_fixtures(config, db, testcls):
pass
from sqlalchemy import pool
return engines.testing_engine(
- options=dict(poolclass=pool.StaticPool)
+ options=dict(poolclass=pool.StaticPool, scope="class"),
)
else:
return config.db
)
return self.engine
- def tearDown(self):
- engines.testing_reaper.close_all()
- self.engine.dispose()
-
@testing.combinations(
("global_string", True, "select 1", True),
("global_text", True, text("select 1"), True),
def test_conn_option(self):
engine = self._fixture(False)
- # should be enabled for this one
- result = (
- engine.connect()
- .execution_options(stream_results=True)
- .exec_driver_sql("select 1")
- )
- assert self._is_server_side(result.cursor)
+ with engine.connect() as conn:
+ # should be enabled for this one
+ result = conn.execution_options(
+ stream_results=True
+ ).exec_driver_sql("select 1")
+ assert self._is_server_side(result.cursor)
def test_stmt_enabled_conn_option_disabled(self):
engine = self._fixture(False)
s = select(1).execution_options(stream_results=True)
- # not this one
- result = (
- engine.connect().execution_options(stream_results=False).execute(s)
- )
- assert not self._is_server_side(result.cursor)
+ with engine.connect() as conn:
+ # not this one
+ result = conn.execution_options(stream_results=False).execute(s)
+ assert not self._is_server_side(result.cursor)
def test_aliases_and_ss(self):
engine = self._fixture(False)
assert not self._is_server_side(result.cursor)
result.close()
- @testing.provide_metadata
- def test_roundtrip_fetchall(self):
+ def test_roundtrip_fetchall(self, metadata):
md = self.metadata
engine = self._fixture(True)
0,
)
- @testing.provide_metadata
- def test_roundtrip_fetchmany(self):
+ def test_roundtrip_fetchmany(self, metadata):
md = self.metadata
engine = self._fixture(True)
__backend__ = True
@testing.fixture
- def do_numeric_test(self, metadata):
+ def do_numeric_test(self, metadata, connection):
@testing.emits_warning(
r".*does \*not\* support Decimal objects natively"
)
def run(type_, input_, output, filter_=None, check_scale=False):
t = Table("t", metadata, Column("x", type_))
- t.create(testing.db)
- with config.db.begin() as conn:
- conn.execute(t.insert(), [{"x": x} for x in input_])
-
- result = {row[0] for row in conn.execute(t.select())}
- output = set(output)
- if filter_:
- result = set(filter_(x) for x in result)
- output = set(filter_(x) for x in output)
- eq_(result, output)
- if check_scale:
- eq_([str(x) for x in result], [str(x) for x in output])
+ t.create(connection)
+ connection.execute(t.insert(), [{"x": x} for x in input_])
+
+ result = {row[0] for row in connection.execute(t.select())}
+ output = set(output)
+ if filter_:
+ result = set(filter_(x) for x in result)
+ output = set(filter_(x) for x in output)
+ eq_(result, output)
+ if check_scale:
+ eq_([str(x) for x in result], [str(x) for x in output])
return run
},
)
- def test_eval_none_flag_orm(self):
+ def test_eval_none_flag_orm(self, connection):
Base = declarative_base()
class Data(Base):
__table__ = self.tables.data_table
- s = Session(testing.db)
+ with Session(connection) as s:
+ d1 = Data(name="d1", data=None, nulldata=None)
+ s.add(d1)
+ s.commit()
- d1 = Data(name="d1", data=None, nulldata=None)
- s.add(d1)
- s.commit()
-
- s.bulk_insert_mappings(
- Data, [{"name": "d2", "data": None, "nulldata": None}]
- )
- eq_(
- s.query(
- cast(self.tables.data_table.c.data, String()),
- cast(self.tables.data_table.c.nulldata, String),
+ s.bulk_insert_mappings(
+ Data, [{"name": "d2", "data": None, "nulldata": None}]
)
- .filter(self.tables.data_table.c.name == "d1")
- .first(),
- ("null", None),
- )
- eq_(
- s.query(
- cast(self.tables.data_table.c.data, String()),
- cast(self.tables.data_table.c.nulldata, String),
+ eq_(
+ s.query(
+ cast(self.tables.data_table.c.data, String()),
+ cast(self.tables.data_table.c.nulldata, String),
+ )
+ .filter(self.tables.data_table.c.name == "d1")
+ .first(),
+ ("null", None),
+ )
+ eq_(
+ s.query(
+ cast(self.tables.data_table.c.data, String()),
+ cast(self.tables.data_table.c.nulldata, String),
+ )
+ .filter(self.tables.data_table.c.name == "d2")
+ .first(),
+ ("null", None),
)
- .filter(self.tables.data_table.c.name == "d2")
- .first(),
- ("null", None),
- )
class JSONLegacyStringCastIndexTest(
from . import config
from . import mock
from .. import inspect
+from ..engine import Connection
from ..schema import Column
from ..schema import DropConstraint
from ..schema import DropTable
@decorator
def provide_metadata(fn, *args, **kw):
- """Provide bound MetaData for a single test, dropping afterwards."""
+ """Provide bound MetaData for a single test, dropping afterwards.
- # import cycle that only occurs with py2k's import resolver
- # in py3k this can be moved top level.
- from . import engines
+ Legacy; use the "metadata" pytest fixture.
+
+ """
+
+ from . import fixtures
metadata = schema.MetaData()
self = args[0]
try:
return fn(*args, **kw)
finally:
- engines.drop_all_tables(metadata, config.db)
+ # close out some things that get in the way of dropping tables.
+ # when using the "metadata" fixture, there is a set ordering
+ # of things that makes sure things are cleaned up in order, however
+ # the simple "decorator" nature of this legacy function means
+ # we have to hardcode some of that cleanup ahead of time.
+
+ # close ORM sessions
+ fixtures._close_all_sessions()
+
+ # integrate with the "connection" fixture as there are many
+ # tests where it is used along with provide_metadata
+ if fixtures._connection_fixture_connection:
+ # TODO: this warning can be used to find all the places
+ # this is used with connection fixture
+ # warn("mixing legacy provide metadata with connection fixture")
+ drop_all_tables_from_metadata(
+ metadata, fixtures._connection_fixture_connection
+ )
+ # as the provide_metadata fixture is often used with "testing.db",
+ # when we do the drop we have to commit the transaction so that
+ # the DB is actually updated as the CREATE would have been
+ # committed
+ fixtures._connection_fixture_connection.get_transaction().commit()
+ else:
+ drop_all_tables_from_metadata(metadata, config.db)
self.metadata = prev_meta
get_all = __call__
+def drop_all_tables_from_metadata(metadata, engine_or_connection):
+ from . import engines
+
+ def go(connection):
+ engines.testing_reaper.prepare_for_drop_tables(connection)
+
+ if not connection.dialect.supports_alter:
+ from . import assertions
+
+ with assertions.expect_warnings(
+ "Can't sort tables", assert_=False
+ ):
+ metadata.drop_all(connection)
+ else:
+ metadata.drop_all(connection)
+
+ if not isinstance(engine_or_connection, Connection):
+ with engine_or_connection.begin() as connection:
+ go(connection)
+ else:
+ go(engine_or_connection)
+
+
def drop_all_tables(engine, inspector, schema=None, include_names=None):
if include_names is not None:
return self.put_nowait(item)
try:
- if timeout:
+ if timeout is not None:
return self.await_(
asyncio.wait_for(self._queue.put(item), timeout)
)
else:
return self.await_(self._queue.put(item))
- except asyncio.queues.QueueFull as err:
+ except (
+ asyncio.queues.QueueFull,
+ asyncio.exceptions.TimeoutError,
+ ) as err:
compat.raise_(
Full(),
replace_context=err,
def get(self, block=True, timeout=None):
if not block:
return self.get_nowait()
+
try:
- if timeout:
+ if timeout is not None:
return self.await_(
asyncio.wait_for(self._queue.get(), timeout)
)
else:
return self.await_(self._queue.get())
- except asyncio.queues.QueueEmpty as err:
+ except (
+ asyncio.queues.QueueEmpty,
+ asyncio.exceptions.TimeoutError,
+ ) as err:
compat.raise_(
Empty(),
replace_context=err,
__backend__ = True
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
global t1, t2, metadata
metadata = MetaData()
class EnsureZeroed(fixtures.ORMTest):
- def setup(self):
+ def setup_test(self):
_sessions.clear()
_mapper_registry.clear()
t2_mapper = mapper(T2, t2)
t1_mapper.add_property("bar", relationship(t2_mapper))
- s1 = fixture_session()
+ s1 = Session(testing.db)
# this causes the path_registry to be invoked
s1.query(t1_mapper)._compile_context()
class EnumTest(fixtures.TestBase):
__requires__ = ("cpython", "python_profiling_backend")
- def setup(self):
+ def setup_test(self):
class SomeEnum(object):
# Implements PEP 435 in the minimal fashion needed by SQLAlchemy
run_setup_bind = "each"
@classmethod
- def setup_class(cls):
- super(NoCache, cls).setup_class()
+ def setup_test_class(cls):
cls._cache = config.db._compiled_cache
config.db._compiled_cache = None
@classmethod
- def teardown_class(cls):
+ def teardown_test_class(cls):
config.db._compiled_cache = cls._cache
- super(NoCache, cls).teardown_class()
class MergeTest(NoCache, fixtures.MappedTest):
def close(self):
pass
- def setup(self):
+ def setup_test(self):
# create a throwaway pool which
# has the effect of initializing
# class-level event listeners on Pool,
class TearDownLocalEventsFixture(object):
- def tearDown(self):
+ def teardown_test(self):
classes = set()
for entry in event.base._registrars.values():
for evt_cls in entry:
class EventsTest(TearDownLocalEventsFixture, fixtures.TestBase):
"""Test class- and instance-level event registration."""
- def setUp(self):
+ def setup_test(self):
class TargetEvents(event.Events):
def event_one(self, x, y):
pass
class LegacySignatureTest(TearDownLocalEventsFixture, fixtures.TestBase):
"""test adaption of legacy args"""
- def setUp(self):
+ def setup_test(self):
class TargetEventsOne(event.Events):
@event._legacy_signature("0.9", ["x", "y"])
def event_three(self, x, y, z, q):
class ClsLevelListenTest(TearDownLocalEventsFixture, fixtures.TestBase):
- def setUp(self):
+ def setup_test(self):
class TargetEventsOne(event.Events):
def event_one(self, x, y):
pass
class AcceptTargetsTest(TearDownLocalEventsFixture, fixtures.TestBase):
"""Test default target acceptance."""
- def setUp(self):
+ def setup_test(self):
class TargetEventsOne(event.Events):
def event_one(self, x, y):
pass
class CustomTargetsTest(TearDownLocalEventsFixture, fixtures.TestBase):
"""Test custom target acceptance."""
- def setUp(self):
+ def setup_test(self):
class TargetEvents(event.Events):
@classmethod
def _accept_with(cls, target):
class SubclassGrowthTest(TearDownLocalEventsFixture, fixtures.TestBase):
"""test that ad-hoc subclasses are garbage collected."""
- def setUp(self):
+ def setup_test(self):
class TargetEvents(event.Events):
def some_event(self, x, y):
pass
"""Test custom listen functions which change the listener function
signature."""
- def setUp(self):
+ def setup_test(self):
class TargetEvents(event.Events):
@classmethod
def _listen(cls, event_key, add=False):
class PropagateTest(TearDownLocalEventsFixture, fixtures.TestBase):
- def setUp(self):
+ def setup_test(self):
class TargetEvents(event.Events):
def event_one(self, arg):
pass
class JoinTest(TearDownLocalEventsFixture, fixtures.TestBase):
- def setUp(self):
+ def setup_test(self):
class TargetEvents(event.Events):
def event_one(self, target, arg):
pass
class DisableClsPropagateTest(TearDownLocalEventsFixture, fixtures.TestBase):
- def setUp(self):
+ def setup_test(self):
class TargetEvents(event.Events):
def event_one(self, target, arg):
pass
class TestInspection(fixtures.TestBase):
- def tearDown(self):
+ def teardown_test(self):
for type_ in list(inspection._registrars):
if issubclass(type_, TestFixture):
del inspection._registrars[type_]
ddl.sort_tables_and_constraints = self.orig_sort
- def setup(self):
+ def setup_test(self):
self._setup_logger()
self._setup_create_table_patcher()
- def teardown(self):
+ def teardown_test(self):
self._teardown_create_table_patcher()
self._teardown_logger()
class SchemaTest(fixtures.TestBase):
- def setup(self):
+ def setup_test(self):
t = Table(
"sometable",
MetaData(),
"""
- def setup(self):
+ def setup_test(self):
metadata = MetaData()
self.t1 = table(
"t1",
return testing.db.execution_options(isolation_level="AUTOCOMMIT")
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
with testing.db.connect().execution_options(
isolation_level="AUTOCOMMIT"
) as conn:
conn.exec_driver_sql("DROP FULLTEXT CATALOG Catalog")
except:
pass
- super(MatchTest, cls).setup_class()
@classmethod
def insert_data(cls, connection):
class InsertOnDuplicateTest(fixtures.TestBase, AssertsCompiledSQL):
__dialect__ = mysql.dialect()
- def setup(self):
+ def setup_test(self):
self.table = Table(
"foos",
MetaData(),
class RegexpCommon(testing.AssertsCompiledSQL):
- def setUp(self):
+ def setup_test(self):
self.table = table(
"mytable", column("myid", Integer), column("name", String)
)
class RawReflectionTest(fixtures.TestBase):
__backend__ = True
- def setup(self):
+ def setup_test(self):
dialect = mysql.dialect()
self.parser = _reflection.MySQLTableDefinitionParser(
dialect, dialect.identifier_preparer
class RegexpTest(fixtures.TestBase, testing.AssertsCompiledSQL):
__dialect__ = "oracle"
- def setUp(self):
+ def setup_test(self):
self.table = table(
"mytable", column("myid", Integer), column("name", String)
)
__backend__ = True
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
with testing.db.begin() as c:
c.exec_driver_sql(
"""
assert isinstance(result.out_parameters["x_out"], int)
@classmethod
- def teardown_class(cls):
+ def teardown_test_class(cls):
with testing.db.begin() as conn:
conn.execute(text("DROP PROCEDURE foo"))
__backend__ = True
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
# currently assuming full DBA privs for the user.
# don't really know how else to go here unless
# we connect as the other user.
conn.exec_driver_sql(stmt)
@classmethod
- def teardown_class(cls):
+ def teardown_test_class(cls):
with testing.db.begin() as conn:
for stmt in (
"""
__only_on__ = "oracle"
__backend__ = True
- def setup(self):
+ def setup_test(self):
with testing.db.begin() as conn:
conn.exec_driver_sql("create table my_table (id integer)")
conn.exec_driver_sql(
"create table foo_table (id integer) tablespace SYSTEM"
)
- def teardown(self):
+ def teardown_test(self):
with testing.db.begin() as conn:
conn.exec_driver_sql("drop table my_temp_table")
conn.exec_driver_sql("drop table my_table")
__only_on__ = "oracle"
__backend__ = True
- def setup(self):
+ def setup_test(self):
with testing.db.begin() as conn:
conn.exec_driver_sql(
"""
""",
)
- def teardown(self):
+ def teardown_test(self):
with testing.db.begin() as conn:
conn.exec_driver_sql("drop table admin_docindex")
__backend__ = True
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
from sqlalchemy.testing import config
cls.dblink = config.file_config.get("sqla_testing", "oracle_db_link")
)
@classmethod
- def teardown_class(cls):
+ def teardown_test_class(cls):
with testing.db.begin() as conn:
conn.exec_driver_sql("drop synonym test_table_syn")
conn.exec_driver_sql("drop table test_table")
__only_on__ = "oracle+cx_oracle"
__backend__ = True
- def setup(self):
+ def setup_test(self):
connect = testing.db.pool._creator
def _creator():
self.engine = testing_engine(options={"creator": _creator})
- def teardown(self):
+ def teardown_test(self):
self.engine.dispose()
def test_were_getting_a_comma(self):
# TODO: remove when Iae6ab95938a7e92b6d42086aec534af27b5577d3
# merges
- from sqlalchemy.testing import engines
+ from sqlalchemy.testing import util as testing_util
from sqlalchemy.sql import schema
metadata = schema.MetaData()
try:
yield metadata
finally:
- engines.drop_all_tables(metadata, testing.db)
+ testing_util.drop_all_tables_from_metadata(metadata, testing.db)
@async_test
async def test_detect_stale_ddl_cache_raise_recover(
class InsertOnConflictTest(fixtures.TestBase, AssertsCompiledSQL):
__dialect__ = postgresql.dialect()
- def setup(self):
+ def setup_test(self):
self.table1 = table1 = table(
"mytable",
column("myid", Integer),
__dialect__ = postgresql.dialect()
- def setup(self):
+ def setup_test(self):
self.table = Table(
"t",
MetaData(),
__dialect__ = postgresql.dialect()
- def setup(self):
+ def setup_test(self):
self.table = Table(
"t",
MetaData(),
class RegexpTest(fixtures.TestBase, testing.AssertsCompiledSQL):
__dialect__ = "postgresql"
- def setUp(self):
+ def setup_test(self):
self.table = table(
"mytable", column("myid", Integer), column("name", String)
)
@config.fixture()
def connection(self):
- eng = engines.testing_engine(options=self.options)
+ opts = dict(self.options)
+ opts["use_reaper"] = False
+ eng = engines.testing_engine(options=opts)
conn = eng.connect()
trans = conn.begin()
- try:
- yield conn
- finally:
- if trans.is_active:
- trans.rollback()
- conn.close()
- eng.dispose()
+ yield conn
+ if trans.is_active:
+ trans.rollback()
+ conn.close()
+ eng.dispose()
@classmethod
def define_tables(cls, metadata):
# assert result.closed
assert result.cursor is None
- @testing.provide_metadata
- def test_insert_returning_preexecute_pk(self, connection):
+ def test_insert_returning_preexecute_pk(self, metadata, connection):
counter = itertools.count(1)
t = Table(
),
Column("data", Integer),
)
- self.metadata.create_all(connection)
+ metadata.create_all(connection)
result = connection.execute(
t.insert().return_defaults(),
__only_on__ = "postgresql"
__backend__ = True
- def setup(self):
+ def setup_test(self):
self.metadata = MetaData()
- def teardown(self):
+ def teardown_test(self):
with testing.db.begin() as conn:
self.metadata.drop_all(conn)
def setup_bind(cls):
from sqlalchemy import event
- eng = engines.testing_engine()
+ eng = engines.testing_engine(options={"scope": "class"})
@event.listens_for(eng, "connect")
def connect(dbapi_conn, rec):
]:
sa.event.listen(metadata, "before_drop", sa.DDL(ddl))
- def test_foreign_table_is_reflected(self):
+ def test_foreign_table_is_reflected(self, connection):
metadata = MetaData()
- table = Table("test_foreigntable", metadata, autoload_with=testing.db)
+ table = Table("test_foreigntable", metadata, autoload_with=connection)
eq_(
set(table.columns.keys()),
set(["id", "data"]),
"Columns of reflected foreign table didn't equal expected columns",
)
- def test_get_foreign_table_names(self):
- inspector = inspect(testing.db)
- with testing.db.connect():
- ft_names = inspector.get_foreign_table_names()
- eq_(ft_names, ["test_foreigntable"])
+ def test_get_foreign_table_names(self, connection):
+ inspector = inspect(connection)
+ ft_names = inspector.get_foreign_table_names()
+ eq_(ft_names, ["test_foreigntable"])
- def test_get_table_names_no_foreign(self):
- inspector = inspect(testing.db)
- with testing.db.connect():
- names = inspector.get_table_names()
- eq_(names, ["testtable"])
+ def test_get_table_names_no_foreign(self, connection):
+ inspector = inspect(connection)
+ names = inspector.get_table_names()
+ eq_(names, ["testtable"])
class PartitionedReflectionTest(fixtures.TablesTest, AssertsExecutionResults):
if testing.against("postgresql >= 11"):
Index("my_index", dv.c.q)
- def test_get_tablenames(self):
+ def test_get_tablenames(self, connection):
assert {"data_values", "data_values_4_10"}.issubset(
- inspect(testing.db).get_table_names()
+ inspect(connection).get_table_names()
)
- def test_reflect_cols(self):
- cols = inspect(testing.db).get_columns("data_values")
+ def test_reflect_cols(self, connection):
+ cols = inspect(connection).get_columns("data_values")
eq_([c["name"] for c in cols], ["modulus", "data", "q"])
- def test_reflect_cols_from_partition(self):
- cols = inspect(testing.db).get_columns("data_values_4_10")
+ def test_reflect_cols_from_partition(self, connection):
+ cols = inspect(connection).get_columns("data_values_4_10")
eq_([c["name"] for c in cols], ["modulus", "data", "q"])
@testing.only_on("postgresql >= 11")
- def test_reflect_index(self):
- idx = inspect(testing.db).get_indexes("data_values")
+ def test_reflect_index(self, connection):
+ idx = inspect(connection).get_indexes("data_values")
eq_(
idx,
[
)
@testing.only_on("postgresql >= 11")
- def test_reflect_index_from_partition(self):
- idx = inspect(testing.db).get_indexes("data_values_4_10")
+ def test_reflect_index_from_partition(self, connection):
+ idx = inspect(connection).get_indexes("data_values_4_10")
# note the name appears to be generated by PG, currently
# 'data_values_4_10_q_idx'
eq_(
testtable, "before_drop", sa.DDL("DROP VIEW test_regview")
)
- def test_mview_is_reflected(self):
+ def test_mview_is_reflected(self, connection):
metadata = MetaData()
- table = Table("test_mview", metadata, autoload_with=testing.db)
+ table = Table("test_mview", metadata, autoload_with=connection)
eq_(
set(table.columns.keys()),
set(["id", "data"]),
"Columns of reflected mview didn't equal expected columns",
)
- def test_mview_select(self):
+ def test_mview_select(self, connection):
metadata = MetaData()
- table = Table("test_mview", metadata, autoload_with=testing.db)
- with testing.db.connect() as conn:
- eq_(conn.execute(table.select()).fetchall(), [(89, "d1")])
+ table = Table("test_mview", metadata, autoload_with=connection)
+ eq_(connection.execute(table.select()).fetchall(), [(89, "d1")])
- def test_get_view_names(self):
- insp = inspect(testing.db)
+ def test_get_view_names(self, connection):
+ insp = inspect(connection)
eq_(set(insp.get_view_names()), set(["test_regview", "test_mview"]))
- def test_get_view_names_plain(self):
- insp = inspect(testing.db)
+ def test_get_view_names_plain(self, connection):
+ insp = inspect(connection)
eq_(
set(insp.get_view_names(include=("plain",))), set(["test_regview"])
)
- def test_get_view_names_plain_string(self):
- insp = inspect(testing.db)
+ def test_get_view_names_plain_string(self, connection):
+ insp = inspect(connection)
eq_(set(insp.get_view_names(include="plain")), set(["test_regview"]))
- def test_get_view_names_materialized(self):
- insp = inspect(testing.db)
+ def test_get_view_names_materialized(self, connection):
+ insp = inspect(connection)
eq_(
set(insp.get_view_names(include=("materialized",))),
set(["test_mview"]),
)
- def test_get_view_names_reflection_cache_ok(self):
- insp = inspect(testing.db)
+ def test_get_view_names_reflection_cache_ok(self, connection):
+ insp = inspect(connection)
eq_(
set(insp.get_view_names(include=("plain",))), set(["test_regview"])
)
)
eq_(set(insp.get_view_names()), set(["test_regview", "test_mview"]))
- def test_get_view_names_empty(self):
- insp = inspect(testing.db)
+ def test_get_view_names_empty(self, connection):
+ insp = inspect(connection)
assert_raises(ValueError, insp.get_view_names, include=())
- def test_get_view_definition(self):
- insp = inspect(testing.db)
+ def test_get_view_definition(self, connection):
+ insp = inspect(connection)
eq_(
re.sub(
r"[\n\t ]+",
__backend__ = True
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
with testing.db.begin() as con:
for ddl in [
'CREATE SCHEMA "SomeSchema"',
)
@classmethod
- def teardown_class(cls):
+ def teardown_test_class(cls):
with testing.db.begin() as con:
con.exec_driver_sql("DROP TABLE testtable")
con.exec_driver_sql("DROP TABLE test_schema.testtable")
con.exec_driver_sql('DROP DOMAIN "SomeSchema"."Quoted.Domain"')
con.exec_driver_sql('DROP SCHEMA "SomeSchema"')
- def test_table_is_reflected(self):
+ def test_table_is_reflected(self, connection):
metadata = MetaData()
- table = Table("testtable", metadata, autoload_with=testing.db)
+ table = Table("testtable", metadata, autoload_with=connection)
eq_(
set(table.columns.keys()),
set(["question", "answer"]),
)
assert isinstance(table.c.answer.type, Integer)
- def test_domain_is_reflected(self):
+ def test_domain_is_reflected(self, connection):
metadata = MetaData()
- table = Table("testtable", metadata, autoload_with=testing.db)
+ table = Table("testtable", metadata, autoload_with=connection)
eq_(
str(table.columns.answer.server_default.arg),
"42",
not table.columns.answer.nullable
), "Expected reflected column to not be nullable."
- def test_enum_domain_is_reflected(self):
+ def test_enum_domain_is_reflected(self, connection):
metadata = MetaData()
- table = Table("enum_test", metadata, autoload_with=testing.db)
+ table = Table("enum_test", metadata, autoload_with=connection)
eq_(table.c.data.type.enums, ["test"])
- def test_array_domain_is_reflected(self):
+ def test_array_domain_is_reflected(self, connection):
metadata = MetaData()
- table = Table("array_test", metadata, autoload_with=testing.db)
+ table = Table("array_test", metadata, autoload_with=connection)
eq_(table.c.data.type.__class__, ARRAY)
eq_(table.c.data.type.item_type.__class__, INTEGER)
- def test_quoted_remote_schema_domain_is_reflected(self):
+ def test_quoted_remote_schema_domain_is_reflected(self, connection):
metadata = MetaData()
- table = Table("quote_test", metadata, autoload_with=testing.db)
+ table = Table("quote_test", metadata, autoload_with=connection)
eq_(table.c.data.type.__class__, INTEGER)
- def test_table_is_reflected_test_schema(self):
+ def test_table_is_reflected_test_schema(self, connection):
metadata = MetaData()
table = Table(
"testtable",
metadata,
- autoload_with=testing.db,
+ autoload_with=connection,
schema="test_schema",
)
eq_(
)
assert isinstance(table.c.anything.type, Integer)
- def test_schema_domain_is_reflected(self):
+ def test_schema_domain_is_reflected(self, connection):
metadata = MetaData()
table = Table(
"testtable",
metadata,
- autoload_with=testing.db,
+ autoload_with=connection,
schema="test_schema",
)
eq_(
table.columns.answer.nullable
), "Expected reflected column to be nullable."
- def test_crosschema_domain_is_reflected(self):
+ def test_crosschema_domain_is_reflected(self, connection):
metadata = MetaData()
- table = Table("crosschema", metadata, autoload_with=testing.db)
+ table = Table("crosschema", metadata, autoload_with=connection)
eq_(
str(table.columns.answer.server_default.arg),
"0",
table.columns.answer.nullable
), "Expected reflected column to be nullable."
- def test_unknown_types(self):
+ def test_unknown_types(self, connection):
from sqlalchemy.dialects.postgresql import base
ischema_names = base.PGDialect.ischema_names
try:
m2 = MetaData()
assert_raises(
- exc.SAWarning, Table, "testtable", m2, autoload_with=testing.db
+ exc.SAWarning, Table, "testtable", m2, autoload_with=connection
)
@testing.emits_warning("Did not recognize type")
def warns():
m3 = MetaData()
- t3 = Table("testtable", m3, autoload_with=testing.db)
+ t3 = Table("testtable", m3, autoload_with=connection)
assert t3.c.answer.type.__class__ == sa.types.NullType
finally:
subject = Table("subject", meta2, autoload_with=connection)
eq_(subject.primary_key.columns.keys(), ["p2", "p1"])
- @testing.provide_metadata
- def test_pg_weirdchar_reflection(self):
- meta1 = self.metadata
+ def test_pg_weirdchar_reflection(self, metadata, connection):
+ meta1 = metadata
subject = Table(
"subject", meta1, Column("id$", Integer, primary_key=True)
)
Column("id", Integer, primary_key=True),
Column("ref", Integer, ForeignKey("subject.id$")),
)
- meta1.create_all(testing.db)
+ meta1.create_all(connection)
meta2 = MetaData()
- subject = Table("subject", meta2, autoload_with=testing.db)
- referer = Table("referer", meta2, autoload_with=testing.db)
+ subject = Table("subject", meta2, autoload_with=connection)
+ referer = Table("referer", meta2, autoload_with=connection)
self.assert_(
(subject.c["id$"] == referer.c.ref).compare(
subject.join(referer).onclause
)
)
- @testing.provide_metadata
- def test_reflect_default_over_128_chars(self):
+ def test_reflect_default_over_128_chars(self, metadata, connection):
Table(
"t",
- self.metadata,
+ metadata,
Column("x", String(200), server_default="abcd" * 40),
- ).create(testing.db)
+ ).create(connection)
m = MetaData()
- t = Table("t", m, autoload_with=testing.db)
+ t = Table("t", m, autoload_with=connection)
eq_(
t.c.x.server_default.arg.text,
"'%s'::character varying" % ("abcd" * 40),
)
- @testing.fails_if("postgresql < 8.1", "schema name leaks in, not sure")
- @testing.provide_metadata
- def test_renamed_sequence_reflection(self):
- metadata = self.metadata
+ def test_renamed_sequence_reflection(self, metadata, connection):
Table("t", metadata, Column("id", Integer, primary_key=True))
- metadata.create_all(testing.db)
+ metadata.create_all(connection)
m2 = MetaData()
- t2 = Table("t", m2, autoload_with=testing.db, implicit_returning=False)
+ t2 = Table("t", m2, autoload_with=connection, implicit_returning=False)
eq_(t2.c.id.server_default.arg.text, "nextval('t_id_seq'::regclass)")
- with testing.db.begin() as conn:
- r = conn.execute(t2.insert())
- eq_(r.inserted_primary_key, (1,))
+ r = connection.execute(t2.insert())
+ eq_(r.inserted_primary_key, (1,))
- with testing.db.begin() as conn:
- conn.exec_driver_sql(
- "alter table t_id_seq rename to foobar_id_seq"
- )
+ connection.exec_driver_sql(
+ "alter table t_id_seq rename to foobar_id_seq"
+ )
m3 = MetaData()
- t3 = Table("t", m3, autoload_with=testing.db, implicit_returning=False)
+ t3 = Table("t", m3, autoload_with=connection, implicit_returning=False)
eq_(
t3.c.id.server_default.arg.text,
"nextval('foobar_id_seq'::regclass)",
)
- with testing.db.begin() as conn:
- r = conn.execute(t3.insert())
- eq_(r.inserted_primary_key, (2,))
+ r = connection.execute(t3.insert())
+ eq_(r.inserted_primary_key, (2,))
- @testing.provide_metadata
- def test_altered_type_autoincrement_pk_reflection(self):
- metadata = self.metadata
+ def test_altered_type_autoincrement_pk_reflection(
+ self, metadata, connection
+ ):
+ metadata = metadata
Table(
"t",
metadata,
Column("id", Integer, primary_key=True),
Column("x", Integer),
)
- metadata.create_all(testing.db)
+ metadata.create_all(connection)
- with testing.db.begin() as conn:
- conn.exec_driver_sql(
- "alter table t alter column id type varchar(50)"
- )
+ connection.exec_driver_sql(
+ "alter table t alter column id type varchar(50)"
+ )
m2 = MetaData()
- t2 = Table("t", m2, autoload_with=testing.db)
+ t2 = Table("t", m2, autoload_with=connection)
eq_(t2.c.id.autoincrement, False)
eq_(t2.c.x.autoincrement, False)
- @testing.provide_metadata
- def test_renamed_pk_reflection(self):
- metadata = self.metadata
+ def test_renamed_pk_reflection(self, metadata, connection):
+ metadata = metadata
Table("t", metadata, Column("id", Integer, primary_key=True))
- metadata.create_all(testing.db)
- with testing.db.begin() as conn:
- conn.exec_driver_sql("alter table t rename id to t_id")
+ metadata.create_all(connection)
+ connection.exec_driver_sql("alter table t rename id to t_id")
m2 = MetaData()
- t2 = Table("t", m2, autoload_with=testing.db)
+ t2 = Table("t", m2, autoload_with=connection)
eq_([c.name for c in t2.primary_key], ["t_id"])
- @testing.provide_metadata
- def test_has_temporary_table(self):
- assert not inspect(testing.db).has_table("some_temp_table")
+ def test_has_temporary_table(self, metadata, connection):
+ assert not inspect(connection).has_table("some_temp_table")
user_tmp = Table(
"some_temp_table",
- self.metadata,
+ metadata,
Column("id", Integer, primary_key=True),
Column("name", String(50)),
prefixes=["TEMPORARY"],
)
- user_tmp.create(testing.db)
- assert inspect(testing.db).has_table("some_temp_table")
+ user_tmp.create(connection)
+ assert inspect(connection).has_table("some_temp_table")
def test_cross_schema_reflection_one(self, metadata, connection):
A_table.create(connection, checkfirst=True)
assert inspect(connection).has_table("A")
- def test_uppercase_lowercase_sequence(self):
+ def test_uppercase_lowercase_sequence(self, connection):
a_seq = Sequence("a")
A_seq = Sequence("A")
- a_seq.create(testing.db)
- assert testing.db.dialect.has_sequence(testing.db, "a")
- assert not testing.db.dialect.has_sequence(testing.db, "A")
- A_seq.create(testing.db, checkfirst=True)
- assert testing.db.dialect.has_sequence(testing.db, "A")
+ a_seq.create(connection)
+ assert connection.dialect.has_sequence(connection, "a")
+ assert not connection.dialect.has_sequence(connection, "A")
+ A_seq.create(connection, checkfirst=True)
+ assert connection.dialect.has_sequence(connection, "A")
- a_seq.drop(testing.db)
- A_seq.drop(testing.db)
+ a_seq.drop(connection)
+ A_seq.drop(connection)
def test_index_reflection(self, metadata, connection):
"""Reflecting expression-based indexes should warn"""
],
)
- @testing.provide_metadata
- def test_index_reflection_partial(self, connection):
+ def test_index_reflection_partial(self, metadata, connection):
"""Reflect the filter defintion on partial indexes"""
- metadata = self.metadata
+ metadata = metadata
t1 = Table(
"table1",
metadata.create_all(connection)
- ind = testing.db.dialect.get_indexes(connection, t1, None)
+ ind = connection.dialect.get_indexes(connection, t1, None)
partial_definitions = []
for ix in ind:
compile_exprs(r3.expressions),
)
- @testing.provide_metadata
- def test_index_reflection_modified(self):
+ def test_index_reflection_modified(self, metadata, connection):
"""reflect indexes when a column name has changed - PG 9
does not update the name of the column in the index def.
[ticket:2141]
"""
- metadata = self.metadata
+ metadata = metadata
Table(
"t",
Column("id", Integer, primary_key=True),
Column("x", Integer),
)
- metadata.create_all(testing.db)
- with testing.db.begin() as conn:
- conn.exec_driver_sql("CREATE INDEX idx1 ON t (x)")
- conn.exec_driver_sql("ALTER TABLE t RENAME COLUMN x to y")
+ metadata.create_all(connection)
+ connection.exec_driver_sql("CREATE INDEX idx1 ON t (x)")
+ connection.exec_driver_sql("ALTER TABLE t RENAME COLUMN x to y")
- ind = testing.db.dialect.get_indexes(conn, "t", None)
- expected = [
- {"name": "idx1", "unique": False, "column_names": ["y"]}
- ]
- if testing.requires.index_reflects_included_columns.enabled:
- expected[0]["include_columns"] = []
+ ind = connection.dialect.get_indexes(connection, "t", None)
+ expected = [{"name": "idx1", "unique": False, "column_names": ["y"]}]
+ if testing.requires.index_reflects_included_columns.enabled:
+ expected[0]["include_columns"] = []
- eq_(ind, expected)
+ eq_(ind, expected)
- @testing.fails_if("postgresql < 8.2", "reloptions not supported")
- @testing.provide_metadata
- def test_index_reflection_with_storage_options(self):
+ def test_index_reflection_with_storage_options(self, metadata, connection):
"""reflect indexes with storage options set"""
- metadata = self.metadata
+ metadata = metadata
Table(
"t",
Column("id", Integer, primary_key=True),
Column("x", Integer),
)
- metadata.create_all(testing.db)
+ metadata.create_all(connection)
- with testing.db.begin() as conn:
- conn.exec_driver_sql(
- "CREATE INDEX idx1 ON t (x) WITH (fillfactor = 50)"
- )
+ connection.exec_driver_sql(
+ "CREATE INDEX idx1 ON t (x) WITH (fillfactor = 50)"
+ )
- ind = testing.db.dialect.get_indexes(conn, "t", None)
+ ind = testing.db.dialect.get_indexes(connection, "t", None)
- expected = [
- {
- "unique": False,
- "column_names": ["x"],
- "name": "idx1",
- "dialect_options": {
- "postgresql_with": {"fillfactor": "50"}
- },
- }
- ]
- if testing.requires.index_reflects_included_columns.enabled:
- expected[0]["include_columns"] = []
- eq_(ind, expected)
+ expected = [
+ {
+ "unique": False,
+ "column_names": ["x"],
+ "name": "idx1",
+ "dialect_options": {"postgresql_with": {"fillfactor": "50"}},
+ }
+ ]
+ if testing.requires.index_reflects_included_columns.enabled:
+ expected[0]["include_columns"] = []
+ eq_(ind, expected)
- m = MetaData()
- t1 = Table("t", m, autoload_with=conn)
- eq_(
- list(t1.indexes)[0].dialect_options["postgresql"]["with"],
- {"fillfactor": "50"},
- )
+ m = MetaData()
+ t1 = Table("t", m, autoload_with=connection)
+ eq_(
+ list(t1.indexes)[0].dialect_options["postgresql"]["with"],
+ {"fillfactor": "50"},
+ )
- @testing.provide_metadata
- def test_index_reflection_with_access_method(self):
+ def test_index_reflection_with_access_method(self, metadata, connection):
"""reflect indexes with storage options set"""
- metadata = self.metadata
-
Table(
"t",
metadata,
Column("id", Integer, primary_key=True),
Column("x", ARRAY(Integer)),
)
- metadata.create_all(testing.db)
- with testing.db.begin() as conn:
- conn.exec_driver_sql("CREATE INDEX idx1 ON t USING gin (x)")
+ metadata.create_all(connection)
+ connection.exec_driver_sql("CREATE INDEX idx1 ON t USING gin (x)")
- ind = testing.db.dialect.get_indexes(conn, "t", None)
- expected = [
- {
- "unique": False,
- "column_names": ["x"],
- "name": "idx1",
- "dialect_options": {"postgresql_using": "gin"},
- }
- ]
- if testing.requires.index_reflects_included_columns.enabled:
- expected[0]["include_columns"] = []
- eq_(ind, expected)
- m = MetaData()
- t1 = Table("t", m, autoload_with=conn)
- eq_(
- list(t1.indexes)[0].dialect_options["postgresql"]["using"],
- "gin",
- )
+ ind = testing.db.dialect.get_indexes(connection, "t", None)
+ expected = [
+ {
+ "unique": False,
+ "column_names": ["x"],
+ "name": "idx1",
+ "dialect_options": {"postgresql_using": "gin"},
+ }
+ ]
+ if testing.requires.index_reflects_included_columns.enabled:
+ expected[0]["include_columns"] = []
+ eq_(ind, expected)
+ m = MetaData()
+ t1 = Table("t", m, autoload_with=connection)
+ eq_(
+ list(t1.indexes)[0].dialect_options["postgresql"]["using"],
+ "gin",
+ )
@testing.skip_if("postgresql < 11.0", "indnkeyatts not supported")
def test_index_reflection_with_include(self, metadata, connection):
# [{'column_names': ['x', 'name'],
# 'name': 'idx1', 'unique': False}]
- ind = testing.db.dialect.get_indexes(connection, "t", None)
+ ind = connection.dialect.get_indexes(connection, "t", None)
eq_(
ind,
[
for fk in fks:
eq_(fk, fk_ref[fk["name"]])
- @testing.provide_metadata
- def test_inspect_enums_schema(self, connection):
+ def test_inspect_enums_schema(self, metadata, connection):
enum_type = postgresql.ENUM(
"sad",
"ok",
"happy",
name="mood",
schema="test_schema",
- metadata=self.metadata,
+ metadata=metadata,
)
enum_type.create(connection)
inspector = inspect(connection)
],
)
- @testing.provide_metadata
- def test_inspect_enums(self):
+ def test_inspect_enums(self, metadata, connection):
enum_type = postgresql.ENUM(
- "cat", "dog", "rat", name="pet", metadata=self.metadata
+ "cat", "dog", "rat", name="pet", metadata=metadata
)
- enum_type.create(testing.db)
- inspector = inspect(testing.db)
+ enum_type.create(connection)
+ inspector = inspect(connection)
eq_(
inspector.get_enums(),
[
],
)
- @testing.provide_metadata
- def test_inspect_enums_case_sensitive(self):
+ def test_inspect_enums_case_sensitive(self, metadata, connection):
sa.event.listen(
- self.metadata,
+ metadata,
"before_create",
sa.DDL('create schema "TestSchema"'),
)
sa.event.listen(
- self.metadata,
+ metadata,
"after_drop",
- sa.DDL('drop schema "TestSchema" cascade'),
+ sa.DDL('drop schema if exists "TestSchema" cascade'),
)
for enum in "lower_case", "UpperCase", "Name.With.Dot":
"CapsTwo",
name=enum,
schema=schema,
- metadata=self.metadata,
+ metadata=metadata,
)
- self.metadata.create_all(testing.db)
- inspector = inspect(testing.db)
+ metadata.create_all(connection)
+ inspector = inspect(connection)
for schema in None, "test_schema", "TestSchema":
eq_(
sorted(
],
)
- @testing.provide_metadata
- def test_inspect_enums_case_sensitive_from_table(self):
+ def test_inspect_enums_case_sensitive_from_table(
+ self, metadata, connection
+ ):
sa.event.listen(
- self.metadata,
+ metadata,
"before_create",
sa.DDL('create schema "TestSchema"'),
)
sa.event.listen(
- self.metadata,
+ metadata,
"after_drop",
- sa.DDL('drop schema "TestSchema" cascade'),
+ sa.DDL('drop schema if exists "TestSchema" cascade'),
)
counter = itertools.count()
"CapsOne",
"CapsTwo",
name=enum,
- metadata=self.metadata,
+ metadata=metadata,
schema=schema,
)
Table(
"t%d" % next(counter),
- self.metadata,
+ metadata,
Column("q", enum_type),
)
- self.metadata.create_all(testing.db)
+ metadata.create_all(connection)
- inspector = inspect(testing.db)
+ inspector = inspect(connection)
counter = itertools.count()
for enum in "lower_case", "UpperCase", "Name.With.Dot":
for schema in None, "test_schema", "TestSchema":
],
)
- @testing.provide_metadata
- def test_inspect_enums_star(self):
+ def test_inspect_enums_star(self, metadata, connection):
enum_type = postgresql.ENUM(
- "cat", "dog", "rat", name="pet", metadata=self.metadata
+ "cat", "dog", "rat", name="pet", metadata=metadata
)
schema_enum_type = postgresql.ENUM(
"sad",
"happy",
name="mood",
schema="test_schema",
- metadata=self.metadata,
+ metadata=metadata,
)
- enum_type.create(testing.db)
- schema_enum_type.create(testing.db)
- inspector = inspect(testing.db)
+ enum_type.create(connection)
+ schema_enum_type.create(connection)
+ inspector = inspect(connection)
eq_(
inspector.get_enums(),
],
)
- @testing.provide_metadata
- def test_inspect_enum_empty(self):
- enum_type = postgresql.ENUM(name="empty", metadata=self.metadata)
- enum_type.create(testing.db)
- inspector = inspect(testing.db)
+ def test_inspect_enum_empty(self, metadata, connection):
+ enum_type = postgresql.ENUM(name="empty", metadata=metadata)
+ enum_type.create(connection)
+ inspector = inspect(connection)
eq_(
inspector.get_enums(),
],
)
- @testing.provide_metadata
- def test_inspect_enum_empty_from_table(self):
+ def test_inspect_enum_empty_from_table(self, metadata, connection):
Table(
- "t", self.metadata, Column("x", postgresql.ENUM(name="empty"))
- ).create(testing.db)
+ "t", metadata, Column("x", postgresql.ENUM(name="empty"))
+ ).create(connection)
- t = Table("t", MetaData(), autoload_with=testing.db)
+ t = Table("t", MetaData(), autoload_with=connection)
eq_(t.c.x.type.enums, [])
def test_reflection_with_unique_constraint(self, metadata, connection):
ischema_names = None
- def setup(self):
+ def setup_test(self):
ischema_names = postgresql.PGDialect.ischema_names
postgresql.PGDialect.ischema_names = ischema_names.copy()
self.ischema_names = ischema_names
- def teardown(self):
+ def teardown_test(self):
postgresql.PGDialect.ischema_names = self.ischema_names
self.ischema_names = None
__only_on__ = "postgresql"
__backend__ = True
- def test_interval_types(self):
- for sym in [
- "YEAR",
- "MONTH",
- "DAY",
- "HOUR",
- "MINUTE",
- "SECOND",
- "YEAR TO MONTH",
- "DAY TO HOUR",
- "DAY TO MINUTE",
- "DAY TO SECOND",
- "HOUR TO MINUTE",
- "HOUR TO SECOND",
- "MINUTE TO SECOND",
- ]:
- self._test_interval_symbol(sym)
-
- @testing.provide_metadata
- def _test_interval_symbol(self, sym):
+ @testing.combinations(
+ ("YEAR",),
+ ("MONTH",),
+ ("DAY",),
+ ("HOUR",),
+ ("MINUTE",),
+ ("SECOND",),
+ ("YEAR TO MONTH",),
+ ("DAY TO HOUR",),
+ ("DAY TO MINUTE",),
+ ("DAY TO SECOND",),
+ ("HOUR TO MINUTE",),
+ ("HOUR TO SECOND",),
+ ("MINUTE TO SECOND",),
+ argnames="sym",
+ )
+ def test_interval_types(self, sym, metadata, connection):
t = Table(
"i_test",
- self.metadata,
+ metadata,
Column("id", Integer, primary_key=True),
Column("data1", INTERVAL(fields=sym)),
)
- t.create(testing.db)
+ t.create(connection)
columns = {
rec["name"]: rec
- for rec in inspect(testing.db).get_columns("i_test")
+ for rec in inspect(connection).get_columns("i_test")
}
assert isinstance(columns["data1"]["type"], INTERVAL)
eq_(columns["data1"]["type"].fields, sym.lower())
eq_(columns["data1"]["type"].precision, None)
- @testing.provide_metadata
- def test_interval_precision(self):
+ def test_interval_precision(self, metadata, connection):
t = Table(
"i_test",
- self.metadata,
+ metadata,
Column("id", Integer, primary_key=True),
Column("data1", INTERVAL(precision=6)),
)
- t.create(testing.db)
+ t.create(connection)
columns = {
rec["name"]: rec
- for rec in inspect(testing.db).get_columns("i_test")
+ for rec in inspect(connection).get_columns("i_test")
}
assert isinstance(columns["data1"]["type"], INTERVAL)
eq_(columns["data1"]["type"].fields, None)
Column("id4", SmallInteger, Identity()),
)
- def test_reflect_identity(self):
- insp = inspect(testing.db)
+ def test_reflect_identity(self, connection):
+ insp = inspect(connection)
default = dict(
always=False,
start=1,
from sqlalchemy.orm import Session
from sqlalchemy.sql import operators
from sqlalchemy.sql import sqltypes
-from sqlalchemy.testing import engines
from sqlalchemy.testing import fixtures
from sqlalchemy.testing.assertions import assert_raises
from sqlalchemy.testing.assertions import assert_raises_message
__only_on__ = "postgresql > 8.3"
- @testing.provide_metadata
- def test_create_table(self, connection):
+ def test_create_table(self, metadata, connection):
metadata = self.metadata
t1 = Table(
"table",
[(1, "two"), (2, "three"), (3, "three")],
)
- @testing.combinations(None, "foo")
- def test_create_table_schema_translate_map(self, symbol_name):
+ @testing.combinations(None, "foo", argnames="symbol_name")
+ def test_create_table_schema_translate_map(self, connection, symbol_name):
# note we can't use the fixture here because it will not drop
# from the correct schema
metadata = MetaData()
),
schema=symbol_name,
)
- with testing.db.begin() as conn:
- conn = conn.execution_options(
- schema_translate_map={symbol_name: testing.config.test_schema}
- )
- t1.create(conn)
- assert "schema_enum" in [
- e["name"]
- for e in inspect(conn).get_enums(
- schema=testing.config.test_schema
- )
- ]
- t1.create(conn, checkfirst=True)
+ conn = connection.execution_options(
+ schema_translate_map={symbol_name: testing.config.test_schema}
+ )
+ t1.create(conn)
+ assert "schema_enum" in [
+ e["name"]
+ for e in inspect(conn).get_enums(schema=testing.config.test_schema)
+ ]
+ t1.create(conn, checkfirst=True)
- conn.execute(t1.insert(), value="two")
- conn.execute(t1.insert(), value="three")
- conn.execute(t1.insert(), value="three")
- eq_(
- conn.execute(t1.select().order_by(t1.c.id)).fetchall(),
- [(1, "two"), (2, "three"), (3, "three")],
- )
+ conn.execute(t1.insert(), value="two")
+ conn.execute(t1.insert(), value="three")
+ conn.execute(t1.insert(), value="three")
+ eq_(
+ conn.execute(t1.select().order_by(t1.c.id)).fetchall(),
+ [(1, "two"), (2, "three"), (3, "three")],
+ )
- t1.drop(conn)
- assert "schema_enum" not in [
- e["name"]
- for e in inspect(conn).get_enums(
- schema=testing.config.test_schema
- )
- ]
- t1.drop(conn, checkfirst=True)
+ t1.drop(conn)
+ assert "schema_enum" not in [
+ e["name"]
+ for e in inspect(conn).get_enums(schema=testing.config.test_schema)
+ ]
+ t1.drop(conn, checkfirst=True)
def test_name_required(self, metadata, connection):
etype = Enum("four", "five", "six", metadata=metadata)
[util.u("réveillé"), util.u("drôle"), util.u("S’il")],
)
- @testing.provide_metadata
- def test_non_native_enum(self, connection):
+ def test_non_native_enum(self, metadata, connection):
metadata = self.metadata
t1 = Table(
"foo",
)
def go():
- t1.create(testing.db)
+ t1.create(connection)
self.assert_sql(
- testing.db,
+ connection,
go,
[
(
connection.execute(t1.insert(), {"bar": "two"})
eq_(connection.scalar(select(t1.c.bar)), "two")
- @testing.provide_metadata
- def test_non_native_enum_w_unicode(self, connection):
+ def test_non_native_enum_w_unicode(self, metadata, connection):
metadata = self.metadata
t1 = Table(
"foo",
)
def go():
- t1.create(testing.db)
+ t1.create(connection)
self.assert_sql(
- testing.db,
+ connection,
go,
[
(
connection.execute(t1.insert(), {"bar": util.u("Ü")})
eq_(connection.scalar(select(t1.c.bar)), util.u("Ü"))
- @testing.provide_metadata
- def test_disable_create(self):
+ def test_disable_create(self, metadata, connection):
metadata = self.metadata
e1 = postgresql.ENUM(
t1 = Table("e1", metadata, Column("c1", e1))
# table can be created separately
# without conflict
- e1.create(bind=testing.db)
- t1.create(testing.db)
- t1.drop(testing.db)
- e1.drop(bind=testing.db)
+ e1.create(bind=connection)
+ t1.create(connection)
+ t1.drop(connection)
+ e1.drop(bind=connection)
- @testing.provide_metadata
- def test_dont_keep_checking(self, connection):
+ def test_dont_keep_checking(self, metadata, connection):
metadata = self.metadata
e1 = postgresql.ENUM("one", "two", "three", name="myenum")
e["name"] for e in inspect(connection).get_enums()
]
- def test_non_native_dialect(self):
- engine = engines.testing_engine()
+ def test_non_native_dialect(self, metadata, testing_engine):
+ engine = testing_engine()
engine.connect()
engine.dialect.supports_native_enum = False
- metadata = MetaData()
t1 = Table(
"foo",
metadata,
def go():
t1.create(engine)
- try:
- self.assert_sql(
- engine,
- go,
- [
- (
- "CREATE TABLE foo (bar "
- "VARCHAR(5), CONSTRAINT myenum CHECK "
- "(bar IN ('one', 'two', 'three')))",
- {},
- )
- ],
- )
- finally:
- metadata.drop_all(engine)
+ self.assert_sql(
+ engine,
+ go,
+ [
+ (
+ "CREATE TABLE foo (bar "
+ "VARCHAR(5), CONSTRAINT myenum CHECK "
+ "(bar IN ('one', 'two', 'three')))",
+ {},
+ )
+ ],
+ )
def test_standalone_enum(self, connection, metadata):
etype = Enum(
)
etype.create(connection)
try:
- assert testing.db.dialect.has_type(connection, "fourfivesixtype")
+ assert connection.dialect.has_type(connection, "fourfivesixtype")
finally:
etype.drop(connection)
- assert not testing.db.dialect.has_type(
+ assert not connection.dialect.has_type(
connection, "fourfivesixtype"
)
metadata.create_all(connection)
try:
- assert testing.db.dialect.has_type(connection, "fourfivesixtype")
+ assert connection.dialect.has_type(connection, "fourfivesixtype")
finally:
metadata.drop_all(connection)
- assert not testing.db.dialect.has_type(
+ assert not connection.dialect.has_type(
connection, "fourfivesixtype"
)
- def test_no_support(self):
+ def test_no_support(self, testing_engine):
def server_version_info(self):
return (8, 2)
- e = engines.testing_engine()
+ e = testing_engine()
dialect = e.dialect
dialect._get_server_version_info = server_version_info
eq_(t2.c.value2.type.name, "fourfivesixtype")
eq_(t2.c.value2.type.schema, "test_schema")
- @testing.provide_metadata
- def test_custom_subclass(self, connection):
+ def test_custom_subclass(self, metadata, connection):
class MyEnum(TypeDecorator):
impl = Enum("oneHI", "twoHI", "threeHI", name="myenum")
return value
t1 = Table("table1", self.metadata, Column("data", MyEnum()))
- self.metadata.create_all(testing.db)
+ self.metadata.create_all(connection)
connection.execute(t1.insert(), {"data": "two"})
eq_(connection.scalar(select(t1.c.data)), "twoHITHERE")
- @testing.provide_metadata
- def test_generic_w_pg_variant(self, connection):
+ def test_generic_w_pg_variant(self, metadata, connection):
some_table = Table(
"some_table",
self.metadata,
e["name"] for e in inspect(connection).get_enums()
]
- @testing.provide_metadata
- def test_generic_w_some_other_variant(self, connection):
+ def test_generic_w_some_other_variant(self, metadata, connection):
some_table = Table(
"some_table",
self.metadata,
__only_on__ = "postgresql"
__backend__ = True
- @staticmethod
- def _scalar(expression):
- with testing.db.connect() as conn:
- return conn.scalar(select(expression))
+ @testing.fixture()
+ def scalar(self, connection):
+ def go(expression):
+ return connection.scalar(select(expression))
- def test_cast_name(self):
- eq_(self._scalar(cast("pg_class", postgresql.REGCLASS)), "pg_class")
+ return go
- def test_cast_path(self):
+ def test_cast_name(self, scalar):
+ eq_(scalar(cast("pg_class", postgresql.REGCLASS)), "pg_class")
+
+ def test_cast_path(self, scalar):
eq_(
- self._scalar(cast("pg_catalog.pg_class", postgresql.REGCLASS)),
+ scalar(cast("pg_catalog.pg_class", postgresql.REGCLASS)),
"pg_class",
)
- def test_cast_oid(self):
+ def test_cast_oid(self, scalar):
regclass = cast("pg_class", postgresql.REGCLASS)
- oid = self._scalar(cast(regclass, postgresql.OID))
+ oid = scalar(cast(regclass, postgresql.OID))
assert isinstance(oid, int)
eq_(
- self._scalar(
+ scalar(
cast(type_coerce(oid, postgresql.OID), postgresql.REGCLASS)
),
"pg_class",
Column("dimarr", ProcValue),
)
- def _fixture_456(self, table):
- with testing.db.begin() as conn:
- conn.execute(table.insert(), intarr=[4, 5, 6])
+ def _fixture_456(self, table, connection):
+ connection.execute(table.insert(), intarr=[4, 5, 6])
- def test_reflect_array_column(self):
+ def test_reflect_array_column(self, connection):
metadata2 = MetaData()
- tbl = Table("arrtable", metadata2, autoload_with=testing.db)
+ tbl = Table("arrtable", metadata2, autoload_with=connection)
assert isinstance(tbl.c.intarr.type, self.ARRAY)
assert isinstance(tbl.c.strarr.type, self.ARRAY)
assert isinstance(tbl.c.intarr.type.item_type, Integer)
def test_array_getitem_single_exec(self, connection):
arrtable = self.tables.arrtable
- self._fixture_456(arrtable)
+ self._fixture_456(arrtable, connection)
eq_(connection.scalar(select(arrtable.c.intarr[2])), 5)
connection.execute(arrtable.update().values({arrtable.c.intarr[2]: 7}))
eq_(connection.scalar(select(arrtable.c.intarr[2])), 7)
set([("1", "2", "3"), ("4", "5", "6"), (("4", "5"), ("6", "7"))]),
)
- def test_array_plus_native_enum_create(self):
- m = MetaData()
+ def test_array_plus_native_enum_create(self, metadata, connection):
t = Table(
"t",
- m,
+ metadata,
Column(
"data_1",
self.ARRAY(postgresql.ENUM("a", "b", "c", name="my_enum_1")),
),
)
- t.create(testing.db)
+ t.create(connection)
eq_(
- set(e["name"] for e in inspect(testing.db).get_enums()),
+ set(e["name"] for e in inspect(connection).get_enums()),
set(["my_enum_1", "my_enum_2"]),
)
- t.drop(testing.db)
- eq_(inspect(testing.db).get_enums(), [])
+ t.drop(connection)
+ eq_(inspect(connection).get_enums(), [])
class CoreArrayRoundTripTest(
):
ARRAY = postgresql.ARRAY
- @testing.combinations((set,), (list,), (lambda elem: (x for x in elem),))
- def test_undim_array_contains_typed_exec(self, struct):
+ @testing.combinations(
+ (set,), (list,), (lambda elem: (x for x in elem),), argnames="struct"
+ )
+ def test_undim_array_contains_typed_exec(self, struct, connection):
arrtable = self.tables.arrtable
- self._fixture_456(arrtable)
- with testing.db.begin() as conn:
- eq_(
- conn.scalar(
- select(arrtable.c.intarr).where(
- arrtable.c.intarr.contains(struct([4, 5]))
- )
- ),
- [4, 5, 6],
- )
+ self._fixture_456(arrtable, connection)
+ eq_(
+ connection.scalar(
+ select(arrtable.c.intarr).where(
+ arrtable.c.intarr.contains(struct([4, 5]))
+ )
+ ),
+ [4, 5, 6],
+ )
- @testing.combinations((set,), (list,), (lambda elem: (x for x in elem),))
- def test_dim_array_contains_typed_exec(self, struct):
+ @testing.combinations(
+ (set,), (list,), (lambda elem: (x for x in elem),), argnames="struct"
+ )
+ def test_dim_array_contains_typed_exec(self, struct, connection):
dim_arrtable = self.tables.dim_arrtable
- self._fixture_456(dim_arrtable)
- with testing.db.begin() as conn:
- eq_(
- conn.scalar(
- select(dim_arrtable.c.intarr).where(
- dim_arrtable.c.intarr.contains(struct([4, 5]))
- )
- ),
- [4, 5, 6],
- )
+ self._fixture_456(dim_arrtable, connection)
+ eq_(
+ connection.scalar(
+ select(dim_arrtable.c.intarr).where(
+ dim_arrtable.c.intarr.contains(struct([4, 5]))
+ )
+ ),
+ [4, 5, 6],
+ )
def test_array_contained_by_exec(self, connection):
arrtable = self.tables.arrtable
def test_undim_array_empty(self, connection):
arrtable = self.tables.arrtable
- self._fixture_456(arrtable)
+ self._fixture_456(arrtable, connection)
eq_(
connection.scalar(
select(arrtable.c.intarr).where(arrtable.c.intarr.contains([]))
sqltypes.ARRAY, postgresql.ARRAY, argnames="array_cls"
)
@testing.combinations(sqltypes.Enum, postgresql.ENUM, argnames="enum_cls")
- @testing.provide_metadata
- def test_raises_non_native_enums(self, array_cls, enum_cls):
+ def test_raises_non_native_enums(
+ self, metadata, connection, array_cls, enum_cls
+ ):
Table(
"my_table",
self.metadata,
"for ARRAY of non-native ENUM; please specify "
"create_constraint=False on this Enum datatype.",
self.metadata.create_all,
- testing.db,
+ connection,
)
@testing.combinations(sqltypes.Enum, postgresql.ENUM, argnames="enum_cls")
(_ArrayOfEnum, testing.only_on("postgresql+psycopg2")),
argnames="array_cls",
)
- @testing.provide_metadata
- def test_array_of_enums(self, array_cls, enum_cls, connection):
+ def test_array_of_enums(self, array_cls, enum_cls, metadata, connection):
tbl = Table(
"enum_table",
self.metadata,
@testing.combinations(
sqltypes.JSON, postgresql.JSON, postgresql.JSONB, argnames="json_cls"
)
- @testing.provide_metadata
- def test_array_of_json(self, array_cls, json_cls, connection):
+ def test_array_of_json(self, array_cls, json_cls, metadata, connection):
tbl = Table(
"json_table",
self.metadata,
},
],
),
+ (
+ "HSTORE",
+ postgresql.HSTORE(),
+ [{"a": "1", "b": "2", "c": "3"}, {"d": "4", "e": "5", "f": "6"}],
+ testing.requires.hstore,
+ ),
+ (
+ "JSONB",
+ postgresql.JSONB(),
+ [
+ {"a": "1", "b": "2", "c": "3"},
+ {
+ "d": "4",
+ "e": {"e1": "5", "e2": "6"},
+ "f": {"f1": [9, 10, 11]},
+ },
+ ],
+ testing.requires.postgresql_jsonb,
+ ),
+ argnames="type_,data",
id_="iaa",
)
- @testing.provide_metadata
- def test_hashable_flag(self, type_, data):
- Base = declarative_base(metadata=self.metadata)
+ def test_hashable_flag(self, metadata, connection, type_, data):
+ Base = declarative_base(metadata=metadata)
class A(Base):
__tablename__ = "a1"
id = Column(Integer, primary_key=True)
data = Column(type_)
- Base.metadata.create_all(testing.db)
- s = Session(testing.db)
+ Base.metadata.create_all(connection)
+ s = Session(connection)
s.add_all([A(data=elem) for elem in data])
s.commit()
list(enumerate(data, 1)),
)
- @testing.requires.hstore
- def test_hstore(self):
- self.test_hashable_flag(
- postgresql.HSTORE(),
- [{"a": "1", "b": "2", "c": "3"}, {"d": "4", "e": "5", "f": "6"}],
- )
-
- @testing.requires.postgresql_jsonb
- def test_jsonb(self):
- self.test_hashable_flag(
- postgresql.JSONB(),
- [
- {"a": "1", "b": "2", "c": "3"},
- {
- "d": "4",
- "e": {"e1": "5", "e2": "6"},
- "f": {"f1": [9, 10, 11]},
- },
- ],
- )
-
class TimestampTest(fixtures.TestBase, AssertsExecutionResults):
__only_on__ = "postgresql"
return table
- def test_reflection(self, special_types_table):
+ def test_reflection(self, special_types_table, connection):
# cheat so that the "strict type check"
# works
special_types_table.c.year_interval.type = postgresql.INTERVAL()
special_types_table.c.month_interval.type = postgresql.INTERVAL()
m = MetaData()
- t = Table("sometable", m, autoload_with=testing.db)
+ t = Table("sometable", m, autoload_with=connection)
self.assert_tables_equal(special_types_table, t, strict_types=True)
assert t.c.plain_interval.type.precision is None
class HStoreTest(AssertsCompiledSQL, fixtures.TestBase):
__dialect__ = "postgresql"
- def setup(self):
+ def setup_test(self):
metadata = MetaData()
self.test_table = Table(
"test_table",
Column("data", HSTORE),
)
- def _fixture_data(self, engine):
+ def _fixture_data(self, connection):
data_table = self.tables.data_table
- with engine.begin() as conn:
- conn.execute(
- data_table.insert(),
- {"name": "r1", "data": {"k1": "r1v1", "k2": "r1v2"}},
- {"name": "r2", "data": {"k1": "r2v1", "k2": "r2v2"}},
- {"name": "r3", "data": {"k1": "r3v1", "k2": "r3v2"}},
- {"name": "r4", "data": {"k1": "r4v1", "k2": "r4v2"}},
- {"name": "r5", "data": {"k1": "r5v1", "k2": "r5v2"}},
- )
+ connection.execute(
+ data_table.insert(),
+ {"name": "r1", "data": {"k1": "r1v1", "k2": "r1v2"}},
+ {"name": "r2", "data": {"k1": "r2v1", "k2": "r2v2"}},
+ {"name": "r3", "data": {"k1": "r3v1", "k2": "r3v2"}},
+ {"name": "r4", "data": {"k1": "r4v1", "k2": "r4v2"}},
+ {"name": "r5", "data": {"k1": "r5v1", "k2": "r5v2"}},
+ )
def _assert_data(self, compare, conn):
data = conn.execute(
).fetchall()
eq_([d for d, in data], compare)
- def _test_insert(self, engine):
- with engine.begin() as conn:
- conn.execute(
- self.tables.data_table.insert(),
- {"name": "r1", "data": {"k1": "r1v1", "k2": "r1v2"}},
- )
- self._assert_data([{"k1": "r1v1", "k2": "r1v2"}], conn)
+ def _test_insert(self, connection):
+ connection.execute(
+ self.tables.data_table.insert(),
+ {"name": "r1", "data": {"k1": "r1v1", "k2": "r1v2"}},
+ )
+ self._assert_data([{"k1": "r1v1", "k2": "r1v2"}], connection)
- def _non_native_engine(self):
- if testing.requires.psycopg2_native_hstore.enabled:
- engine = engines.testing_engine(
- options=dict(use_native_hstore=False)
- )
+ @testing.fixture
+ def non_native_hstore_connection(self, testing_engine):
+ local_engine = testing.requires.psycopg2_native_hstore.enabled
+
+ if local_engine:
+ engine = testing_engine(options=dict(use_native_hstore=False))
else:
engine = testing.db
- engine.connect().close()
- return engine
- def test_reflect(self):
- insp = inspect(testing.db)
+ conn = engine.connect()
+ trans = conn.begin()
+ yield conn
+ try:
+ trans.rollback()
+ finally:
+ conn.close()
+
+ def test_reflect(self, connection):
+ insp = inspect(connection)
cols = insp.get_columns("data_table")
assert isinstance(cols[2]["type"], HSTORE)
eq_(connection.scalar(select(expr)), "3")
@testing.requires.psycopg2_native_hstore
- def test_insert_native(self):
- engine = testing.db
- self._test_insert(engine)
+ def test_insert_native(self, connection):
+ self._test_insert(connection)
- def test_insert_python(self):
- engine = self._non_native_engine()
- self._test_insert(engine)
+ def test_insert_python(self, non_native_hstore_connection):
+ self._test_insert(non_native_hstore_connection)
@testing.requires.psycopg2_native_hstore
- def test_criterion_native(self):
- engine = testing.db
- self._fixture_data(engine)
- self._test_criterion(engine)
+ def test_criterion_native(self, connection):
+ self._fixture_data(connection)
+ self._test_criterion(connection)
- def test_criterion_python(self):
- engine = self._non_native_engine()
- self._fixture_data(engine)
- self._test_criterion(engine)
+ def test_criterion_python(self, non_native_hstore_connection):
+ self._fixture_data(non_native_hstore_connection)
+ self._test_criterion(non_native_hstore_connection)
- def _test_criterion(self, engine):
+ def _test_criterion(self, connection):
data_table = self.tables.data_table
- with engine.begin() as conn:
- result = conn.execute(
- select(data_table.c.data).where(
- data_table.c.data["k1"] == "r3v1"
- )
- ).first()
- eq_(result, ({"k1": "r3v1", "k2": "r3v2"},))
+ result = connection.execute(
+ select(data_table.c.data).where(data_table.c.data["k1"] == "r3v1")
+ ).first()
+ eq_(result, ({"k1": "r3v1", "k2": "r3v2"},))
- def _test_fixed_round_trip(self, engine):
- with engine.begin() as conn:
- s = select(
- hstore(
- array(["key1", "key2", "key3"]),
- array(["value1", "value2", "value3"]),
- )
- )
- eq_(
- conn.scalar(s),
- {"key1": "value1", "key2": "value2", "key3": "value3"},
+ def _test_fixed_round_trip(self, connection):
+ s = select(
+ hstore(
+ array(["key1", "key2", "key3"]),
+ array(["value1", "value2", "value3"]),
)
+ )
+ eq_(
+ connection.scalar(s),
+ {"key1": "value1", "key2": "value2", "key3": "value3"},
+ )
- def test_fixed_round_trip_python(self):
- engine = self._non_native_engine()
- self._test_fixed_round_trip(engine)
+ def test_fixed_round_trip_python(self, non_native_hstore_connection):
+ self._test_fixed_round_trip(non_native_hstore_connection)
@testing.requires.psycopg2_native_hstore
- def test_fixed_round_trip_native(self):
- engine = testing.db
- self._test_fixed_round_trip(engine)
+ def test_fixed_round_trip_native(self, connection):
+ self._test_fixed_round_trip(connection)
- def _test_unicode_round_trip(self, engine):
- with engine.begin() as conn:
- s = select(
- hstore(
- array(
- [util.u("réveillé"), util.u("drôle"), util.u("S’il")]
- ),
- array(
- [util.u("réveillé"), util.u("drôle"), util.u("S’il")]
- ),
- )
- )
- eq_(
- conn.scalar(s),
- {
- util.u("réveillé"): util.u("réveillé"),
- util.u("drôle"): util.u("drôle"),
- util.u("S’il"): util.u("S’il"),
- },
+ def _test_unicode_round_trip(self, connection):
+ s = select(
+ hstore(
+ array([util.u("réveillé"), util.u("drôle"), util.u("S’il")]),
+ array([util.u("réveillé"), util.u("drôle"), util.u("S’il")]),
)
+ )
+ eq_(
+ connection.scalar(s),
+ {
+ util.u("réveillé"): util.u("réveillé"),
+ util.u("drôle"): util.u("drôle"),
+ util.u("S’il"): util.u("S’il"),
+ },
+ )
@testing.requires.psycopg2_native_hstore
- def test_unicode_round_trip_python(self):
- engine = self._non_native_engine()
- self._test_unicode_round_trip(engine)
+ def test_unicode_round_trip_python(self, non_native_hstore_connection):
+ self._test_unicode_round_trip(non_native_hstore_connection)
@testing.requires.psycopg2_native_hstore
- def test_unicode_round_trip_native(self):
- engine = testing.db
- self._test_unicode_round_trip(engine)
+ def test_unicode_round_trip_native(self, connection):
+ self._test_unicode_round_trip(connection)
- def test_escaped_quotes_round_trip_python(self):
- engine = self._non_native_engine()
- self._test_escaped_quotes_round_trip(engine)
+ def test_escaped_quotes_round_trip_python(
+ self, non_native_hstore_connection
+ ):
+ self._test_escaped_quotes_round_trip(non_native_hstore_connection)
@testing.requires.psycopg2_native_hstore
- def test_escaped_quotes_round_trip_native(self):
- engine = testing.db
- self._test_escaped_quotes_round_trip(engine)
+ def test_escaped_quotes_round_trip_native(self, connection):
+ self._test_escaped_quotes_round_trip(connection)
- def _test_escaped_quotes_round_trip(self, engine):
- with engine.begin() as conn:
- conn.execute(
- self.tables.data_table.insert(),
- {"name": "r1", "data": {r"key \"foo\"": r'value \"bar"\ xyz'}},
- )
- self._assert_data([{r"key \"foo\"": r'value \"bar"\ xyz'}], conn)
+ def _test_escaped_quotes_round_trip(self, connection):
+ connection.execute(
+ self.tables.data_table.insert(),
+ {"name": "r1", "data": {r"key \"foo\"": r'value \"bar"\ xyz'}},
+ )
+ self._assert_data([{r"key \"foo\"": r'value \"bar"\ xyz'}], connection)
- def test_orm_round_trip(self):
+ def test_orm_round_trip(self, connection):
from sqlalchemy import orm
class Data(object):
self.data = data
orm.mapper(Data, self.tables.data_table)
- s = orm.Session(testing.db)
- d = Data(
- name="r1",
- data={"key1": "value1", "key2": "value2", "key3": "value3"},
- )
- s.add(d)
- eq_(s.query(Data.data, Data).all(), [(d.data, d)])
+
+ with orm.Session(connection) as s:
+ d = Data(
+ name="r1",
+ data={"key1": "value1", "key2": "value2", "key3": "value3"},
+ )
+ s.add(d)
+ eq_(s.query(Data.data, Data).all(), [(d.data, d)])
class _RangeTypeCompilation(AssertsCompiledSQL, fixtures.TestBase):
# operator tests
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
table = Table(
"data_table",
MetaData(),
def test_actual_type(self):
eq_(str(self._col_type()), self._col_str)
- def test_reflect(self):
+ def test_reflect(self, connection):
from sqlalchemy import inspect
- insp = inspect(testing.db)
+ insp = inspect(connection)
cols = insp.get_columns("data_table")
assert isinstance(cols[0]["type"], self._col_type)
def tstzs(self):
if self._tstzs is None:
- with testing.db.begin() as conn:
- lower = conn.scalar(func.current_timestamp().select())
+ with testing.db.connect() as connection:
+ lower = connection.scalar(func.current_timestamp().select())
upper = lower + datetime.timedelta(1)
self._tstzs = (lower, upper)
return self._tstzs
class JSONTest(AssertsCompiledSQL, fixtures.TestBase):
__dialect__ = "postgresql"
- def setup(self):
+ def setup_test(self):
metadata = MetaData()
self.test_table = Table(
"test_table",
Column("nulldata", cls.data_type(none_as_null=True)),
)
- def _fixture_data(self, engine):
+ def _fixture_data(self, connection):
data_table = self.tables.data_table
data = [
{"name": "r5", "data": {"k1": "r5v1", "k2": "r5v2", "k3": 5}},
{"name": "r6", "data": {"k1": {"r6v1": {"subr": [1, 2, 3]}}}},
]
- with engine.begin() as conn:
- conn.execute(data_table.insert(), data)
+ connection.execute(data_table.insert(), data)
return data
def _assert_data(self, compare, conn, column="data"):
).fetchall()
eq_([d for d, in data], [None])
- def _test_insert(self, conn):
- conn.execute(
+ def test_reflect(self, connection):
+ insp = inspect(connection)
+ cols = insp.get_columns("data_table")
+ assert isinstance(cols[2]["type"], self.data_type)
+
+ def test_insert(self, connection):
+ connection.execute(
self.tables.data_table.insert(),
{"name": "r1", "data": {"k1": "r1v1", "k2": "r1v2"}},
)
- self._assert_data([{"k1": "r1v1", "k2": "r1v2"}], conn)
+ self._assert_data([{"k1": "r1v1", "k2": "r1v2"}], connection)
- def _test_insert_nulls(self, conn):
- conn.execute(
+ def test_insert_nulls(self, connection):
+ connection.execute(
self.tables.data_table.insert(), {"name": "r1", "data": null()}
)
- self._assert_data([None], conn)
+ self._assert_data([None], connection)
- def _test_insert_none_as_null(self, conn):
- conn.execute(
+ def test_insert_none_as_null(self, connection):
+ connection.execute(
self.tables.data_table.insert(),
{"name": "r1", "nulldata": None},
)
- self._assert_column_is_NULL(conn, column="nulldata")
+ self._assert_column_is_NULL(connection, column="nulldata")
- def _test_insert_nulljson_into_none_as_null(self, conn):
- conn.execute(
+ def test_insert_nulljson_into_none_as_null(self, connection):
+ connection.execute(
self.tables.data_table.insert(),
{"name": "r1", "nulldata": JSON.NULL},
)
- self._assert_column_is_JSON_NULL(conn, column="nulldata")
-
- def test_reflect(self):
- insp = inspect(testing.db)
- cols = insp.get_columns("data_table")
- assert isinstance(cols[2]["type"], self.data_type)
-
- def test_insert(self, connection):
- self._test_insert(connection)
-
- def test_insert_nulls(self, connection):
- self._test_insert_nulls(connection)
+ self._assert_column_is_JSON_NULL(connection, column="nulldata")
- def test_insert_none_as_null(self, connection):
- self._test_insert_none_as_null(connection)
-
- def test_insert_nulljson_into_none_as_null(self, connection):
- self._test_insert_nulljson_into_none_as_null(connection)
-
- def test_custom_serialize_deserialize(self):
+ def test_custom_serialize_deserialize(self, testing_engine):
import json
def loads(value):
value["x"] = "dumps_y"
return json.dumps(value)
- engine = engines.testing_engine(
+ engine = testing_engine(
options=dict(json_serializer=dumps, json_deserializer=loads)
)
with engine.begin() as conn:
eq_(conn.scalar(s), {"key": "value", "x": "dumps_y_loads"})
- def test_criterion(self):
- engine = testing.db
- self._fixture_data(engine)
- self._test_criterion(engine)
+ def test_criterion(self, connection):
+ self._fixture_data(connection)
+ data_table = self.tables.data_table
+
+ result = connection.execute(
+ select(data_table.c.data).where(
+ data_table.c.data["k1"].astext == "r3v1"
+ )
+ ).first()
+ eq_(result, ({"k1": "r3v1", "k2": "r3v2"},))
+
+ result = connection.execute(
+ select(data_table.c.data).where(
+ data_table.c.data["k1"].astext.cast(String) == "r3v1"
+ )
+ ).first()
+ eq_(result, ({"k1": "r3v1", "k2": "r3v2"},))
def test_path_query(self, connection):
- engine = testing.db
- self._fixture_data(engine)
+ self._fixture_data(connection)
data_table = self.tables.data_table
result = connection.execute(
"postgresql < 9.4", "Improvement in PostgreSQL behavior?"
)
def test_multi_index_query(self, connection):
- engine = testing.db
- self._fixture_data(engine)
+ self._fixture_data(connection)
data_table = self.tables.data_table
result = connection.execute(
eq_(result.scalar(), "r6")
def test_query_returned_as_text(self, connection):
- engine = testing.db
- self._fixture_data(engine)
+ self._fixture_data(connection)
data_table = self.tables.data_table
result = connection.execute(
select(data_table.c.data["k1"].astext)
).first()
- if engine.dialect.returns_unicode_strings:
+ if connection.dialect.returns_unicode_strings:
assert isinstance(result[0], util.text_type)
else:
assert isinstance(result[0], util.string_types)
def test_query_returned_as_int(self, connection):
- engine = testing.db
- self._fixture_data(engine)
+ self._fixture_data(connection)
data_table = self.tables.data_table
result = connection.execute(
select(data_table.c.data["k3"].astext.cast(Integer)).where(
).first()
assert isinstance(result[0], int)
- def _test_criterion(self, engine):
- data_table = self.tables.data_table
- with engine.begin() as conn:
- result = conn.execute(
- select(data_table.c.data).where(
- data_table.c.data["k1"].astext == "r3v1"
- )
- ).first()
- eq_(result, ({"k1": "r3v1", "k2": "r3v2"},))
-
- result = conn.execute(
- select(data_table.c.data).where(
- data_table.c.data["k1"].astext.cast(String) == "r3v1"
- )
- ).first()
- eq_(result, ({"k1": "r3v1", "k2": "r3v2"},))
-
def test_fixed_round_trip(self, connection):
s = select(
cast(
},
)
- def test_eval_none_flag_orm(self):
+ def test_eval_none_flag_orm(self, connection):
Base = declarative_base()
class Data(Base):
__table__ = self.tables.data_table
- s = Session(testing.db)
+ with Session(connection) as s:
+ d1 = Data(name="d1", data=None, nulldata=None)
+ s.add(d1)
+ s.commit()
- d1 = Data(name="d1", data=None, nulldata=None)
- s.add(d1)
- s.commit()
-
- s.bulk_insert_mappings(
- Data, [{"name": "d2", "data": None, "nulldata": None}]
- )
- eq_(
- s.query(
- cast(self.tables.data_table.c.data, String),
- cast(self.tables.data_table.c.nulldata, String),
+ s.bulk_insert_mappings(
+ Data, [{"name": "d2", "data": None, "nulldata": None}]
)
- .filter(self.tables.data_table.c.name == "d1")
- .first(),
- ("null", None),
- )
- eq_(
- s.query(
- cast(self.tables.data_table.c.data, String),
- cast(self.tables.data_table.c.nulldata, String),
+ eq_(
+ s.query(
+ cast(self.tables.data_table.c.data, String),
+ cast(self.tables.data_table.c.nulldata, String),
+ )
+ .filter(self.tables.data_table.c.name == "d1")
+ .first(),
+ ("null", None),
+ )
+ eq_(
+ s.query(
+ cast(self.tables.data_table.c.data, String),
+ cast(self.tables.data_table.c.nulldata, String),
+ )
+ .filter(self.tables.data_table.c.name == "d2")
+ .first(),
+ ("null", None),
)
- .filter(self.tables.data_table.c.name == "d2")
- .first(),
- ("null", None),
- )
def test_literal(self, connection):
- exp = self._fixture_data(testing.db)
+ exp = self._fixture_data(connection)
result = connection.exec_driver_sql(
"select data from data_table order by name"
)
eq_(len(res), len(exp))
for row, expected in zip(res, exp):
eq_(row[0], expected["data"])
- result.close()
class JSONBTest(JSONTest):
- def setup(self):
+ def setup_test(self):
metadata = MetaData()
self.test_table = Table(
"test_table",
from sqlalchemy import util
from sqlalchemy.dialects.sqlite import base as sqlite
from sqlalchemy.dialects.sqlite import insert
+from sqlalchemy.dialects.sqlite import provision
from sqlalchemy.dialects.sqlite import pysqlite as pysqlite_dialect
from sqlalchemy.engine.url import make_url
from sqlalchemy.schema import CreateTable
from sqlalchemy.testing import AssertsCompiledSQL
from sqlalchemy.testing import AssertsExecutionResults
from sqlalchemy.testing import combinations
+from sqlalchemy.testing import config
from sqlalchemy.testing import engines
from sqlalchemy.testing import eq_
from sqlalchemy.testing import expect_warnings
def _fixture(self):
meta = self.metadata
- self.conn = testing.db.connect()
+ self.conn = self.engine.connect()
Table("created", meta, Column("foo", Integer), Column("bar", String))
Table("local_only", meta, Column("q", Integer), Column("p", Integer))
meta.create_all(self.conn)
return ct
- def setup(self):
- self.conn = testing.db.connect()
+ def setup_test(self):
+ self.engine = engines.testing_engine(options={"use_reaper": False})
+
+ provision._sqlite_post_configure_engine(
+ self.engine.url, self.engine, config.ident
+ )
+ self.conn = self.engine.connect()
self.metadata = MetaData()
- def teardown(self):
+ def teardown_test(self):
with self.conn.begin():
self.metadata.drop_all(self.conn)
self.conn.close()
+ self.engine.dispose()
def test_no_tables(self):
insp = inspect(self.conn)
__skip_if__ = (full_text_search_missing,)
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
global metadata, cattable, matchtable
metadata = MetaData()
exec_sql(
)
@classmethod
- def teardown_class(cls):
+ def teardown_test_class(cls):
metadata.drop_all(testing.db)
def test_expression(self):
class ReflectHeadlessFKsTest(fixtures.TestBase):
__only_on__ = "sqlite"
- def setup(self):
+ def setup_test(self):
exec_sql(testing.db, "CREATE TABLE a (id INTEGER PRIMARY KEY)")
# this syntax actually works on other DBs perhaps we'd want to add
# tests to test_reflection
testing.db, "CREATE TABLE b (id INTEGER PRIMARY KEY REFERENCES a)"
)
- def teardown(self):
+ def teardown_test(self):
exec_sql(testing.db, "drop table b")
exec_sql(testing.db, "drop table a")
__only_on__ = "sqlite"
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
with testing.db.begin() as conn:
conn.exec_driver_sql("CREATE TABLE a1 (id INTEGER PRIMARY KEY)")
)
@classmethod
- def teardown_class(cls):
+ def teardown_test_class(cls):
with testing.db.begin() as conn:
for name in [
"implicit_referrer_comp_fake",
@classmethod
def setup_bind(cls):
- engine = engines.testing_engine(options={"use_reaper": False})
+ engine = engines.testing_engine(options={"scope": "class"})
@event.listens_for(engine, "connect")
def do_connect(dbapi_connection, connection_record):
class RegexpTest(fixtures.TestBase, testing.AssertsCompiledSQL):
__dialect__ = "sqlite"
- def setUp(self):
+ def setup_test(self):
self.table = table(
"mytable", column("myid", Integer), column("name", String)
)
class DDLEventTest(fixtures.TestBase):
- def setup(self):
+ def setup_test(self):
self.bind = engines.mock_engine()
self.metadata = MetaData()
self.table = Table("t", self.metadata, Column("id", Integer))
class DDLExecutionTest(fixtures.TestBase):
- def setup(self):
+ def setup_test(self):
self.engine = engines.mock_engine()
self.metadata = MetaData()
self.users = Table(
class HandleInvalidatedOnConnectTest(fixtures.TestBase):
__requires__ = ("sqlite",)
- def setUp(self):
+ def setup_test(self):
e = create_engine("sqlite://")
connection = Mock(get_server_version_info=Mock(return_value="5.0"))
class PoolTestBase(fixtures.TestBase):
- def setup(self):
+ def setup_test(self):
pool.clear_managers()
self._teardown_conns = []
- def teardown(self):
+ def teardown_test(self):
for ref in self._teardown_conns:
conn = ref()
if conn:
conn.close()
@classmethod
- def teardown_class(cls):
+ def teardown_test_class(cls):
pool.clear_managers()
def _queuepool_fixture(self, **kw):
__requires__ = ("ad_hoc_engines",)
__backend__ = True
- def tearDown(self):
+ def teardown_test(self):
Engine.dispatch._clear()
Engine._has_events = False
event.listen(
engine, "before_cursor_execute", cursor_execute, retval=True
)
+
with testing.expect_deprecated(
r"The argument signature for the "
r"\"ConnectionEvents.before_execute\" event listener",
r"The argument signature for the "
r"\"ConnectionEvents.after_execute\" event listener",
):
- e1.execute(select(1))
+ result = e1.execute(select(1))
+ result.close()
class DDLExecutionTest(fixtures.TestBase):
- def setup(self):
+ def setup_test(self):
self.engine = engines.mock_engine()
self.metadata = MetaData()
self.users = Table(
from sqlalchemy.testing import is_true
from sqlalchemy.testing import mock
from sqlalchemy.testing.assertsql import CompiledSQL
-from sqlalchemy.testing.engines import testing_engine
from sqlalchemy.testing.mock import call
from sqlalchemy.testing.mock import Mock
from sqlalchemy.testing.mock import patch
).default_from()
)
- conn = testing.db.connect()
- result = (
- conn.execution_options(no_parameters=True)
- .exec_driver_sql(stmt)
- .scalar()
- )
- eq_(result, "%")
+ with testing.db.connect() as conn:
+ result = (
+ conn.execution_options(no_parameters=True)
+ .exec_driver_sql(stmt)
+ .scalar()
+ )
+ eq_(result, "%")
def test_raw_positional_invalid(self, connection):
assert_raises_message(
(4, "sally"),
]
- @testing.engines.close_open_connections
def test_exception_wrapping_dbapi(self):
- conn = testing.db.connect()
- # engine does not have exec_driver_sql
- assert_raises_message(
- tsa.exc.DBAPIError,
- r"not_a_valid_statement",
- conn.exec_driver_sql,
- "not_a_valid_statement",
- )
+ with testing.db.connect() as conn:
+ # engine does not have exec_driver_sql
+ assert_raises_message(
+ tsa.exc.DBAPIError,
+ r"not_a_valid_statement",
+ conn.exec_driver_sql,
+ "not_a_valid_statement",
+ )
@testing.requires.sqlite
def test_exception_wrapping_non_dbapi_error(self):
["sqlite", "mysql", "postgresql"],
"uses blob value that is problematic for some DBAPIs",
)
- @testing.provide_metadata
- def test_cache_noleak_on_statement_values(self, connection):
+ def test_cache_noleak_on_statement_values(self, metadata, connection):
# This is a non regression test for an object reference leak caused
# by the compiled_cache.
- metadata = self.metadata
photo = Table(
"photo",
metadata,
__requires__ = ("schemas",)
__backend__ = True
- def test_create_table(self):
+ @testing.fixture
+ def plain_tables(self, metadata):
+ t1 = Table(
+ "t1", metadata, Column("x", Integer), schema=config.test_schema
+ )
+ t2 = Table(
+ "t2", metadata, Column("x", Integer), schema=config.test_schema
+ )
+ t3 = Table("t3", metadata, Column("x", Integer), schema=None)
+
+ return t1, t2, t3
+
+ def test_create_table(self, plain_tables, connection):
map_ = {
None: config.test_schema,
"foo": config.test_schema,
t2 = Table("t2", metadata, Column("x", Integer), schema="foo")
t3 = Table("t3", metadata, Column("x", Integer), schema="bar")
- with self.sql_execution_asserter(config.db) as asserter:
- with config.db.begin() as conn, conn.execution_options(
- schema_translate_map=map_
- ) as conn:
+ with self.sql_execution_asserter(connection) as asserter:
+ conn = connection.execution_options(schema_translate_map=map_)
- t1.create(conn)
- t2.create(conn)
- t3.create(conn)
+ t1.create(conn)
+ t2.create(conn)
+ t3.create(conn)
- t3.drop(conn)
- t2.drop(conn)
- t1.drop(conn)
+ t3.drop(conn)
+ t2.drop(conn)
+ t1.drop(conn)
asserter.assert_(
CompiledSQL("CREATE TABLE [SCHEMA__none].t1 (x INTEGER)"),
CompiledSQL("DROP TABLE [SCHEMA__none].t1"),
)
- def _fixture(self):
- metadata = self.metadata
- Table("t1", metadata, Column("x", Integer), schema=config.test_schema)
- Table("t2", metadata, Column("x", Integer), schema=config.test_schema)
- Table("t3", metadata, Column("x", Integer), schema=None)
- metadata.create_all(testing.db)
-
- def test_ddl_hastable(self):
+ def test_ddl_hastable(self, plain_tables, connection):
map_ = {
None: config.test_schema,
Table("t2", metadata, Column("x", Integer), schema="foo")
Table("t3", metadata, Column("x", Integer), schema="bar")
- with config.db.begin() as conn:
- conn = conn.execution_options(schema_translate_map=map_)
- metadata.create_all(conn)
+ conn = connection.execution_options(schema_translate_map=map_)
+ metadata.create_all(conn)
- insp = inspect(config.db)
+ insp = inspect(connection)
is_true(insp.has_table("t1", schema=config.test_schema))
is_true(insp.has_table("t2", schema=config.test_schema))
is_true(insp.has_table("t3", schema=None))
- with config.db.begin() as conn:
- conn = conn.execution_options(schema_translate_map=map_)
- metadata.drop_all(conn)
+ conn = connection.execution_options(schema_translate_map=map_)
+
+ # if this test fails, the tables won't get dropped. so need a
+ # more robust fixture for this
+ metadata.drop_all(conn)
- insp = inspect(config.db)
+ insp = inspect(connection)
is_false(insp.has_table("t1", schema=config.test_schema))
is_false(insp.has_table("t2", schema=config.test_schema))
is_false(insp.has_table("t3", schema=None))
- @testing.provide_metadata
- def test_option_on_execute(self):
- self._fixture()
+ def test_option_on_execute(self, plain_tables, connection):
+ # provided by metadata fixture provided by plain_tables fixture
+ self.metadata.create_all(connection)
map_ = {
None: config.test_schema,
t2 = Table("t2", metadata, Column("x", Integer), schema="foo")
t3 = Table("t3", metadata, Column("x", Integer), schema="bar")
- with self.sql_execution_asserter(config.db) as asserter:
- with config.db.begin() as conn:
+ with self.sql_execution_asserter(connection) as asserter:
+ conn = connection
+ execution_options = {"schema_translate_map": map_}
+ conn._execute_20(
+ t1.insert(), {"x": 1}, execution_options=execution_options
+ )
+ conn._execute_20(
+ t2.insert(), {"x": 1}, execution_options=execution_options
+ )
+ conn._execute_20(
+ t3.insert(), {"x": 1}, execution_options=execution_options
+ )
- execution_options = {"schema_translate_map": map_}
- conn._execute_20(
- t1.insert(), {"x": 1}, execution_options=execution_options
- )
- conn._execute_20(
- t2.insert(), {"x": 1}, execution_options=execution_options
- )
- conn._execute_20(
- t3.insert(), {"x": 1}, execution_options=execution_options
- )
+ conn._execute_20(
+ t1.update().values(x=1).where(t1.c.x == 1),
+ execution_options=execution_options,
+ )
+ conn._execute_20(
+ t2.update().values(x=2).where(t2.c.x == 1),
+ execution_options=execution_options,
+ )
+ conn._execute_20(
+ t3.update().values(x=3).where(t3.c.x == 1),
+ execution_options=execution_options,
+ )
+ eq_(
conn._execute_20(
- t1.update().values(x=1).where(t1.c.x == 1),
- execution_options=execution_options,
- )
+ select(t1.c.x), execution_options=execution_options
+ ).scalar(),
+ 1,
+ )
+ eq_(
conn._execute_20(
- t2.update().values(x=2).where(t2.c.x == 1),
- execution_options=execution_options,
- )
+ select(t2.c.x), execution_options=execution_options
+ ).scalar(),
+ 2,
+ )
+ eq_(
conn._execute_20(
- t3.update().values(x=3).where(t3.c.x == 1),
- execution_options=execution_options,
- )
-
- eq_(
- conn._execute_20(
- select(t1.c.x), execution_options=execution_options
- ).scalar(),
- 1,
- )
- eq_(
- conn._execute_20(
- select(t2.c.x), execution_options=execution_options
- ).scalar(),
- 2,
- )
- eq_(
- conn._execute_20(
- select(t3.c.x), execution_options=execution_options
- ).scalar(),
- 3,
- )
+ select(t3.c.x), execution_options=execution_options
+ ).scalar(),
+ 3,
+ )
- conn._execute_20(
- t1.delete(), execution_options=execution_options
- )
- conn._execute_20(
- t2.delete(), execution_options=execution_options
- )
- conn._execute_20(
- t3.delete(), execution_options=execution_options
- )
+ conn._execute_20(t1.delete(), execution_options=execution_options)
+ conn._execute_20(t2.delete(), execution_options=execution_options)
+ conn._execute_20(t3.delete(), execution_options=execution_options)
asserter.assert_(
CompiledSQL("INSERT INTO [SCHEMA__none].t1 (x) VALUES (:x)"),
CompiledSQL("DELETE FROM [SCHEMA_bar].t3"),
)
- @testing.provide_metadata
- def test_crud(self):
- self._fixture()
+ def test_crud(self, plain_tables, connection):
+ # provided by metadata fixture provided by plain_tables fixture
+ self.metadata.create_all(connection)
map_ = {
None: config.test_schema,
t2 = Table("t2", metadata, Column("x", Integer), schema="foo")
t3 = Table("t3", metadata, Column("x", Integer), schema="bar")
- with self.sql_execution_asserter(config.db) as asserter:
- with config.db.begin() as conn, conn.execution_options(
- schema_translate_map=map_
- ) as conn:
+ with self.sql_execution_asserter(connection) as asserter:
+ conn = connection.execution_options(schema_translate_map=map_)
- conn.execute(t1.insert(), {"x": 1})
- conn.execute(t2.insert(), {"x": 1})
- conn.execute(t3.insert(), {"x": 1})
+ conn.execute(t1.insert(), {"x": 1})
+ conn.execute(t2.insert(), {"x": 1})
+ conn.execute(t3.insert(), {"x": 1})
- conn.execute(t1.update().values(x=1).where(t1.c.x == 1))
- conn.execute(t2.update().values(x=2).where(t2.c.x == 1))
- conn.execute(t3.update().values(x=3).where(t3.c.x == 1))
+ conn.execute(t1.update().values(x=1).where(t1.c.x == 1))
+ conn.execute(t2.update().values(x=2).where(t2.c.x == 1))
+ conn.execute(t3.update().values(x=3).where(t3.c.x == 1))
- eq_(conn.scalar(select(t1.c.x)), 1)
- eq_(conn.scalar(select(t2.c.x)), 2)
- eq_(conn.scalar(select(t3.c.x)), 3)
+ eq_(conn.scalar(select(t1.c.x)), 1)
+ eq_(conn.scalar(select(t2.c.x)), 2)
+ eq_(conn.scalar(select(t3.c.x)), 3)
- conn.execute(t1.delete())
- conn.execute(t2.delete())
- conn.execute(t3.delete())
+ conn.execute(t1.delete())
+ conn.execute(t2.delete())
+ conn.execute(t3.delete())
asserter.assert_(
CompiledSQL("INSERT INTO [SCHEMA__none].t1 (x) VALUES (:x)"),
CompiledSQL("DELETE FROM [SCHEMA_bar].t3"),
)
- @testing.provide_metadata
- def test_via_engine(self):
- self._fixture()
+ def test_via_engine(self, plain_tables, metadata):
+
+ with config.db.begin() as connection:
+ metadata.create_all(connection)
map_ = {
None: config.test_schema,
with self.sql_execution_asserter(config.db) as asserter:
eng = config.db.execution_options(schema_translate_map=map_)
- conn = eng.connect()
- conn.execute(select(t2.c.x))
+ with eng.connect() as conn:
+ conn.execute(select(t2.c.x))
asserter.assert_(
CompiledSQL("SELECT [SCHEMA_foo].t2.x FROM [SCHEMA_foo].t2")
)
class ExecutionOptionsTest(fixtures.TestBase):
- def test_dialect_conn_options(self):
+ def test_dialect_conn_options(self, testing_engine):
engine = testing_engine("sqlite://", options=dict(_initialize=False))
engine.dialect = Mock()
- conn = engine.connect()
- c2 = conn.execution_options(foo="bar")
- eq_(
- engine.dialect.set_connection_execution_options.mock_calls,
- [call(c2, {"foo": "bar"})],
- )
+ with engine.connect() as conn:
+ c2 = conn.execution_options(foo="bar")
+ eq_(
+ engine.dialect.set_connection_execution_options.mock_calls,
+ [call(c2, {"foo": "bar"})],
+ )
- def test_dialect_engine_options(self):
+ def test_dialect_engine_options(self, testing_engine):
engine = testing_engine("sqlite://")
engine.dialect = Mock()
e2 = engine.execution_options(foo="bar")
[call(engine, {"foo": "bar"})],
)
- def test_propagate_engine_to_connection(self):
+ def test_propagate_engine_to_connection(self, testing_engine):
engine = testing_engine(
"sqlite://", options=dict(execution_options={"foo": "bar"})
)
- conn = engine.connect()
- eq_(conn._execution_options, {"foo": "bar"})
+ with engine.connect() as conn:
+ eq_(conn._execution_options, {"foo": "bar"})
- def test_propagate_option_engine_to_connection(self):
+ def test_propagate_option_engine_to_connection(self, testing_engine):
e1 = testing_engine(
"sqlite://", options=dict(execution_options={"foo": "bar"})
)
eq_(c1._execution_options, {"foo": "bar"})
eq_(c2._execution_options, {"foo": "bar", "bat": "hoho"})
- def test_get_engine_execution_options(self):
+ c1.close()
+ c2.close()
+
+ def test_get_engine_execution_options(self, testing_engine):
engine = testing_engine("sqlite://")
engine.dialect = Mock()
e2 = engine.execution_options(foo="bar")
eq_(e2.get_execution_options(), {"foo": "bar"})
- def test_get_connection_execution_options(self):
+ def test_get_connection_execution_options(self, testing_engine):
engine = testing_engine("sqlite://", options=dict(_initialize=False))
engine.dialect = Mock()
- conn = engine.connect()
- c = conn.execution_options(foo="bar")
+ with engine.connect() as conn:
+ c = conn.execution_options(foo="bar")
- eq_(c.get_execution_options(), {"foo": "bar"})
+ eq_(c.get_execution_options(), {"foo": "bar"})
class EngineEventsTest(fixtures.TestBase):
__requires__ = ("ad_hoc_engines",)
__backend__ = True
- def tearDown(self):
+ def teardown_test(self):
Engine.dispatch._clear()
Engine._has_events = False
):
break
- def test_per_engine_independence(self):
+ def test_per_engine_independence(self, testing_engine):
e1 = testing_engine(config.db_url)
e2 = testing_engine(config.db_url)
conn.execute(s2)
eq_([arg[1][1] for arg in canary.mock_calls], [s1, s1, s2])
- def test_per_engine_plus_global(self):
+ def test_per_engine_plus_global(self, testing_engine):
canary = Mock()
event.listen(Engine, "before_execute", canary.be1)
e1 = testing_engine(config.db_url)
event.listen(e1, "before_execute", canary.be2)
event.listen(Engine, "before_execute", canary.be3)
- e1.connect()
- e2.connect()
with e1.connect() as conn:
conn.execute(select(1))
eq_(canary.be2.call_count, 1)
eq_(canary.be3.call_count, 2)
- def test_per_connection_plus_engine(self):
+ def test_per_connection_plus_engine(self, testing_engine):
canary = Mock()
e1 = testing_engine(config.db_url)
eq_(canary.be1.call_count, 2)
eq_(canary.be2.call_count, 2)
- @testing.combinations((True, False), (True, True), (False, False))
+ @testing.combinations(
+ (True, False),
+ (True, True),
+ (False, False),
+ argnames="mock_out_on_connect, add_our_own_onconnect",
+ )
def test_insert_connect_is_definitely_first(
- self, mock_out_on_connect, add_our_own_onconnect
+ self, mock_out_on_connect, add_our_own_onconnect, testing_engine
):
"""test issue #5708.
patcher = util.nullcontext()
with patcher:
- e1 = create_engine(config.db_url)
+ e1 = testing_engine(config.db_url)
initialize = e1.dialect.initialize
conn.exec_driver_sql(select1(testing.db))
eq_(m1.mock_calls, [])
- def test_add_event_after_connect(self):
+ def test_add_event_after_connect(self, testing_engine):
# new feature as of #2978
+
canary = Mock()
- e1 = create_engine(config.db_url)
+ e1 = testing_engine(config.db_url, future=False)
assert not e1._has_events
conn = e1.connect()
conn._branch().execute(select(1))
eq_(canary.be1.call_count, 2)
- def test_force_conn_events_false(self):
+ def test_force_conn_events_false(self, testing_engine):
canary = Mock()
- e1 = create_engine(config.db_url)
+ e1 = testing_engine(config.db_url, future=False)
assert not e1._has_events
event.listen(e1, "before_execute", canary.be1)
conn._branch().execute(select(1))
eq_(canary.be1.call_count, 0)
- def test_cursor_events_ctx_execute_scalar(self):
+ def test_cursor_events_ctx_execute_scalar(self, testing_engine):
canary = Mock()
e1 = testing_engine(config.db_url)
[call(conn, ctx.cursor, stmt, ctx.parameters[0], ctx, False)],
)
- def test_cursor_events_execute(self):
+ def test_cursor_events_execute(self, testing_engine):
canary = Mock()
e1 = testing_engine(config.db_url)
),
((), {"z": 10}, [], {"z": 10}, testing.requires.legacy_engine),
(({"z": 10},), {}, [], {"z": 10}),
+ argnames="multiparams, params, expected_multiparams, expected_params",
)
def test_modify_parameters_from_event_one(
- self, multiparams, params, expected_multiparams, expected_params
+ self,
+ multiparams,
+ params,
+ expected_multiparams,
+ expected_params,
+ testing_engine,
):
# this is testing both the normalization added to parameters
# as of I97cb4d06adfcc6b889f10d01cc7775925cffb116 as well as
[(15,), (19,)],
)
- def test_modify_parameters_from_event_three(self, connection):
+ def test_modify_parameters_from_event_three(
+ self, connection, testing_engine
+ ):
def before_execute(
conn, clauseelement, multiparams, params, execution_options
):
with e1.connect() as conn:
conn.execute(select(literal("1")))
- def test_argument_format_execute(self):
+ def test_argument_format_execute(self, testing_engine):
def before_execute(
conn, clauseelement, multiparams, params, execution_options
):
)
@testing.requires.ad_hoc_engines
- def test_dispose_event(self):
+ def test_dispose_event(self, testing_engine):
canary = Mock()
- eng = create_engine(testing.db.url)
+ eng = testing_engine(testing.db.url)
event.listen(eng, "engine_disposed", canary)
conn = eng.connect()
event.listen(engine, "commit", tracker("commit"))
event.listen(engine, "rollback", tracker("rollback"))
- conn = engine.connect()
- trans = conn.begin()
- conn.execute(select(1))
- trans.rollback()
- trans = conn.begin()
- conn.execute(select(1))
- trans.commit()
+ with engine.connect() as conn:
+ trans = conn.begin()
+ conn.execute(select(1))
+ trans.rollback()
+ trans = conn.begin()
+ conn.execute(select(1))
+ trans.commit()
eq_(
canary,
event.listen(engine, "commit", tracker("commit"), named=True)
event.listen(engine, "rollback", tracker("rollback"), named=True)
- conn = engine.connect()
- trans = conn.begin()
- conn.execute(select(1))
- trans.rollback()
- trans = conn.begin()
- conn.execute(select(1))
- trans.commit()
+ with engine.connect() as conn:
+ trans = conn.begin()
+ conn.execute(select(1))
+ trans.rollback()
+ trans = conn.begin()
+ conn.execute(select(1))
+ trans.commit()
eq_(
canary,
__requires__ = ("ad_hoc_engines",)
__backend__ = True
- def tearDown(self):
+ def teardown_test(self):
Engine.dispatch._clear()
Engine._has_events = False
class HandleInvalidatedOnConnectTest(fixtures.TestBase):
__requires__ = ("sqlite",)
- def setUp(self):
+ def setup_test(self):
e = create_engine("sqlite://")
connection = Mock(get_server_version_info=Mock(return_value="5.0"))
],
)
+ c.close()
+ c2.close()
+
class DialectEventTest(fixtures.TestBase):
@contextmanager
)
@testing.fixture
- def input_sizes_fixture(self):
+ def input_sizes_fixture(self, testing_engine):
canary = mock.Mock()
def do_set_input_sizes(cursor, list_of_tuples, context):
__only_on__ = "sqlite"
__requires__ = ("ad_hoc_engines",)
- def setup(self):
+ def setup_test(self):
self.eng = engines.testing_engine(options={"echo": True})
self.no_param_engine = engines.testing_engine(
options={"echo": True, "hide_parameters": True}
for log in [logging.getLogger("sqlalchemy.engine")]:
log.addHandler(self.buf)
- def teardown(self):
+ def teardown_test(self):
exec_sql(self.eng, "drop table if exists foo")
for log in [logging.getLogger("sqlalchemy.engine")]:
log.removeHandler(self.buf)
class PoolLoggingTest(fixtures.TestBase):
- def setup(self):
+ def setup_test(self):
self.existing_level = logging.getLogger("sqlalchemy.pool").level
self.buf = logging.handlers.BufferingHandler(100)
for log in [logging.getLogger("sqlalchemy.pool")]:
log.addHandler(self.buf)
- def teardown(self):
+ def teardown_test(self):
for log in [logging.getLogger("sqlalchemy.pool")]:
log.removeHandler(self.buf)
logging.getLogger("sqlalchemy.pool").setLevel(self.existing_level)
kw.update({"echo": True})
return engines.testing_engine(options=kw)
- def setup(self):
+ def setup_test(self):
self.buf = logging.handlers.BufferingHandler(100)
for log in [
logging.getLogger("sqlalchemy.engine"),
]:
log.addHandler(self.buf)
- def teardown(self):
+ def teardown_test(self):
for log in [
logging.getLogger("sqlalchemy.engine"),
logging.getLogger("sqlalchemy.pool"),
class EchoTest(fixtures.TestBase):
__requires__ = ("ad_hoc_engines",)
- def setup(self):
+ def setup_test(self):
self.level = logging.getLogger("sqlalchemy.engine").level
logging.getLogger("sqlalchemy.engine").setLevel(logging.WARN)
self.buf = logging.handlers.BufferingHandler(100)
logging.getLogger("sqlalchemy.engine").addHandler(self.buf)
- def teardown(self):
+ def teardown_test(self):
logging.getLogger("sqlalchemy.engine").removeHandler(self.buf)
logging.getLogger("sqlalchemy.engine").setLevel(self.level)
from sqlalchemy.testing import expect_raises
from sqlalchemy.testing import fixtures
from sqlalchemy.testing import is_
+from sqlalchemy.testing import is_none
from sqlalchemy.testing import is_not
+from sqlalchemy.testing import is_not_none
from sqlalchemy.testing import is_true
from sqlalchemy.testing import mock
from sqlalchemy.testing.engines import testing_engine
class PoolTestBase(fixtures.TestBase):
- def setup(self):
+ def setup_test(self):
pool.clear_managers()
self._teardown_conns = []
- def teardown(self):
+ def teardown_test(self):
for ref in self._teardown_conns:
conn = ref()
if conn:
conn.close()
@classmethod
- def teardown_class(cls):
+ def teardown_test_class(cls):
pool.clear_managers()
def _with_teardown(self, connection):
p = self._queuepool_fixture()
canary = []
+ @event.listens_for(p, "checkin")
def checkin(*arg, **kw):
canary.append("checkin")
- event.listen(p, "checkin", checkin)
+ @event.listens_for(p, "close_detached")
+ def close_detached(*arg, **kw):
+ canary.append("close_detached")
+
+ @event.listens_for(p, "detach")
+ def detach(*arg, **kw):
+ canary.append("detach")
return p, canary
assert canary.call_args_list[0][0][0] is dbapi_con
assert canary.call_args_list[0][0][2] is exc
+ @testing.combinations((True, testing.requires.python3), (False,))
@testing.requires.predictable_gc
- def test_checkin_event_gc(self):
+ def test_checkin_event_gc(self, detach_gced):
p, canary = self._checkin_event_fixture()
+ if detach_gced:
+ p._is_asyncio = True
+
c1 = p.connect()
+
+ dbapi_connection = weakref.ref(c1.connection)
+
eq_(canary, [])
del c1
lazy_gc()
- eq_(canary, ["checkin"])
+
+ if detach_gced:
+ # "close_detached" is not called because for asyncio the
+ # connection is just lost.
+ eq_(canary, ["detach"])
+
+ else:
+ eq_(canary, ["checkin"])
+
+ gc_collect()
+ if detach_gced:
+ is_none(dbapi_connection())
+ else:
+ is_not_none(dbapi_connection())
def test_checkin_event_on_subsequently_recreated(self):
p, canary = self._checkin_event_fixture()
eq_(conn.info["important_flag"], True)
conn.close()
- def teardown(self):
+ def teardown_test(self):
# TODO: need to get remove() functionality
# going
pool.Pool.dispatch._clear()
self._assert_cleanup_on_pooled_reconnect(dbapi, p)
+ @testing.combinations((True, testing.requires.python3), (False,))
@testing.requires.predictable_gc
- def test_userspace_disconnectionerror_weakref_finalizer(self):
+ def test_userspace_disconnectionerror_weakref_finalizer(self, detach_gced):
dbapi, pool = self._queuepool_dbapi_fixture(
pool_size=1, max_overflow=2
)
+ if detach_gced:
+ pool._is_asyncio = True
+
@event.listens_for(pool, "checkout")
def handle_checkout_event(dbapi_con, con_record, con_proxy):
if getattr(dbapi_con, "boom") == "yes":
del conn
gc_collect()
- # new connection was reset on return appropriately
- eq_(dbapi_conn.mock_calls, [call.rollback()])
+ if detach_gced:
+ # new connection was detached + abandoned on return
+ eq_(dbapi_conn.mock_calls, [])
+ else:
+ # new connection reset and returned to pool
+ eq_(dbapi_conn.mock_calls, [call.rollback()])
# old connection was just closed - did not get an
# erroneous reset on return
__requires__ = ("cextensions",)
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
from sqlalchemy import cprocessors
cls.module = cprocessors
class PyDateProcessorTest(_DateProcessorTest):
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
from sqlalchemy import processors
cls.module = type(
__requires__ = ("cextensions",)
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
from sqlalchemy import cprocessors
cls.module = cprocessors
class PyDistillArgsTest(_DistillArgsTest):
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
from sqlalchemy.engine import util
cls.module = type(
__requires__ = ("cextensions",)
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
from sqlalchemy import cutils as util
cls.module = util
class PrePingMockTest(fixtures.TestBase):
- def setup(self):
+ def setup_test(self):
self.dbapi = MockDBAPI()
def _pool_fixture(self, pre_ping, pool_kw=None):
)
return _pool
- def teardown(self):
+ def teardown_test(self):
self.dbapi.dispose()
def test_ping_not_on_first_connect(self):
class MockReconnectTest(fixtures.TestBase):
- def setup(self):
+ def setup_test(self):
self.dbapi = MockDBAPI()
self.db = testing_engine(
e, MockDisconnect
)
- def teardown(self):
+ def teardown_test(self):
self.dbapi.dispose()
def test_reconnect(self):
__backend__ = True
__requires__ = "graceful_disconnects", "ad_hoc_engines"
- def setup(self):
+ def setup_test(self):
self.engine = engines.reconnecting_engine()
- def teardown(self):
+ def teardown_test(self):
self.engine.dispose()
def test_reconnect(self):
class InvalidateDuringResultTest(fixtures.TestBase):
__backend__ = True
- def setup(self):
+ def setup_test(self):
self.engine = engines.reconnecting_engine()
self.meta = MetaData()
table = Table(
[{"id": i, "name": "row %d" % i} for i in range(1, 100)],
)
- def teardown(self):
+ def teardown_test(self):
with self.engine.begin() as conn:
self.meta.drop_all(conn)
self.engine.dispose()
__backend__ = True
- def setup(self):
+ def setup_test(self):
self.engine = engines.reconnecting_engine(
options=dict(future=self.future)
)
)
self.meta.create_all(self.engine)
- def teardown(self):
+ def teardown_test(self):
self.meta.drop_all(self.engine)
self.engine.dispose()
assert f1 in b1.constraints
assert len(b1.constraints) == 2
- def test_override_keys(self, connection, metadata):
+ def test_override_keys(self, metadata, connection):
"""test that columns can be overridden with a 'key',
and that ForeignKey targeting during reflection still works."""
run_create_tables = None
@classmethod
- def teardown_class(cls):
+ def teardown_test_class(cls):
# TablesTest is used here without
# run_create_tables, so add an explicit drop of whatever is in
# metadata
@testing.requires.schemas
@testing.requires.cross_schema_fk_reflection
@testing.requires.implicit_default_schema
- @testing.provide_metadata
def test_blank_schema_arg(self, connection, metadata):
Table(
__backend__ = True
@testing.requires.denormalized_names
- def setup(self):
+ def setup_test(self):
with testing.db.begin() as conn:
conn.exec_driver_sql(
"""
)
@testing.requires.denormalized_names
- def teardown(self):
+ def teardown_test(self):
with testing.db.begin() as conn:
conn.exec_driver_sql("drop table weird_casing")
import sys
-from sqlalchemy import create_engine
from sqlalchemy import event
from sqlalchemy import exc
from sqlalchemy import func
__backend__ = True
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
global metadata
metadata = MetaData()
@classmethod
- def teardown_class(cls):
+ def teardown_test_class(cls):
metadata.drop_all(testing.db)
def test_rollback_deadlock(self):
def test_per_engine(self):
# new in 0.9
- eng = create_engine(
+ eng = testing_engine(
testing.db.url,
- execution_options={
- "isolation_level": self._non_default_isolation_level()
- },
+ options=dict(
+ execution_options={
+ "isolation_level": self._non_default_isolation_level()
+ }
+ ),
)
conn = eng.connect()
eq_(
)
def test_per_option_engine(self):
- eng = create_engine(testing.db.url).execution_options(
+ eng = testing_engine(testing.db.url).execution_options(
isolation_level=self._non_default_isolation_level()
)
)
def test_isolation_level_accessors_connection_default(self):
- eng = create_engine(testing.db.url)
+ eng = testing_engine(testing.db.url)
with eng.connect() as conn:
eq_(conn.default_isolation_level, self._default_isolation_level())
with eng.connect() as conn:
eq_(conn.get_isolation_level(), self._default_isolation_level())
def test_isolation_level_accessors_connection_option_modified(self):
- eng = create_engine(testing.db.url)
+ eng = testing_engine(testing.db.url)
with eng.connect() as conn:
c2 = conn.execution_options(
isolation_level=self._non_default_isolation_level()
await trans.rollback(),
@async_test
- async def test_pool_exhausted(self, async_engine):
+ async def test_pool_exhausted_some_timeout(self, async_engine):
engine = create_async_engine(
testing.db.url,
pool_size=1,
pool_timeout=0.1,
)
async with engine.connect():
- with expect_raises(asyncio.TimeoutError):
+ with expect_raises(exc.TimeoutError):
+ await engine.connect()
+
+ @async_test
+ async def test_pool_exhausted_no_timeout(self, async_engine):
+ engine = create_async_engine(
+ testing.db.url,
+ pool_size=1,
+ max_overflow=0,
+ pool_timeout=0,
+ )
+ async with engine.connect():
+ with expect_raises(exc.TimeoutError):
await engine.connect()
@async_test
class DeclarativeTestBase(fixtures.TestBase, testing.AssertsExecutionResults):
- def setup(self):
+ def setup_test(self):
global Base
Base = decl.declarative_base(testing.db)
- def teardown(self):
+ def teardown_test(self):
close_all_sessions()
clear_mappers()
Base.metadata.drop_all(testing.db)
from sqlalchemy import testing
from sqlalchemy.ext.declarative import DeferredReflection
from sqlalchemy.orm import clear_mappers
-from sqlalchemy.orm import create_session
from sqlalchemy.orm import decl_api as decl
from sqlalchemy.orm import declared_attr
from sqlalchemy.orm import exc as orm_exc
from sqlalchemy.testing import assert_raises_message
from sqlalchemy.testing import eq_
from sqlalchemy.testing import fixtures
+from sqlalchemy.testing.fixtures import fixture_session
from sqlalchemy.testing.schema import Column
from sqlalchemy.testing.schema import Table
from sqlalchemy.testing.util import gc_collect
class DeclarativeReflectionBase(fixtures.TablesTest):
__requires__ = ("reflectable_autoincrement",)
- def setup(self):
+ def setup_test(self):
global Base, registry
registry = decl.registry()
Base = registry.generate_base()
- def teardown(self):
- super(DeclarativeReflectionBase, self).teardown()
+ def teardown_test(self):
clear_mappers()
class DeferredReflectBase(DeclarativeReflectionBase):
- def teardown(self):
- super(DeferredReflectBase, self).teardown()
+ def teardown_test(self):
+ super(DeferredReflectBase, self).teardown_test()
_DeferredMapperConfig._configs.clear()
u1 = User(
name="u1", addresses=[Address(email="one"), Address(email="two")]
)
- sess = create_session(testing.db)
- sess.add(u1)
- sess.flush()
- sess.expunge_all()
- eq_(
- sess.query(User).all(),
- [
- User(
- name="u1",
- addresses=[Address(email="one"), Address(email="two")],
- )
- ],
- )
- a1 = sess.query(Address).filter(Address.email == "two").one()
- eq_(a1, Address(email="two"))
- eq_(a1.user, User(name="u1"))
+ with fixture_session() as sess:
+ sess.add(u1)
+ sess.commit()
+
+ with fixture_session() as sess:
+ eq_(
+ sess.query(User).all(),
+ [
+ User(
+ name="u1",
+ addresses=[Address(email="one"), Address(email="two")],
+ )
+ ],
+ )
+ a1 = sess.query(Address).filter(Address.email == "two").one()
+ eq_(a1, Address(email="two"))
+ eq_(a1.user, User(name="u1"))
def test_exception_prepare_not_called(self):
class User(DeferredReflection, fixtures.ComparableEntity, Base):
return {"primary_key": cls.__table__.c.id}
DeferredReflection.prepare(testing.db)
- sess = Session(testing.db)
- sess.add_all(
- [User(name="G"), User(name="Q"), User(name="A"), User(name="C")]
- )
- sess.commit()
- eq_(
- sess.query(User).order_by(User.name).all(),
- [User(name="A"), User(name="C"), User(name="G"), User(name="Q")],
- )
+ with fixture_session() as sess:
+ sess.add_all(
+ [
+ User(name="G"),
+ User(name="Q"),
+ User(name="A"),
+ User(name="C"),
+ ]
+ )
+ sess.commit()
+ eq_(
+ sess.query(User).order_by(User.name).all(),
+ [
+ User(name="A"),
+ User(name="C"),
+ User(name="G"),
+ User(name="Q"),
+ ],
+ )
@testing.requires.predictable_gc
def test_cls_not_strong_ref(self):
u1 = User(name="u1", items=[Item(name="i1"), Item(name="i2")])
- sess = Session(testing.db)
- sess.add(u1)
- sess.commit()
+ with fixture_session() as sess:
+ sess.add(u1)
+ sess.commit()
- eq_(
- sess.query(User).all(),
- [User(name="u1", items=[Item(name="i1"), Item(name="i2")])],
- )
+ eq_(
+ sess.query(User).all(),
+ [User(name="u1", items=[Item(name="i1"), Item(name="i2")])],
+ )
def test_string_resolution(self):
class User(DeferredReflection, fixtures.ComparableEntity, Base):
Foo = Base.registry._class_registry["Foo"]
Bar = Base.registry._class_registry["Bar"]
- s = Session(testing.db)
-
- s.add_all(
- [
- Bar(data="d1", bar_data="b1"),
- Bar(data="d2", bar_data="b2"),
- Bar(data="d3", bar_data="b3"),
- Foo(data="d4"),
- ]
- )
- s.commit()
-
- eq_(
- s.query(Foo).order_by(Foo.id).all(),
- [
- Bar(data="d1", bar_data="b1"),
- Bar(data="d2", bar_data="b2"),
- Bar(data="d3", bar_data="b3"),
- Foo(data="d4"),
- ],
- )
+ with fixture_session() as s:
+ s.add_all(
+ [
+ Bar(data="d1", bar_data="b1"),
+ Bar(data="d2", bar_data="b2"),
+ Bar(data="d3", bar_data="b3"),
+ Foo(data="d4"),
+ ]
+ )
+ s.commit()
+
+ eq_(
+ s.query(Foo).order_by(Foo.id).all(),
+ [
+ Bar(data="d1", bar_data="b1"),
+ Bar(data="d2", bar_data="b2"),
+ Bar(data="d3", bar_data="b3"),
+ Foo(data="d4"),
+ ],
+ )
class DeferredSingleInhReflectionTest(DeferredInhReflectBase):
Column("name", String(50)),
)
- def teardown(self):
+ def teardown_test(self):
clear_mappers()
def _fixture(self, collection_class, is_dict=False):
class _CollectionOperations(fixtures.TestBase):
- def setup(self):
+ def setup_test(self):
collection_class = self.collection_class
metadata = MetaData()
self.session = fixture_session()
self.Parent, self.Child = Parent, Child
- def teardown(self):
+ def teardown_test(self):
self.metadata.drop_all(testing.db)
def roundtrip(self, obj):
class ProxyFactoryTest(ListTest):
- def setup(self):
+ def setup_test(self):
metadata = MetaData()
parents_table = Table(
class LazyLoadTest(fixtures.TestBase):
- def setup(self):
+ def setup_test(self):
metadata = MetaData()
parents_table = Table(
self.Parent, self.Child = Parent, Child
self.table = parents_table
- def teardown(self):
+ def teardown_test(self):
self.metadata.drop_all(testing.db)
def roundtrip(self, obj):
class DictOfTupleUpdateTest(fixtures.TestBase):
- def setup(self):
+ def setup_test(self):
class B(object):
def __init__(self, key, elem):
self.key = key
class AttributeAccessTest(fixtures.TestBase):
- def teardown(self):
+ def teardown_test(self):
clear_mappers()
def test_resolve_aliased_class(self):
run_inserts = "once"
run_deletes = None
- def setup(self):
+ def setup_test(self):
self.bakery = baked.bakery()
__dialect__ = "default"
- def teardown(self):
+ def teardown_test(self):
for cls in (Select, BindParameter):
deregister(cls)
class _ExtBase(object):
@classmethod
- def teardown_class(cls):
+ def teardown_test_class(cls):
instrumentation._reinstall_default_lookups()
class UserDefinedExtensionTest(_ExtBase, fixtures.ORMTest):
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
global MyBaseClass, MyClass
class MyBaseClass(object):
else:
del self._goofy_dict[key]
- def teardown(self):
+ def teardown_test(self):
clear_mappers()
def test_instance_dict(self):
from sqlalchemy import util
from sqlalchemy.ext.horizontal_shard import ShardedSession
from sqlalchemy.orm import clear_mappers
-from sqlalchemy.orm import create_session
from sqlalchemy.orm import deferred
from sqlalchemy.orm import mapper
from sqlalchemy.orm import relationship
schema = None
- def setUp(self):
+ def setup_test(self):
global db1, db2, db3, db4, weather_locations, weather_reports
db1, db2, db3, db4 = self._dbs = self._init_dbs()
@classmethod
def setup_session(cls):
- global create_session
+ global sharded_session
shard_lookup = {
"North America": "north_america",
"Asia": "asia",
else:
return ids
- create_session = sessionmaker(
+ sharded_session = sessionmaker(
class_=ShardedSession, autoflush=True, autocommit=False
)
- create_session.configure(
+ sharded_session.configure(
shards={
"north_america": db1,
"asia": db2,
tokyo.reports.append(Report(80.0, id_=1))
newyork.reports.append(Report(75, id_=1))
quito.reports.append(Report(85))
- sess = create_session(future=True)
+ sess = sharded_session(future=True)
for c in [tokyo, newyork, toronto, london, dublin, brasilia, quito]:
sess.add(c)
sess.flush()
self.dbs = [db1, db2, db3, db4]
return self.dbs
- def teardown(self):
+ def teardown_test(self):
clear_mappers()
- for db in self.dbs:
- db.connect().invalidate()
+ testing_reaper.checkin_all()
for i in range(1, 5):
os.remove("shard%d_%s.db" % (i, provision.FOLLOWER_IDENT))
self.engine = e
return db1, db2, db3, db4
- def teardown(self):
+ def teardown_test(self):
clear_mappers()
- self.engine.connect().invalidate()
+ testing_reaper.checkin_all()
for i in range(1, 5):
os.remove("shard%d_%s.db" % (i, provision.FOLLOWER_IDENT))
self.postgresql_engine = e2
return db1, db2, db3, db4
- def teardown(self):
+ def teardown_test(self):
clear_mappers()
- self.sqlite_engine.connect().invalidate()
+ # the tests in this suite don't cleanly close out the Session
+ # at the moment so use the reaper to close all connections
+ testing_reaper.checkin_all()
+
for i in [1, 3]:
os.remove("shard%d_%s.db" % (i, provision.FOLLOWER_IDENT))
self.tables_test_metadata.drop_all(conn)
for i in [2, 4]:
conn.exec_driver_sql("DROP SCHEMA shard%s CASCADE" % (i,))
+ self.postgresql_engine.dispose()
class SelectinloadRegressionTest(fixtures.DeclarativeMappedTest):
return self.dbs
- def teardown(self):
+ def teardown_test(self):
for db in self.dbs:
db.connect().invalidate()
- testing_reaper.close_all()
+ testing_reaper.checkin_all()
for i in range(1, 3):
os.remove("shard%d_%s.db" % (i, provision.FOLLOWER_IDENT))
__dialect__ = "default"
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
from sqlalchemy import literal
symbols = ("usd", "gbp", "cad", "eur", "aud")
def _type_fixture(cls):
return MutableDict
- def teardown(self):
+ def teardown_test(self):
# clear out mapper events
Mapper.dispatch._clear()
ClassManager.dispatch._clear()
- super(_MutableDictTestFixture, self).teardown()
class _MutableDictTestBase(_MutableDictTestFixture):
def _type_fixture(cls):
return MutableList
- def teardown(self):
+ def teardown_test(self):
# clear out mapper events
Mapper.dispatch._clear()
ClassManager.dispatch._clear()
- super(_MutableListTestFixture, self).teardown()
class _MutableListTestBase(_MutableListTestFixture):
def _type_fixture(cls):
return MutableSet
- def teardown(self):
+ def teardown_test(self):
# clear out mapper events
Mapper.dispatch._clear()
ClassManager.dispatch._clear()
- super(_MutableSetTestFixture, self).teardown()
class _MutableSetTestBase(_MutableSetTestFixture):
Column("unrelated_data", String(50)),
)
- def setup(self):
+ def setup_test(self):
from sqlalchemy.ext import mutable
mutable._setup_composite_listener()
- super(_CompositeTestBase, self).setup()
- def teardown(self):
+ def teardown_test(self):
# clear out mapper events
Mapper.dispatch._clear()
ClassManager.dispatch._clear()
- super(_CompositeTestBase, self).teardown()
@classmethod
def _type_fixture(cls):
from sqlalchemy.orm import relationship
from sqlalchemy.testing import eq_
from sqlalchemy.testing import fixtures
-from sqlalchemy.testing.fixtures import create_session
+from sqlalchemy.testing.fixtures import fixture_session
from sqlalchemy.testing.schema import Column
from sqlalchemy.testing.schema import Table
from sqlalchemy.testing.util import picklers
class OrderingListTest(fixtures.TestBase):
- def setup(self):
+ def setup_test(self):
global metadata, slides_table, bullets_table, Slide, Bullet
slides_table, bullets_table = None, None
Slide, Bullet = None, None
metadata.create_all(testing.db)
- def teardown(self):
+ def teardown_test(self):
metadata.drop_all(testing.db)
def test_append_no_reorder(self):
self.assert_(s1.bullets[2].position == 3)
self.assert_(s1.bullets[3].position == 4)
- session = create_session()
+ session = fixture_session()
session.add(s1)
session.flush()
s1.bullets._reorder()
self.assert_(s1.bullets[4].position == 5)
- session = create_session()
+ session = fixture_session()
session.add(s1)
session.flush()
self.assert_(len(s1.bullets) == 6)
self.assert_(s1.bullets[5].position == 5)
- session = create_session()
+ session = fixture_session()
session.add(s1)
session.flush()
self.assert_(s1.bullets[li].position == li)
self.assert_(s1.bullets[li] == b[bi])
- session = create_session()
+ session = fixture_session()
session.add(s1)
session.flush()
self.assert_(len(s1.bullets) == 3)
self.assert_(s1.bullets[2].position == 2)
- session = create_session()
+ session = fixture_session()
session.add(s1)
session.flush()
):
__dialect__ = "default"
- def setup(self):
+ def setup_test(self):
global Base
Base = declarative_base(testing.db)
- def teardown(self):
+ def teardown_test(self):
close_all_sessions()
clear_mappers()
Base.metadata.drop_all(testing.db)
class ConcurrentUseDeclMappingTest(fixtures.TestBase):
- def teardown(self):
+ def teardown_test(self):
clear_mappers()
@classmethod
class DeclarativeTestBase(fixtures.TestBase, testing.AssertsExecutionResults):
- def setup(self):
+ def setup_test(self):
global Base
Base = decl.declarative_base(testing.db)
- def teardown(self):
+ def teardown_test(self):
close_all_sessions()
clear_mappers()
Base.metadata.drop_all(testing.db)
class DeclarativeTestBase(fixtures.TestBase, testing.AssertsExecutionResults):
- def setup(self):
+ def setup_test(self):
global Base, mapper_registry
mapper_registry = registry(metadata=MetaData())
Base = mapper_registry.generate_base()
- def teardown(self):
+ def teardown_test(self):
close_all_sessions()
clear_mappers()
with testing.db.begin() as conn:
class DeclarativeReflectionBase(fixtures.TablesTest):
__requires__ = ("reflectable_autoincrement",)
- def setup(self):
+ def setup_test(self):
global Base, registry
registry = decl.registry(metadata=MetaData())
Base = registry.generate_base()
- def teardown(self):
- super(DeclarativeReflectionBase, self).teardown()
+ def teardown_test(self):
clear_mappers()
from sqlalchemy.orm.util import instance_str
from sqlalchemy.testing import assert_raises
from sqlalchemy.testing import assert_raises_message
-from sqlalchemy.testing import engines
from sqlalchemy.testing import eq_
from sqlalchemy.testing import fixtures
from sqlalchemy.testing import is_
@testing.emits_warning(r".*updated rowcount")
@testing.requires.sane_rowcount_w_returning
- @engines.close_open_connections
def test_save_update(self):
subtable, base, stuff = (
self.tables.subtable,
)
return parent, child
- def tearDown(self):
+ def teardown_test(self):
clear_mappers()
def test_warning_on_sub(self):
@classmethod
def insert_data(cls, connection):
Parent, A, B = cls.classes("Parent", "A", "B")
- s = fixture_session()
-
- p1 = Parent(id=1)
- p2 = Parent(id=2)
- s.add_all([p1, p2])
- s.flush()
+ with Session(connection) as s:
+ p1 = Parent(id=1)
+ p2 = Parent(id=2)
+ s.add_all([p1, p2])
+ s.flush()
- s.add_all(
- [
- A(id=1, parent_id=1),
- B(id=2, parent_id=1),
- A(id=3, parent_id=1),
- B(id=4, parent_id=1),
- ]
- )
- s.flush()
+ s.add_all(
+ [
+ A(id=1, parent_id=1),
+ B(id=2, parent_id=1),
+ A(id=3, parent_id=1),
+ B(id=4, parent_id=1),
+ ]
+ )
+ s.flush()
- s.query(A).filter(A.id.in_([3, 4])).update(
- {A.type: None}, synchronize_session=False
- )
- s.commit()
+ s.query(A).filter(A.id.in_([3, 4])).update(
+ {A.type: None}, synchronize_session=False
+ )
+ s.commit()
def test_pk_is_null(self):
Parent, A = self.classes("Parent", "A")
ASingleSubA, ASingleSubB, AJoinedSubA, AJoinedSubB = cls.classes(
"ASingleSubA", "ASingleSubB", "AJoinedSubA", "AJoinedSubB"
)
- s = fixture_session()
+ with Session(connection) as s:
- s.add_all([ASingleSubA(), ASingleSubB(), AJoinedSubA(), AJoinedSubB()])
- s.commit()
+ s.add_all(
+ [ASingleSubA(), ASingleSubB(), AJoinedSubA(), AJoinedSubB()]
+ )
+ s.commit()
def test_single_invalid_ident(self):
ASingle, ASingleSubA = self.classes("ASingle", "ASingleSubA")
class AttributesTest(fixtures.ORMTest):
- def setup(self):
+ def setup_test(self):
global MyTest, MyTest2
class MyTest(object):
class MyTest2(object):
pass
- def teardown(self):
+ def teardown_test(self):
global MyTest, MyTest2
MyTest, MyTest2 = None, None
class CollectionInitTest(fixtures.TestBase):
- def setUp(self):
+ def setup_test(self):
class A(object):
pass
class TestUnlink(fixtures.TestBase):
- def setUp(self):
+ def setup_test(self):
class A(object):
pass
User, users = self.classes.User, self.tables.users
mapper(User, users)
- c = testing.db.connect()
-
- sess = Session(bind=c, autocommit=False)
- u = User(name="u1")
- sess.add(u)
- sess.flush()
- sess.close()
- assert not c.in_transaction()
- assert c.exec_driver_sql("select count(1) from users").scalar() == 0
-
- sess = Session(bind=c, autocommit=False)
- u = User(name="u2")
- sess.add(u)
- sess.flush()
- sess.commit()
- assert not c.in_transaction()
- assert c.exec_driver_sql("select count(1) from users").scalar() == 1
-
- with c.begin():
- c.exec_driver_sql("delete from users")
- assert c.exec_driver_sql("select count(1) from users").scalar() == 0
-
- c = testing.db.connect()
-
- trans = c.begin()
- sess = Session(bind=c, autocommit=True)
- u = User(name="u3")
- sess.add(u)
- sess.flush()
- assert c.in_transaction()
- trans.commit()
- assert not c.in_transaction()
- assert c.exec_driver_sql("select count(1) from users").scalar() == 1
+ with testing.db.connect() as c:
+
+ sess = Session(bind=c, autocommit=False)
+ u = User(name="u1")
+ sess.add(u)
+ sess.flush()
+ sess.close()
+ assert not c.in_transaction()
+ assert (
+ c.exec_driver_sql("select count(1) from users").scalar() == 0
+ )
+
+ sess = Session(bind=c, autocommit=False)
+ u = User(name="u2")
+ sess.add(u)
+ sess.flush()
+ sess.commit()
+ assert not c.in_transaction()
+ assert (
+ c.exec_driver_sql("select count(1) from users").scalar() == 1
+ )
+
+ with c.begin():
+ c.exec_driver_sql("delete from users")
+ assert (
+ c.exec_driver_sql("select count(1) from users").scalar() == 0
+ )
+
+ with testing.db.connect() as c:
+ trans = c.begin()
+ sess = Session(bind=c, autocommit=True)
+ u = User(name="u3")
+ sess.add(u)
+ sess.flush()
+ assert c.in_transaction()
+ trans.commit()
+ assert not c.in_transaction()
+ assert (
+ c.exec_driver_sql("select count(1) from users").scalar() == 1
+ )
class SessionBindTest(fixtures.MappedTest):
finally:
if hasattr(bind, "close"):
bind.close()
+ sess.close()
def test_session_unbound(self):
Foo = self.classes.Foo
return str((id(self), self.a, self.b, self.c))
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
instrumentation.register_class(cls.Entity)
@classmethod
- def teardown_class(cls):
+ def teardown_test_class(cls):
instrumentation.unregister_class(cls.Entity)
- super(CollectionsTest, cls).teardown_class()
_entity_id = 1
class CompileTest(fixtures.ORMTest):
"""test various mapper compilation scenarios"""
- def teardown(self):
+ def teardown_test(self):
clear_mappers()
def test_with_polymorphic(self):
id = Column(Integer, primary_key=True)
a_id = Column(ForeignKey("a.id", name="a_fk"))
- def setup(self):
- super(PostUpdateOnUpdateTest, self).setup()
+ def setup_test(self):
PostUpdateOnUpdateTest.counter = count()
PostUpdateOnUpdateTest.db_counter = count()
class AutocommitClosesOnFailTest(fixtures.MappedTest):
__requires__ = ("deferrable_fks",)
+ __only_on__ = ("postgresql+psycopg2",) # needs #5824 for asyncpg
@classmethod
def define_tables(cls, metadata):
warnings += (join_aliased_dep,)
# load a user who has an order that contains item id 3 and address
# id 1 (order 3, owned by jack)
- with testing.expect_deprecated_20(*warnings):
- result = (
- fixture_session()
- .query(User)
- .join("orders", "items", aliased=aliased_)
- .filter_by(id=3)
- .reset_joinpoint()
- .join("orders", "address", aliased=aliased_)
- .filter_by(id=1)
- .all()
- )
- assert [User(id=7, name="jack")] == result
- with testing.expect_deprecated_20(*warnings):
- result = (
- fixture_session()
- .query(User)
- .join("orders", "items", aliased=aliased_, isouter=True)
- .filter_by(id=3)
- .reset_joinpoint()
- .join("orders", "address", aliased=aliased_, isouter=True)
- .filter_by(id=1)
- .all()
- )
- assert [User(id=7, name="jack")] == result
-
- with testing.expect_deprecated_20(*warnings):
- result = (
- fixture_session()
- .query(User)
- .outerjoin("orders", "items", aliased=aliased_)
- .filter_by(id=3)
- .reset_joinpoint()
- .outerjoin("orders", "address", aliased=aliased_)
- .filter_by(id=1)
- .all()
- )
- assert [User(id=7, name="jack")] == result
+ with fixture_session() as sess:
+ with testing.expect_deprecated_20(*warnings):
+ result = (
+ sess.query(User)
+ .join("orders", "items", aliased=aliased_)
+ .filter_by(id=3)
+ .reset_joinpoint()
+ .join("orders", "address", aliased=aliased_)
+ .filter_by(id=1)
+ .all()
+ )
+ assert [User(id=7, name="jack")] == result
+
+ with fixture_session() as sess:
+ with testing.expect_deprecated_20(*warnings):
+ result = (
+ sess.query(User)
+ .join(
+ "orders", "items", aliased=aliased_, isouter=True
+ )
+ .filter_by(id=3)
+ .reset_joinpoint()
+ .join(
+ "orders", "address", aliased=aliased_, isouter=True
+ )
+ .filter_by(id=1)
+ .all()
+ )
+ assert [User(id=7, name="jack")] == result
+
+ with fixture_session() as sess:
+ with testing.expect_deprecated_20(*warnings):
+ result = (
+ sess.query(User)
+ .outerjoin("orders", "items", aliased=aliased_)
+ .filter_by(id=3)
+ .reset_joinpoint()
+ .outerjoin("orders", "address", aliased=aliased_)
+ .filter_by(id=1)
+ .all()
+ )
+ assert [User(id=7, name="jack")] == result
class AliasFromCorrectLeftTest(
5,
),
]:
- sess = fixture_session()
+ with fixture_session() as sess:
- def go():
- eq_(
- sess.query(User).options(*opt).order_by(User.id).all(),
- self.static.user_item_keyword_result,
- )
+ def go():
+ eq_(
+ sess.query(User).options(*opt).order_by(User.id).all(),
+ self.static.user_item_keyword_result,
+ )
- self.assert_sql_count(testing.db, go, count)
+ self.assert_sql_count(testing.db, go, count)
def test_disable_dynamic(self):
"""test no joined option on a dynamic."""
class _RemoveListeners(object):
- def teardown(self):
+ @testing.fixture(autouse=True)
+ def _remove_listeners(self):
+ yield
events.MapperEvents._clear()
events.InstanceEvents._clear()
events.SessionEvents._clear()
events.InstrumentationEvents._clear()
events.QueryEvents._clear()
- super(_RemoveListeners, self).teardown()
class ORMExecuteTest(_RemoveListeners, _fixtures.FixtureTest):
argnames="target, event_name, fn",
)(fn)
- def teardown(self):
+ def teardown_test(self):
A = self.classes.A
A._sa_class_manager.dispatch._clear()
]
adalias = addresses.alias()
- q = (
- fixture_session()
- .query(User)
- .add_columns(func.count(adalias.c.id), ("Name:" + users.c.name))
- .outerjoin(adalias, "addresses")
- .group_by(users)
- .order_by(users.c.id)
- )
- assert q.all() == expected
+ with fixture_session() as sess:
+ q = (
+ sess.query(User)
+ .add_columns(
+ func.count(adalias.c.id), ("Name:" + users.c.name)
+ )
+ .outerjoin(adalias, "addresses")
+ .group_by(users)
+ .order_by(users.c.id)
+ )
+
+ eq_(q.all(), expected)
# test with a straight statement
s = (
.group_by(*[c for c in users.c])
.order_by(users.c.id)
)
- q = fixture_session().query(User)
- result = (
- q.add_columns(s.selected_columns.count, s.selected_columns.concat)
- .from_statement(s)
- .all()
- )
- assert result == expected
-
- sess.expunge_all()
- # test with select_entity_from()
- q = (
- fixture_session()
- .query(User)
- .add_columns(func.count(addresses.c.id), ("Name:" + users.c.name))
- .select_entity_from(users.outerjoin(addresses))
- .group_by(users)
- .order_by(users.c.id)
- )
+ with fixture_session() as sess:
+ q = sess.query(User)
+ result = (
+ q.add_columns(
+ s.selected_columns.count, s.selected_columns.concat
+ )
+ .from_statement(s)
+ .all()
+ )
+ eq_(result, expected)
- assert q.all() == expected
- sess.expunge_all()
+ with fixture_session() as sess:
+ # test with select_entity_from()
+ q = (
+ fixture_session()
+ .query(User)
+ .add_columns(
+ func.count(addresses.c.id), ("Name:" + users.c.name)
+ )
+ .select_entity_from(users.outerjoin(addresses))
+ .group_by(users)
+ .order_by(users.c.id)
+ )
- q = (
- fixture_session()
- .query(User)
- .add_columns(func.count(addresses.c.id), ("Name:" + users.c.name))
- .outerjoin("addresses")
- .group_by(users)
- .order_by(users.c.id)
- )
+ eq_(q.all(), expected)
- assert q.all() == expected
- sess.expunge_all()
+ with fixture_session() as sess:
+ q = (
+ sess.query(User)
+ .add_columns(
+ func.count(addresses.c.id), ("Name:" + users.c.name)
+ )
+ .outerjoin("addresses")
+ .group_by(users)
+ .order_by(users.c.id)
+ )
+ eq_(q.all(), expected)
- q = (
- fixture_session()
- .query(User)
- .add_columns(func.count(adalias.c.id), ("Name:" + users.c.name))
- .outerjoin(adalias, "addresses")
- .group_by(users)
- .order_by(users.c.id)
- )
+ with fixture_session() as sess:
+ q = (
+ sess.query(User)
+ .add_columns(
+ func.count(adalias.c.id), ("Name:" + users.c.name)
+ )
+ .outerjoin(adalias, "addresses")
+ .group_by(users)
+ .order_by(users.c.id)
+ )
- assert q.all() == expected
- sess.expunge_all()
+ eq_(q.all(), expected)
def test_expression_selectable_matches_mzero(self):
User, Address = self.classes.User, self.classes.Address
),
)
- sess = fixture_session()
+ with fixture_session() as sess:
- # load address
- a1 = (
- sess.query(Address)
- .filter_by(email_address="ed@wood.com")
- .one()
- )
+ # load address
+ a1 = (
+ sess.query(Address)
+ .filter_by(email_address="ed@wood.com")
+ .one()
+ )
- # load user that is attached to the address
- u1 = sess.query(User).get(8)
+ # load user that is attached to the address
+ u1 = sess.query(User).get(8)
- def go():
- # lazy load of a1.user should get it from the session
- assert a1.user is u1
+ def go():
+ # lazy load of a1.user should get it from the session
+ assert a1.user is u1
- self.assert_sql_count(testing.db, go, 0)
- sa.orm.clear_mappers()
+ self.assert_sql_count(testing.db, go, 0)
+ sa.orm.clear_mappers()
def test_uses_get_compatible_types(self):
"""test the use_get optimization with compatible
properties=dict(user=relationship(mapper(User, users))),
)
- sess = fixture_session()
-
- # load address
- a1 = (
- sess.query(Address)
- .filter_by(email_address="ed@wood.com")
- .one()
- )
+ with fixture_session() as sess:
+ # load address
+ a1 = (
+ sess.query(Address)
+ .filter_by(email_address="ed@wood.com")
+ .one()
+ )
- # load user that is attached to the address
- u1 = sess.query(User).get(8)
+ # load user that is attached to the address
+ u1 = sess.query(User).get(8)
- def go():
- # lazy load of a1.user should get it from the session
- assert a1.user is u1
+ def go():
+ # lazy load of a1.user should get it from the session
+ assert a1.user is u1
- self.assert_sql_count(testing.db, go, 0)
- sa.orm.clear_mappers()
+ self.assert_sql_count(testing.db, go, 0)
+ sa.orm.clear_mappers()
def test_many_to_one(self):
users, Address, addresses, User = (
from sqlalchemy.orm.attributes import instance_state
from sqlalchemy.testing import AssertsExecutionResults
from sqlalchemy.testing import fixtures
+from sqlalchemy.testing.fixtures import fixture_session
from sqlalchemy.testing.schema import Column
-engine = testing.db
-
-
class FlushOnPendingTest(AssertsExecutionResults, fixtures.TestBase):
- def setUp(self):
+ def setup_test(self):
global Parent, Child, Base
Base = declarative_base()
)
parent_id = Column(Integer, ForeignKey("parent.id"))
- Base.metadata.create_all(engine)
+ Base.metadata.create_all(testing.db)
- def tearDown(self):
- Base.metadata.drop_all(engine)
+ def teardown_test(self):
+ Base.metadata.drop_all(testing.db)
def test_annoying_autoflush_one(self):
- sess = Session(engine)
+ sess = fixture_session()
p1 = Parent()
sess.add(p1)
p1.children = []
def test_annoying_autoflush_two(self):
- sess = Session(engine)
+ sess = fixture_session()
p1 = Parent()
sess.add(p1)
assert p1.children == []
def test_dont_load_if_no_keys(self):
- sess = Session(engine)
+ sess = fixture_session()
p1 = Parent()
sess.add(p1)
class LoadOnFKsTest(AssertsExecutionResults, fixtures.TestBase):
- def setUp(self):
+ __leave_connections_for_teardown__ = True
+
+ def setup_test(self):
global Parent, Child, Base
Base = declarative_base()
parent = relationship(Parent, backref=backref("children"))
- Base.metadata.create_all(engine)
+ Base.metadata.create_all(testing.db)
global sess, p1, p2, c1, c2
- sess = Session(bind=engine)
+ sess = Session(bind=testing.db)
p1 = Parent()
p2 = Parent()
sess.commit()
- def tearDown(self):
+ def teardown_test(self):
sess.rollback()
- Base.metadata.drop_all(engine)
+ Base.metadata.drop_all(testing.db)
def test_load_on_pending_allows_backref_event(self):
Child.parent.property.load_on_pending = True
class DocumentTest(fixtures.TestBase):
- def setup(self):
+ def setup_test(self):
self.mapper = registry().map_imperatively
class ORMLoggingTest(_fixtures.FixtureTest):
- def setup(self):
+ def setup_test(self):
self.buf = logging.handlers.BufferingHandler(100)
for log in [logging.getLogger("sqlalchemy.orm")]:
log.addHandler(self.buf)
self.mapper = registry().map_imperatively
- def teardown(self):
+ def teardown_test(self):
for log in [logging.getLogger("sqlalchemy.orm")]:
log.removeHandler(self.buf)
class LocalOptsTest(PathTest, QueryTest):
@classmethod
- def setup_class(cls):
- super(LocalOptsTest, cls).setup_class()
-
+ def setup_test_class(cls):
@strategy_options.loader_option()
def some_col_opt_only(loadopt, key, opts):
return loadopt.set_column_strategy(
[User.orders_syn, Order.items_syn],
[User.orders_syn_2, Order.items_syn],
):
- q = fixture_session().query(User)
- for path in j:
- q = q.join(path)
- q = q.filter_by(id=3)
- result = q.all()
- assert [User(id=7, name="jack"), User(id=9, name="fred")] == result
+ with fixture_session() as sess:
+ q = sess.query(User)
+ for path in j:
+ q = q.join(path)
+ q = q.filter_by(id=3)
+ result = q.all()
+ eq_(
+ result,
+ [
+ User(id=7, name="jack"),
+ User(id=9, name="fred"),
+ ],
+ )
def test_with_parent(self):
Order, User = self.classes.Order, self.classes.User
("name_syn", "orders_syn"),
("name_syn", "orders_syn_2"),
):
- sess = fixture_session()
- q = sess.query(User)
+ with fixture_session() as sess:
+ q = sess.query(User)
- u1 = q.filter_by(**{nameprop: "jack"}).one()
+ u1 = q.filter_by(**{nameprop: "jack"}).one()
- o = sess.query(Order).with_parent(u1, property=orderprop).all()
- assert [
- Order(description="order 1"),
- Order(description="order 3"),
- Order(description="order 5"),
- ] == o
+ o = sess.query(Order).with_parent(u1, property=orderprop).all()
+ assert [
+ Order(description="order 1"),
+ Order(description="order 3"),
+ Order(description="order 5"),
+ ] == o
def test_froms_aliased_col(self):
Address, User = self.classes.Address, self.classes.User
class _JoinFixtures(object):
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
m = MetaData()
cls.left = Table(
"lft",
"""
- def teardown(self):
+ def teardown_test(self):
clear_mappers()
def _fixture_one(
assert_raises(sa.exc.ArgumentError, configure_mappers)
- def teardown(self):
+ def teardown_test(self):
clear_mappers()
class SecondaryArgTest(fixtures.TestBase):
- def teardown(self):
+ def teardown_test(self):
clear_mappers()
@testing.combinations((True,), (False,))
def _do_query_tests(self, opts, count):
Order, User = self.classes.Order, self.classes.User
- sess = fixture_session()
+ with fixture_session() as sess:
- def go():
- eq_(
- sess.query(User).options(*opts).order_by(User.id).all(),
- self.static.user_item_keyword_result,
- )
+ def go():
+ eq_(
+ sess.query(User).options(*opts).order_by(User.id).all(),
+ self.static.user_item_keyword_result,
+ )
- self.assert_sql_count(testing.db, go, count)
+ self.assert_sql_count(testing.db, go, count)
- eq_(
- sess.query(User)
- .options(*opts)
- .filter(User.name == "fred")
- .order_by(User.id)
- .all(),
- self.static.user_item_keyword_result[2:3],
- )
+ eq_(
+ sess.query(User)
+ .options(*opts)
+ .filter(User.name == "fred")
+ .order_by(User.id)
+ .all(),
+ self.static.user_item_keyword_result[2:3],
+ )
- sess = fixture_session()
- eq_(
- sess.query(User)
- .options(*opts)
- .join(User.orders)
- .filter(Order.id == 3)
- .order_by(User.id)
- .all(),
- self.static.user_item_keyword_result[0:1],
- )
+ with fixture_session() as sess:
+ eq_(
+ sess.query(User)
+ .options(*opts)
+ .join(User.orders)
+ .filter(Order.id == 3)
+ .order_by(User.id)
+ .all(),
+ self.static.user_item_keyword_result[0:1],
+ )
def test_cyclical(self):
"""A circular eager relationship breaks the cycle with a lazy loader"""
run_inserts = None
- def setup(self):
+ def setup_test(self):
mapper(self.classes.User, self.tables.users)
def _assert_modified(self, u1):
def _assert_no_cycle(self, u1):
assert sa.orm.attributes.instance_state(u1)._strong_obj is None
- def _persistent_fixture(self):
+ def _persistent_fixture(self, gc_collect=False):
User = self.classes.User
u1 = User()
u1.name = "ed"
- sess = fixture_session()
+ if gc_collect:
+ sess = Session(testing.db)
+ else:
+ sess = fixture_session()
sess.add(u1)
sess.flush()
return sess, u1
@testing.requires.predictable_gc
def test_move_gc_session_persistent_dirty(self):
- sess, u1 = self._persistent_fixture()
+ sess, u1 = self._persistent_fixture(gc_collect=True)
u1.name = "edchanged"
self._assert_cycle(u1)
self._assert_modified(u1)
del sess
gc_collect()
self._assert_cycle(u1)
- s2 = fixture_session()
+ s2 = Session(testing.db)
s2.add(u1)
self._assert_cycle(u1)
self._assert_modified(u1)
mapper(User, users)
- sess = fixture_session()
+ sess = Session(testing.db)
u1 = User(name="u1")
sess.add(u1)
# can't add u1 to Session,
# already belongs to u2
- s2 = fixture_session()
+ s2 = Session(testing.db)
assert_raises_message(
sa.exc.InvalidRequestError,
r".*is already attached to session",
mapper(T, cls.tables.t1)
- def teardown(self):
+ def teardown_test(self):
from sqlalchemy.orm.session import _sessions
_sessions.clear()
- super(DisposedStates, self).teardown()
def _set_imap_in_disposal(self, sess, *objs):
"""remove selected objects from the given session, as though
def _do_query_tests(self, opts, count):
Order, User = self.classes.Order, self.classes.User
- sess = fixture_session()
+ with fixture_session() as sess:
- def go():
- eq_(
- sess.query(User).options(*opts).order_by(User.id).all(),
- self.static.user_item_keyword_result,
- )
+ def go():
+ eq_(
+ sess.query(User).options(*opts).order_by(User.id).all(),
+ self.static.user_item_keyword_result,
+ )
- self.assert_sql_count(testing.db, go, count)
+ self.assert_sql_count(testing.db, go, count)
- eq_(
- sess.query(User)
- .options(*opts)
- .filter(User.name == "fred")
- .order_by(User.id)
- .all(),
- self.static.user_item_keyword_result[2:3],
- )
+ eq_(
+ sess.query(User)
+ .options(*opts)
+ .filter(User.name == "fred")
+ .order_by(User.id)
+ .all(),
+ self.static.user_item_keyword_result[2:3],
+ )
- sess = fixture_session()
- eq_(
- sess.query(User)
- .options(*opts)
- .join(User.orders)
- .filter(Order.id == 3)
- .order_by(User.id)
- .all(),
- self.static.user_item_keyword_result[0:1],
- )
+ with fixture_session() as sess:
+ eq_(
+ sess.query(User)
+ .options(*opts)
+ .join(User.orders)
+ .filter(Order.id == 3)
+ .order_by(User.id)
+ .all(),
+ self.static.user_item_keyword_result[0:1],
+ )
def test_cyclical(self):
"""A circular eager relationship breaks the cycle with a lazy loader"""
users, User = self.tables.users, self.classes.User
mapper(User, users)
- conn = testing.db.connect()
- trans = conn.begin()
- sess = Session(bind=conn, autocommit=False, autoflush=True)
- sess.begin(subtransactions=True)
- u = User(name="ed")
- sess.add(u)
- sess.flush()
- sess.commit() # commit does nothing
- trans.rollback() # rolls back
- assert len(sess.query(User).all()) == 0
- sess.close()
+
+ with testing.db.connect() as conn:
+ trans = conn.begin()
+ sess = Session(bind=conn, autocommit=False, autoflush=True)
+ sess.begin(subtransactions=True)
+ u = User(name="ed")
+ sess.add(u)
+ sess.flush()
+ sess.commit() # commit does nothing
+ trans.rollback() # rolls back
+ assert len(sess.query(User).all()) == 0
+ sess.close()
@engines.close_open_connections
def test_subtransaction_on_external_no_begin(self):
users = self.tables.users
engine = Engine._future_facade(testing.db)
- session = Session(engine, autocommit=False)
-
- session.begin()
- session.connection().execute(users.insert().values(name="user1"))
- session.begin(subtransactions=True)
- session.begin_nested()
- session.connection().execute(users.insert().values(name="user2"))
- assert (
- session.connection()
- .exec_driver_sql("select count(1) from users")
- .scalar()
- == 2
- )
- session.rollback()
- assert (
- session.connection()
- .exec_driver_sql("select count(1) from users")
- .scalar()
- == 1
- )
- session.connection().execute(users.insert().values(name="user3"))
- session.commit()
- assert (
- session.connection()
- .exec_driver_sql("select count(1) from users")
- .scalar()
- == 2
- )
+ with Session(engine, autocommit=False) as session:
+ session.begin()
+ session.connection().execute(users.insert().values(name="user1"))
+ session.begin(subtransactions=True)
+ session.begin_nested()
+ session.connection().execute(users.insert().values(name="user2"))
+ assert (
+ session.connection()
+ .exec_driver_sql("select count(1) from users")
+ .scalar()
+ == 2
+ )
+ session.rollback()
+ assert (
+ session.connection()
+ .exec_driver_sql("select count(1) from users")
+ .scalar()
+ == 1
+ )
+ session.connection().execute(users.insert().values(name="user3"))
+ session.commit()
+ assert (
+ session.connection()
+ .exec_driver_sql("select count(1) from users")
+ .scalar()
+ == 2
+ )
@testing.requires.savepoints
def test_dirty_state_transferred_deep_nesting(self):
mapper(User, users)
- s = Session(testing.db)
- u1 = User(name="u1")
- s.add(u1)
- s.commit()
-
- nt1 = s.begin_nested()
- nt2 = s.begin_nested()
- u1.name = "u2"
- assert attributes.instance_state(u1) not in nt2._dirty
- assert attributes.instance_state(u1) not in nt1._dirty
- s.flush()
- assert attributes.instance_state(u1) in nt2._dirty
- assert attributes.instance_state(u1) not in nt1._dirty
+ with fixture_session() as s:
+ u1 = User(name="u1")
+ s.add(u1)
+ s.commit()
+
+ nt1 = s.begin_nested()
+ nt2 = s.begin_nested()
+ u1.name = "u2"
+ assert attributes.instance_state(u1) not in nt2._dirty
+ assert attributes.instance_state(u1) not in nt1._dirty
+ s.flush()
+ assert attributes.instance_state(u1) in nt2._dirty
+ assert attributes.instance_state(u1) not in nt1._dirty
- s.commit()
- assert attributes.instance_state(u1) in nt2._dirty
- assert attributes.instance_state(u1) in nt1._dirty
+ s.commit()
+ assert attributes.instance_state(u1) in nt2._dirty
+ assert attributes.instance_state(u1) in nt1._dirty
- s.rollback()
- assert attributes.instance_state(u1).expired
- eq_(u1.name, "u1")
+ s.rollback()
+ assert attributes.instance_state(u1).expired
+ eq_(u1.name, "u1")
@testing.requires.savepoints
def test_dirty_state_transferred_deep_nesting_future(self):
mapper(User, users)
- s = Session(testing.db, future=True)
- u1 = User(name="u1")
- s.add(u1)
- s.commit()
-
- nt1 = s.begin_nested()
- nt2 = s.begin_nested()
- u1.name = "u2"
- assert attributes.instance_state(u1) not in nt2._dirty
- assert attributes.instance_state(u1) not in nt1._dirty
- s.flush()
- assert attributes.instance_state(u1) in nt2._dirty
- assert attributes.instance_state(u1) not in nt1._dirty
+ with fixture_session(future=True) as s:
+ u1 = User(name="u1")
+ s.add(u1)
+ s.commit()
+
+ nt1 = s.begin_nested()
+ nt2 = s.begin_nested()
+ u1.name = "u2"
+ assert attributes.instance_state(u1) not in nt2._dirty
+ assert attributes.instance_state(u1) not in nt1._dirty
+ s.flush()
+ assert attributes.instance_state(u1) in nt2._dirty
+ assert attributes.instance_state(u1) not in nt1._dirty
- nt2.commit()
- assert attributes.instance_state(u1) in nt2._dirty
- assert attributes.instance_state(u1) in nt1._dirty
+ nt2.commit()
+ assert attributes.instance_state(u1) in nt2._dirty
+ assert attributes.instance_state(u1) in nt1._dirty
- nt1.rollback()
- assert attributes.instance_state(u1).expired
- eq_(u1.name, "u1")
+ nt1.rollback()
+ assert attributes.instance_state(u1).expired
+ eq_(u1.name, "u1")
@testing.requires.independent_connections
def test_transactions_isolated(self):
mapper(User, users)
- session = Session(testing.db)
+ with fixture_session() as session:
- with expect_warnings(".*during handling of a previous exception.*"):
- session.begin_nested()
- savepoint = session.connection()._nested_transaction._savepoint
+ with expect_warnings(
+ ".*during handling of a previous exception.*"
+ ):
+ session.begin_nested()
+ savepoint = session.connection()._nested_transaction._savepoint
- # force the savepoint to disappear
- session.connection().dialect.do_release_savepoint(
- session.connection(), savepoint
- )
+ # force the savepoint to disappear
+ session.connection().dialect.do_release_savepoint(
+ session.connection(), savepoint
+ )
- # now do a broken flush
- session.add_all([User(id=1), User(id=1)])
+ # now do a broken flush
+ session.add_all([User(id=1), User(id=1)])
- assert_raises_message(
- sa_exc.DBAPIError, "ROLLBACK TO SAVEPOINT ", session.flush
- )
+ assert_raises_message(
+ sa_exc.DBAPIError, "ROLLBACK TO SAVEPOINT ", session.flush
+ )
class _LocalFixture(FixtureTest):
def test_recipe_heavy_nesting(self, subtransaction_recipe):
users = self.tables.users
- session = Session(testing.db, future=self.future)
-
- with subtransaction_recipe(session):
- session.connection().execute(users.insert().values(name="user1"))
+ with fixture_session(future=self.future) as session:
with subtransaction_recipe(session):
- savepoint = session.begin_nested()
session.connection().execute(
- users.insert().values(name="user2")
+ users.insert().values(name="user1")
)
+ with subtransaction_recipe(session):
+ savepoint = session.begin_nested()
+ session.connection().execute(
+ users.insert().values(name="user2")
+ )
+ assert (
+ session.connection()
+ .exec_driver_sql("select count(1) from users")
+ .scalar()
+ == 2
+ )
+ savepoint.rollback()
+
+ with subtransaction_recipe(session):
+ assert (
+ session.connection()
+ .exec_driver_sql("select count(1) from users")
+ .scalar()
+ == 1
+ )
+ session.connection().execute(
+ users.insert().values(name="user3")
+ )
assert (
session.connection()
.exec_driver_sql("select count(1) from users")
.scalar()
== 2
)
- savepoint.rollback()
-
- with subtransaction_recipe(session):
- assert (
- session.connection()
- .exec_driver_sql("select count(1) from users")
- .scalar()
- == 1
- )
- session.connection().execute(
- users.insert().values(name="user3")
- )
- assert (
- session.connection()
- .exec_driver_sql("select count(1) from users")
- .scalar()
- == 2
- )
@engines.close_open_connections
def test_recipe_subtransaction_on_external_subtrans(
User, users = self.classes.User, self.tables.users
mapper(User, users)
- sess = Session(testing.db, future=self.future)
-
- with subtransaction_recipe(sess):
- u = User(name="u1")
- sess.add(u)
- sess.close()
- assert len(sess.query(User).all()) == 1
+ with fixture_session(future=self.future) as sess:
+ with subtransaction_recipe(sess):
+ u = User(name="u1")
+ sess.add(u)
+ sess.close()
+ assert len(sess.query(User).all()) == 1
def test_recipe_subtransaction_on_noautocommit(
self, subtransaction_recipe
User, users = self.classes.User, self.tables.users
mapper(User, users)
- sess = Session(testing.db, future=self.future)
-
- sess.begin()
- with subtransaction_recipe(sess):
- u = User(name="u1")
- sess.add(u)
- sess.flush()
- sess.rollback() # rolls back
- assert len(sess.query(User).all()) == 0
- sess.close()
+ with fixture_session(future=self.future) as sess:
+ sess.begin()
+ with subtransaction_recipe(sess):
+ u = User(name="u1")
+ sess.add(u)
+ sess.flush()
+ sess.rollback() # rolls back
+ assert len(sess.query(User).all()) == 0
+ sess.close()
@testing.requires.savepoints
def test_recipe_mixed_transaction_control(self, subtransaction_recipe):
mapper(User, users)
- sess = Session(testing.db, future=self.future)
+ with fixture_session(future=self.future) as sess:
- sess.begin()
- sess.begin_nested()
+ sess.begin()
+ sess.begin_nested()
- with subtransaction_recipe(sess):
+ with subtransaction_recipe(sess):
- sess.add(User(name="u1"))
+ sess.add(User(name="u1"))
- sess.commit()
- sess.commit()
+ sess.commit()
+ sess.commit()
- eq_(len(sess.query(User).all()), 1)
- sess.close()
+ eq_(len(sess.query(User).all()), 1)
+ sess.close()
- t1 = sess.begin()
- t2 = sess.begin_nested()
-
- sess.add(User(name="u2"))
+ t1 = sess.begin()
+ t2 = sess.begin_nested()
- t2.commit()
- assert sess._legacy_transaction() is t1
+ sess.add(User(name="u2"))
- sess.close()
+ t2.commit()
+ assert sess._legacy_transaction() is t1
def test_recipe_error_on_using_inactive_session_commands(
self, subtransaction_recipe
users, User = self.tables.users, self.classes.User
mapper(User, users)
- sess = Session(testing.db, future=self.future)
- sess.begin()
-
- try:
- with subtransaction_recipe(sess):
- sess.add(User(name="u1"))
- sess.flush()
- raise Exception("force rollback")
- except:
- pass
-
- if self.recipe_rollsback_early:
- # that was a real rollback, so no transaction
- assert not sess.in_transaction()
- is_(sess.get_transaction(), None)
- else:
- assert sess.in_transaction()
-
- sess.close()
- assert not sess.in_transaction()
-
- def test_recipe_multi_nesting(self, subtransaction_recipe):
- sess = Session(testing.db, future=self.future)
-
- with subtransaction_recipe(sess):
- assert sess.in_transaction()
+ with fixture_session(future=self.future) as sess:
+ sess.begin()
try:
with subtransaction_recipe(sess):
- assert sess._legacy_transaction()
+ sess.add(User(name="u1"))
+ sess.flush()
raise Exception("force rollback")
except:
pass
if self.recipe_rollsback_early:
+ # that was a real rollback, so no transaction
assert not sess.in_transaction()
+ is_(sess.get_transaction(), None)
else:
assert sess.in_transaction()
- assert not sess.in_transaction()
+ sess.close()
+ assert not sess.in_transaction()
+
+ def test_recipe_multi_nesting(self, subtransaction_recipe):
+ with fixture_session(future=self.future) as sess:
+ with subtransaction_recipe(sess):
+ assert sess.in_transaction()
+
+ try:
+ with subtransaction_recipe(sess):
+ assert sess._legacy_transaction()
+ raise Exception("force rollback")
+ except:
+ pass
+
+ if self.recipe_rollsback_early:
+ assert not sess.in_transaction()
+ else:
+ assert sess.in_transaction()
+
+ assert not sess.in_transaction()
def test_recipe_deactive_status_check(self, subtransaction_recipe):
- sess = Session(testing.db, future=self.future)
- sess.begin()
+ with fixture_session(future=self.future) as sess:
+ sess.begin()
- with subtransaction_recipe(sess):
- sess.rollback()
+ with subtransaction_recipe(sess):
+ sess.rollback()
- assert not sess.in_transaction()
- sess.commit() # no error
+ assert not sess.in_transaction()
+ sess.commit() # no error
class FixtureDataTest(_LocalFixture):
mapper(User, users)
- s = Session(bind=testing.db, future=future)
- u1 = User(name="u1")
- u2 = User(name="u2")
- s.add_all([u1, u2])
- s.commit()
- u1.name
- u2.name
- trans = s._transaction
- assert trans is not None
- s.begin_nested()
- update_fn(s, u2)
- eq_(u2.name, "u2modified")
- s.rollback()
+ with fixture_session(future=future) as s:
+ u1 = User(name="u1")
+ u2 = User(name="u2")
+ s.add_all([u1, u2])
+ s.commit()
+ u1.name
+ u2.name
+ trans = s._transaction
+ assert trans is not None
+ s.begin_nested()
+ update_fn(s, u2)
+ eq_(u2.name, "u2modified")
+ s.rollback()
- if future:
- assert s._transaction is None
- assert "name" not in u1.__dict__
- else:
- assert s._transaction is trans
- eq_(u1.__dict__["name"], "u1")
- assert "name" not in u2.__dict__
- eq_(u2.name, "u2")
+ if future:
+ assert s._transaction is None
+ assert "name" not in u1.__dict__
+ else:
+ assert s._transaction is trans
+ eq_(u1.__dict__["name"], "u1")
+ assert "name" not in u2.__dict__
+ eq_(u2.name, "u2")
@testing.requires.savepoints
def test_rollback_ignores_clean_on_savepoint(self):
eq_(sess.query(User).count(), 1)
def test_explicit_begin(self):
- s1 = Session(testing.db)
- with s1.begin() as trans:
- is_(trans, s1._legacy_transaction())
- s1.connection()
+ with fixture_session() as s1:
+ with s1.begin() as trans:
+ is_(trans, s1._legacy_transaction())
+ s1.connection()
- is_(s1._transaction, None)
+ is_(s1._transaction, None)
def test_no_double_begin_explicit(self):
- s1 = Session(testing.db)
- s1.begin()
- assert_raises_message(
- sa_exc.InvalidRequestError,
- "A transaction is already begun on this Session.",
- s1.begin,
- )
+ with fixture_session() as s1:
+ s1.begin()
+ assert_raises_message(
+ sa_exc.InvalidRequestError,
+ "A transaction is already begun on this Session.",
+ s1.begin,
+ )
@testing.requires.savepoints
def test_future_rollback_is_global(self):
users = self.tables.users
- s1 = Session(testing.db, future=True)
+ with fixture_session(future=True) as s1:
+ s1.begin()
- s1.begin()
+ s1.connection().execute(users.insert(), [{"id": 1, "name": "n1"}])
- s1.connection().execute(users.insert(), [{"id": 1, "name": "n1"}])
+ s1.begin_nested()
- s1.begin_nested()
-
- s1.connection().execute(
- users.insert(), [{"id": 2, "name": "n2"}, {"id": 3, "name": "n3"}]
- )
+ s1.connection().execute(
+ users.insert(),
+ [{"id": 2, "name": "n2"}, {"id": 3, "name": "n3"}],
+ )
- eq_(s1.connection().scalar(select(func.count()).select_from(users)), 3)
+ eq_(
+ s1.connection().scalar(
+ select(func.count()).select_from(users)
+ ),
+ 3,
+ )
- # rolls back the whole transaction
- s1.rollback()
- is_(s1._legacy_transaction(), None)
+ # rolls back the whole transaction
+ s1.rollback()
+ is_(s1._legacy_transaction(), None)
- eq_(s1.connection().scalar(select(func.count()).select_from(users)), 0)
+ eq_(
+ s1.connection().scalar(
+ select(func.count()).select_from(users)
+ ),
+ 0,
+ )
- s1.commit()
- is_(s1._legacy_transaction(), None)
+ s1.commit()
+ is_(s1._legacy_transaction(), None)
@testing.requires.savepoints
def test_old_rollback_is_local(self):
users = self.tables.users
- s1 = Session(testing.db)
+ with fixture_session() as s1:
- t1 = s1.begin()
+ t1 = s1.begin()
- s1.connection().execute(users.insert(), [{"id": 1, "name": "n1"}])
+ s1.connection().execute(users.insert(), [{"id": 1, "name": "n1"}])
- s1.begin_nested()
+ s1.begin_nested()
- s1.connection().execute(
- users.insert(), [{"id": 2, "name": "n2"}, {"id": 3, "name": "n3"}]
- )
+ s1.connection().execute(
+ users.insert(),
+ [{"id": 2, "name": "n2"}, {"id": 3, "name": "n3"}],
+ )
- eq_(s1.connection().scalar(select(func.count()).select_from(users)), 3)
+ eq_(
+ s1.connection().scalar(
+ select(func.count()).select_from(users)
+ ),
+ 3,
+ )
- # rolls back only the savepoint
- s1.rollback()
+ # rolls back only the savepoint
+ s1.rollback()
- is_(s1._legacy_transaction(), t1)
+ is_(s1._legacy_transaction(), t1)
- eq_(s1.connection().scalar(select(func.count()).select_from(users)), 1)
+ eq_(
+ s1.connection().scalar(
+ select(func.count()).select_from(users)
+ ),
+ 1,
+ )
- s1.commit()
- eq_(s1.connection().scalar(select(func.count()).select_from(users)), 1)
- is_not(s1._legacy_transaction(), None)
+ s1.commit()
+ eq_(
+ s1.connection().scalar(
+ select(func.count()).select_from(users)
+ ),
+ 1,
+ )
+ is_not(s1._legacy_transaction(), None)
def test_session_as_ctx_manager_one(self):
users = self.tables.users
- with Session(testing.db) as sess:
+ with fixture_session() as sess:
is_not(sess._legacy_transaction(), None)
sess.connection().execute(
def test_session_as_ctx_manager_future_one(self):
users = self.tables.users
- with Session(testing.db, future=True) as sess:
+ with fixture_session(future=True) as sess:
is_(sess._legacy_transaction(), None)
sess.connection().execute(
users = self.tables.users
try:
- with Session(testing.db) as sess:
+ with fixture_session() as sess:
is_not(sess._legacy_transaction(), None)
sess.connection().execute(
users = self.tables.users
try:
- with Session(testing.db, future=True) as sess:
+ with fixture_session(future=True) as sess:
is_(sess._legacy_transaction(), None)
sess.connection().execute(
def test_begin_context_manager(self):
users = self.tables.users
- with Session(testing.db) as sess:
+ with fixture_session() as sess:
with sess.begin():
sess.connection().execute(
users.insert().values(id=1, name="user1")
# committed
eq_(sess.connection().execute(users.select()).all(), [(1, "user1")])
+ sess.close()
def test_begin_context_manager_rollback_trans(self):
users = self.tables.users
try:
- with Session(testing.db) as sess:
+ with fixture_session() as sess:
with sess.begin():
sess.connection().execute(
users.insert().values(id=1, name="user1")
# rolled back
eq_(sess.connection().execute(users.select()).all(), [])
+ sess.close()
def test_begin_context_manager_rollback_outer(self):
users = self.tables.users
try:
- with Session(testing.db) as sess:
+ with fixture_session() as sess:
with sess.begin():
sess.connection().execute(
users.insert().values(id=1, name="user1")
# committed
eq_(sess.connection().execute(users.select()).all(), [(1, "user1")])
+ sess.close()
def test_sessionmaker_begin_context_manager_rollback_trans(self):
users = self.tables.users
# rolled back
eq_(sess.connection().execute(users.select()).all(), [])
+ sess.close()
def test_sessionmaker_begin_context_manager_rollback_outer(self):
users = self.tables.users
# committed
eq_(sess.connection().execute(users.select()).all(), [(1, "user1")])
+ sess.close()
class TransactionFlagsTest(fixtures.TestBase):
def test_in_transaction(self):
- s1 = Session(testing.db)
+ with fixture_session() as s1:
- eq_(s1.in_transaction(), False)
+ eq_(s1.in_transaction(), False)
- trans = s1.begin()
+ trans = s1.begin()
- eq_(s1.in_transaction(), True)
- is_(s1.get_transaction(), trans)
+ eq_(s1.in_transaction(), True)
+ is_(s1.get_transaction(), trans)
- n1 = s1.begin_nested()
+ n1 = s1.begin_nested()
- eq_(s1.in_transaction(), True)
- is_(s1.get_transaction(), trans)
- is_(s1.get_nested_transaction(), n1)
+ eq_(s1.in_transaction(), True)
+ is_(s1.get_transaction(), trans)
+ is_(s1.get_nested_transaction(), n1)
- n1.rollback()
+ n1.rollback()
- is_(s1.get_nested_transaction(), None)
- is_(s1.get_transaction(), trans)
+ is_(s1.get_nested_transaction(), None)
+ is_(s1.get_transaction(), trans)
- eq_(s1.in_transaction(), True)
+ eq_(s1.in_transaction(), True)
- s1.commit()
+ s1.commit()
- eq_(s1.in_transaction(), False)
- is_(s1.get_transaction(), None)
+ eq_(s1.in_transaction(), False)
+ is_(s1.get_transaction(), None)
def test_in_transaction_subtransactions(self):
"""we'd like to do away with subtransactions for future sessions
the external API works.
"""
- s1 = Session(testing.db)
-
- eq_(s1.in_transaction(), False)
+ with fixture_session() as s1:
+ eq_(s1.in_transaction(), False)
- trans = s1.begin()
+ trans = s1.begin()
- eq_(s1.in_transaction(), True)
- is_(s1.get_transaction(), trans)
+ eq_(s1.in_transaction(), True)
+ is_(s1.get_transaction(), trans)
- subtrans = s1.begin(_subtrans=True)
- is_(s1.get_transaction(), trans)
- eq_(s1.in_transaction(), True)
+ subtrans = s1.begin(_subtrans=True)
+ is_(s1.get_transaction(), trans)
+ eq_(s1.in_transaction(), True)
- is_(s1._transaction, subtrans)
+ is_(s1._transaction, subtrans)
- s1.rollback()
+ s1.rollback()
- eq_(s1.in_transaction(), True)
- is_(s1._transaction, trans)
+ eq_(s1.in_transaction(), True)
+ is_(s1._transaction, trans)
- s1.rollback()
+ s1.rollback()
- eq_(s1.in_transaction(), False)
- is_(s1._transaction, None)
+ eq_(s1.in_transaction(), False)
+ is_(s1._transaction, None)
def test_in_transaction_nesting(self):
- s1 = Session(testing.db)
+ with fixture_session() as s1:
- eq_(s1.in_transaction(), False)
+ eq_(s1.in_transaction(), False)
- trans = s1.begin()
+ trans = s1.begin()
- eq_(s1.in_transaction(), True)
- is_(s1.get_transaction(), trans)
+ eq_(s1.in_transaction(), True)
+ is_(s1.get_transaction(), trans)
- sp1 = s1.begin_nested()
+ sp1 = s1.begin_nested()
- eq_(s1.in_transaction(), True)
- is_(s1.get_transaction(), trans)
- is_(s1.get_nested_transaction(), sp1)
+ eq_(s1.in_transaction(), True)
+ is_(s1.get_transaction(), trans)
+ is_(s1.get_nested_transaction(), sp1)
- sp2 = s1.begin_nested()
+ sp2 = s1.begin_nested()
- eq_(s1.in_transaction(), True)
- eq_(s1.in_nested_transaction(), True)
- is_(s1.get_transaction(), trans)
- is_(s1.get_nested_transaction(), sp2)
+ eq_(s1.in_transaction(), True)
+ eq_(s1.in_nested_transaction(), True)
+ is_(s1.get_transaction(), trans)
+ is_(s1.get_nested_transaction(), sp2)
- sp2.rollback()
+ sp2.rollback()
- eq_(s1.in_transaction(), True)
- eq_(s1.in_nested_transaction(), True)
- is_(s1.get_transaction(), trans)
- is_(s1.get_nested_transaction(), sp1)
+ eq_(s1.in_transaction(), True)
+ eq_(s1.in_nested_transaction(), True)
+ is_(s1.get_transaction(), trans)
+ is_(s1.get_nested_transaction(), sp1)
- sp1.rollback()
+ sp1.rollback()
- is_(s1.get_nested_transaction(), None)
- eq_(s1.in_transaction(), True)
- eq_(s1.in_nested_transaction(), False)
- is_(s1.get_transaction(), trans)
+ is_(s1.get_nested_transaction(), None)
+ eq_(s1.in_transaction(), True)
+ eq_(s1.in_nested_transaction(), False)
+ is_(s1.get_transaction(), trans)
- s1.rollback()
+ s1.rollback()
- eq_(s1.in_transaction(), False)
- is_(s1.get_transaction(), None)
+ eq_(s1.in_transaction(), False)
+ is_(s1.get_transaction(), None)
class NaturalPKRollbackTest(fixtures.MappedTest):
class JoinIntoAnExternalTransactionFixture(object):
"""Test the "join into an external transaction" examples"""
- def setup(self):
- self.connection = testing.db.connect()
+ __leave_connections_for_teardown__ = True
+
+ def setup_test(self):
+ self.engine = testing.db
+ self.connection = self.engine.connect()
self.metadata = MetaData()
self.table = Table(
self.setup_session()
+ def teardown_test(self):
+ self.teardown_session()
+
+ with self.connection.begin():
+ self._assert_count(0)
+
+ with self.connection.begin():
+ self.table.drop(self.connection)
+
+ self.connection.close()
+
def test_something(self):
A = self.A
)
eq_(result, count)
- def teardown(self):
- self.teardown_session()
-
- with self.connection.begin():
- self._assert_count(0)
-
- with self.connection.begin():
- self.table.drop(self.connection)
-
- # return connection to the Engine
- self.connection.close()
-
class NewStyleJoinIntoAnExternalTransactionTest(
JoinIntoAnExternalTransactionFixture
# rollback - everything that happened with the
# Session above (including calls to commit())
# is rolled back.
- self.trans.rollback()
+ if self.trans.is_active:
+ self.trans.rollback()
class FutureJoinIntoAnExternalTransactionTest(
cls.tables["t1"] = t1
cls.tables["t2"] = t2
- @classmethod
- def setup_class(cls):
- super(UnicodeSchemaTest, cls).setup_class()
-
- @classmethod
- def teardown_class(cls):
- super(UnicodeSchemaTest, cls).teardown_class()
-
def test_mapping(self):
t2, t1 = self.tables.t2, self.tables.t1
class SingleCycleTest(UOWTest):
- def teardown(self):
+ def teardown_test(self):
engines.testing_reaper.rollback_all()
# mysql can't handle delete from nodes
# since it doesn't deal with the FKs correctly,
# so wipe out the parent_id first
with testing.db.begin() as conn:
conn.execute(self.tables.nodes.update().values(parent_id=None))
- super(SingleCycleTest, self).teardown()
def test_one_to_many_save(self):
Node, nodes = self.classes.Node, self.tables.nodes
@property
def postgresql_utf8_server_encoding(self):
+ def go(config):
+ if not against(config, "postgresql"):
+ return False
- return only_if(
- lambda config: against(config, "postgresql")
- and config.db.connect(close_with_result=True)
- .exec_driver_sql("show server_encoding")
- .scalar()
- .lower()
- == "utf8"
- )
+ with config.db.connect() as conn:
+ enc = conn.exec_driver_sql("show server_encoding").scalar()
+ return enc.lower() == "utf8"
+
+ return only_if(go)
@property
def cxoracle6_or_greater(self):
__dialect__ = "default"
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
metadata = MetaData()
global info_table
info_table = Table(
)
@classmethod
- def teardown_class(cls):
+ def teardown_test_class(cls):
with testing.db.begin() as conn:
info_table.drop(conn)
class CompareAndCopyTest(CoreFixtures, fixtures.TestBase):
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
# TODO: we need to get dialects here somehow, perhaps in test_suite?
[
importlib.import_module("sqlalchemy.dialects.%s" % d)
class KwargPropagationTest(fixtures.TestBase):
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
from sqlalchemy.sql.expression import ColumnClause, TableClause
class CatchCol(ColumnClause):
Column("col11", MyType(), default="foo"),
)
- def teardown(self):
+ def teardown_test(self):
self.default_generator["x"] = 50
- super(DefaultRoundTripTest, self).teardown()
def test_standalone(self, connection):
t = self.tables.default_test
__backend__ = True
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
class MyInteger(TypeDecorator):
impl = Integer
__backend__ = True
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
cls.seq = Sequence("my_sequence")
cls.seq.create(testing.db)
@classmethod
- def teardown_class(cls):
+ def teardown_test_class(cls):
cls.seq.drop(testing.db)
def _assert_seq_result(self, ret):
class DDLDeprecatedBindTest(fixtures.TestBase):
- def teardown(self):
+ def teardown_test(self):
with testing.db.begin() as conn:
if inspect(conn).has_table("foo"):
conn.execute(schema.DropTable(table("foo")))
ability to copy and modify a ClauseElement in place."""
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
global A, B
# establish two fictitious ClauseElements.
__dialect__ = "default"
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
global t1, t2, t3
t1 = table("table1", column("col1"), column("col2"), column("col3"))
t2 = table("table2", column("col1"), column("col2"), column("col3"))
__dialect__ = "default"
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
global t1, t2
t1 = table(
"table1",
__dialect__ = "default"
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
global t1, t2
t1 = table("table1", column("col1"), column("col2"), column("col3"))
t2 = table("table2", column("col1"), column("col2"), column("col3"))
__dialect__ = "default"
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
global table1, table2, table3, table4
def _table(name):
__dialect__ = "default"
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
global t1, t2
t1 = table("table1", column("col1"), column("col2"), column("col3"))
t2 = table("table2", column("col1"), column("col2"), column("col3"))
# fixme: consolidate converage from elsewhere here and expand
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
global t1, t2
t1 = table("table1", column("col1"), column("col2"), column("col3"))
t2 = table("table2", column("col1"), column("col2"), column("col3"))
Table("table_c", metadata, Column("col_c", Integer, primary_key=True))
Table("table_d", metadata, Column("col_d", Integer, primary_key=True))
- def setup(self):
+ def setup_test(self):
self.a = self.tables.table_a
self.b = self.tables.table_b
self.c = self.tables.table_c
with self.bind.connect() as conn:
conn.execute(query)
- def test_no_linting(self):
- eng = engines.testing_engine(options={"enable_from_linting": False})
+ def test_no_linting(self, metadata, connection):
+ eng = engines.testing_engine(
+ options={"enable_from_linting": False, "use_reaper": False}
+ )
eng.pool = self.bind.pool # needed for SQLite
a, b = self.tables("table_a", "table_b")
query = select(a.c.col_a).where(b.c.col_b == 5)
class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
__dialect__ = "default"
- def setup(self):
+ def setup_test(self):
self._registry = deepcopy(functions._registry)
- def teardown(self):
+ def teardown_test(self):
functions._registry = self._registry
def test_compile(self):
class ExecuteTest(fixtures.TestBase):
__backend__ = True
- def tearDown(self):
+ def teardown_test(self):
pass
def test_conn_execute(self, connection):
class RegisterTest(fixtures.TestBase, AssertsCompiledSQL):
__dialect__ = "default"
- def setup(self):
+ def setup_test(self):
self._registry = deepcopy(functions._registry)
- def teardown(self):
+ def teardown_test(self):
functions._registry = self._registry
def test_GenericFunction_is_registered(self):
with mock.patch("sqlalchemy.dialects.registry.load", load):
yield
- def teardown(self):
+ def teardown_test(self):
Index._kw_registry.clear()
def test_participating(self):
class JSONIndexOpTest(fixtures.TestBase, testing.AssertsCompiledSQL):
- def setUp(self):
+ def setup_test(self):
class MyTypeCompiler(compiler.GenericTypeCompiler):
def visit_mytype(self, type_, **kw):
return "MYTYPE"
class ArrayIndexOpTest(fixtures.TestBase, testing.AssertsCompiledSQL):
- def setUp(self):
+ def setup_test(self):
class MyTypeCompiler(compiler.GenericTypeCompiler):
def visit_mytype(self, type_, **kw):
return "MYTYPE"
class RegexpTest(fixtures.TestBase, testing.AssertsCompiledSQL):
__dialect__ = "default"
- def setUp(self):
+ def setup_test(self):
self.table = table(
"mytable", column("myid", Integer), column("name", String)
)
class RegexpTestStrCompiler(fixtures.TestBase, testing.AssertsCompiledSQL):
__dialect__ = "default_enhanced"
- def setUp(self):
+ def setup_test(self):
self.table = table(
"mytable", column("myid", Integer), column("name", String)
)
assert_raises(KeyError, lambda: row._mapping["Case_insensitive"])
assert_raises(KeyError, lambda: row._mapping["casesensitive"])
- def test_row_case_sensitive_unoptimized(self):
- with engines.testing_engine().connect() as ins_conn:
+ def test_row_case_sensitive_unoptimized(self, testing_engine):
+ with testing_engine().connect() as ins_conn:
row = ins_conn.execute(
select(
literal_column("1").label("case_insensitive"),
eq_(proxy[0], "value")
eq_(proxy._mapping["key"], "value")
- @testing.provide_metadata
- def test_no_rowcount_on_selects_inserts(self):
+ def test_no_rowcount_on_selects_inserts(self, metadata, testing_engine):
"""assert that rowcount is only called on deletes and updates.
This because cursor.rowcount may can be expensive on some dialects
"""
- metadata = self.metadata
-
- engine = engines.testing_engine()
+ engine = testing_engine()
t = Table("t1", metadata, Column("data", String(10)))
metadata.create_all(engine)
@classmethod
def setup_bind(cls):
- cls.engine = engine = engines.testing_engine("sqlite://")
+ cls.engine = engine = engines.testing_engine(
+ "sqlite://", options={"scope": "class"}
+ )
return engine
@classmethod
__backend__ = True
@classmethod
- def setup_class(cls):
+ def setup_test_class(cls):
cls.seq = Sequence("my_sequence")
cls.seq.create(testing.db)
@classmethod
- def teardown_class(cls):
+ def teardown_test_class(cls):
cls.seq.drop(testing.db)
def _assert_seq_result(self, ret):
class VariantTest(fixtures.TestBase, AssertsCompiledSQL):
- def setup(self):
+ def setup_test(self):
class UTypeOne(types.UserDefinedType):
def get_col_spec(self):
return "UTYPEONE"
class JSONTest(fixtures.TestBase):
- def setup(self):
+ def setup_test(self):
metadata = MetaData()
self.test_table = Table(
"test_table",
@testing.requires.non_native_boolean_unconstrained
def test_constraint(self, connection):
assert_raises(
- (exc.IntegrityError, exc.ProgrammingError, exc.OperationalError),
+ (
+ exc.IntegrityError,
+ exc.ProgrammingError,
+ exc.OperationalError,
+ exc.InternalError, # older pymysql's do this
+ ),
connection.exec_driver_sql,
"insert into boolean_table (id, value) values(1, 5)",
)
usedevelop=
cov: True
-deps=pytest>=4.6.11 # this can be 6.x once we are on python 3 only
+deps=
+ pytest>=4.6.11,<5.0; python_version < '3'
+ pytest>=6.2; python_version >= '3'
pytest-xdist
greenlet != 0.4.17
mock; python_version < '3.3'
sqlite_file: SQLITE={env:TOX_SQLITE_FILE:--db sqlite_file}
postgresql: POSTGRESQL={env:TOX_POSTGRESQL:--db postgresql}
+ py2{,7}-postgresql: POSTGRESQL={env:TOX_POSTGRESQL_PY2K:{env:TOX_POSTGRESQL:--db postgresql}}
py3{,5,6,7,8,9,10,11}-postgresql: EXTRA_PG_DRIVERS={env:EXTRA_PG_DRIVERS:--dbdriver psycopg2 --dbdriver asyncpg --dbdriver pg8000}
mysql: MYSQL={env:TOX_MYSQL:--db mysql}
+ py2{,7}-mysql: MYSQL={env:TOX_MYSQL_PY2K:{env:TOX_MYSQL:--db mysql}}
mysql: EXTRA_MYSQL_DRIVERS={env:EXTRA_MYSQL_DRIVERS:--dbdriver mysqldb --dbdriver pymysql}
py3{,5,6,7,8,9,10,11}-mysql: EXTRA_MYSQL_DRIVERS={env:EXTRA_MYSQL_DRIVERS:--dbdriver mysqldb --dbdriver pymysql --dbdriver mariadbconnector --dbdriver aiomysql}
# tox as of 2.0 blocks all environment variables from the
# outside, unless they are here (or in TOX_TESTENV_PASSENV,
# wildcards OK). Need at least these
-passenv=ORACLE_HOME NLS_LANG TOX_POSTGRESQL TOX_MYSQL TOX_ORACLE TOX_MSSQL TOX_SQLITE TOX_SQLITE_FILE TOX_WORKERS EXTRA_PG_DRIVERS EXTRA_MYSQL_DRIVERS
+passenv=ORACLE_HOME NLS_LANG TOX_POSTGRESQL TOX_POSTGRESQL_PY2K TOX_MYSQL TOX_MYSQL_PY2K TOX_ORACLE TOX_MSSQL TOX_SQLITE TOX_SQLITE_FILE TOX_WORKERS EXTRA_PG_DRIVERS EXTRA_MYSQL_DRIVERS
# for nocext, we rm *.so in lib in case we are doing usedevelop=True
commands=