]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Improve typing tests
authorFederico Caselli <cfederico87@gmail.com>
Fri, 23 Jun 2023 17:58:54 +0000 (19:58 +0200)
committerFederico Caselli <cfederico87@gmail.com>
Tue, 27 Jun 2023 21:50:25 +0000 (23:50 +0200)
Extract a fixture to run mypy on files
Move the plain files to test/typing
Move test files from stubs repository
Transform the fixture module in a package

Change-Id: I23acaecb84e7c4b9010259d44395dc1df83a9385

106 files changed:
lib/sqlalchemy/orm/writeonly.py
lib/sqlalchemy/testing/config.py
lib/sqlalchemy/testing/fixtures.py [deleted file]
lib/sqlalchemy/testing/fixtures/__init__.py [new file with mode: 0644]
lib/sqlalchemy/testing/fixtures/base.py [new file with mode: 0644]
lib/sqlalchemy/testing/fixtures/mypy.py [new file with mode: 0644]
lib/sqlalchemy/testing/fixtures/orm.py [new file with mode: 0644]
lib/sqlalchemy/testing/fixtures/sql.py [new file with mode: 0644]
lib/sqlalchemy/testing/pickleable.py
lib/sqlalchemy/testing/plugin/plugin_base.py
lib/sqlalchemy/testing/plugin/pytestplugin.py
lib/sqlalchemy/testing/util.py
setup.cfg
test/aaa_profiling/test_memusage.py
test/ext/asyncio/test_session_py3k.py
test/ext/declarative/test_inheritance.py
test/ext/declarative/test_reflection.py
test/ext/mypy/plain_files/core_ddl.py [deleted file]
test/ext/mypy/test_mypy_plugin_py3k.py
test/ext/test_associationproxy.py
test/ext/test_indexable.py
test/ext/test_mutable.py
test/ext/test_serializer.py
test/orm/declarative/test_basic.py
test/orm/declarative/test_inheritance.py
test/orm/declarative/test_reflection.py
test/orm/dml/test_bulk_statements.py
test/orm/inheritance/_poly_fixtures.py
test/orm/inheritance/test_abc_polymorphic.py
test/orm/inheritance/test_assorted_poly.py
test/orm/inheritance/test_basic.py
test/orm/inheritance/test_poly_loading.py
test/orm/inheritance/test_poly_persistence.py
test/orm/inheritance/test_relationship.py
test/orm/inheritance/test_selects.py
test/orm/inheritance/test_single.py
test/orm/test_ac_relationships.py
test/orm/test_attributes.py
test/orm/test_cascade.py
test/orm/test_defaults.py
test/orm/test_deferred.py
test/orm/test_deprecations.py
test/orm/test_eager_relations.py
test/orm/test_froms.py
test/orm/test_lazy_relations.py
test/orm/test_mapper.py
test/orm/test_merge.py
test/orm/test_query.py
test/orm/test_relationships.py
test/orm/test_scoping.py
test/orm/test_selectin_relations.py
test/orm/test_subquery_relations.py
test/orm/test_unitofwork.py
test/orm/test_unitofworkv2.py
test/orm/test_validators.py
test/typing/plain_files/dialects/postgresql/pg_stuff.py [moved from test/ext/mypy/plain_files/pg_stuff.py with 100% similarity]
test/typing/plain_files/engine/engine_inspection.py [moved from test/ext/mypy/plain_files/engine_inspection.py with 100% similarity]
test/typing/plain_files/engine/engines.py [new file with mode: 0644]
test/typing/plain_files/ext/association_proxy/association_proxy_one.py [moved from test/ext/mypy/plain_files/association_proxy_one.py with 100% similarity]
test/typing/plain_files/ext/association_proxy/association_proxy_three.py [moved from test/ext/mypy/plain_files/association_proxy_three.py with 100% similarity]
test/typing/plain_files/ext/association_proxy/association_proxy_two.py [moved from test/ext/mypy/plain_files/association_proxy_two.py with 100% similarity]
test/typing/plain_files/ext/asyncio/async_sessionmaker.py [moved from test/ext/mypy/plain_files/async_sessionmaker.py with 94% similarity]
test/typing/plain_files/ext/asyncio/async_stuff.py [new file with mode: 0644]
test/typing/plain_files/ext/asyncio/create_proxy_methods.py [new file with mode: 0644]
test/typing/plain_files/ext/asyncio/engines.py [moved from test/ext/mypy/plain_files/engines.py with 73% similarity]
test/typing/plain_files/ext/hybrid/hybrid_four.py [moved from test/ext/mypy/plain_files/hybrid_four.py with 100% similarity]
test/typing/plain_files/ext/hybrid/hybrid_one.py [moved from test/ext/mypy/plain_files/hybrid_one.py with 100% similarity]
test/typing/plain_files/ext/hybrid/hybrid_three.py [moved from test/ext/mypy/plain_files/hybrid_three.py with 100% similarity]
test/typing/plain_files/ext/hybrid/hybrid_two.py [moved from test/ext/mypy/plain_files/hybrid_two.py with 100% similarity]
test/typing/plain_files/inspection_inspect.py [moved from test/ext/mypy/inspection_inspect.py with 78% similarity]
test/typing/plain_files/orm/complete_orm_no_plugin.py [moved from test/ext/mypy/plugin_files/complete_orm_no_plugin.py with 99% similarity]
test/typing/plain_files/orm/composite.py [moved from test/ext/mypy/plain_files/composite.py with 100% similarity]
test/typing/plain_files/orm/composite_dc.py [moved from test/ext/mypy/plain_files/composite_dc.py with 100% similarity]
test/typing/plain_files/orm/dataclass_transforms_one.py [moved from test/ext/mypy/plain_files/dataclass_transforms_one.py with 100% similarity]
test/typing/plain_files/orm/declared_attr_one.py [moved from test/ext/mypy/plain_files/declared_attr_one.py with 98% similarity]
test/typing/plain_files/orm/declared_attr_two.py [moved from test/ext/mypy/plain_files/declared_attr_two.py with 100% similarity]
test/typing/plain_files/orm/dynamic_rel.py [moved from test/ext/mypy/plain_files/dynamic_rel.py with 100% similarity]
test/typing/plain_files/orm/issue_9340.py [moved from test/ext/mypy/plain_files/issue_9340.py with 100% similarity]
test/typing/plain_files/orm/keyfunc_dict.py [moved from test/ext/mypy/plain_files/keyfunc_dict.py with 100% similarity]
test/typing/plain_files/orm/mapped_assign_expression.py [new file with mode: 0644]
test/typing/plain_files/orm/mapped_column.py [moved from test/ext/mypy/plain_files/mapped_column.py with 100% similarity]
test/typing/plain_files/orm/orm_config_constructs.py [moved from test/ext/mypy/plain_files/orm_config_constructs.py with 100% similarity]
test/typing/plain_files/orm/orm_querying.py [moved from test/ext/mypy/plain_files/orm_querying.py with 100% similarity]
test/typing/plain_files/orm/relationship.py [moved from test/ext/mypy/plain_files/experimental_relationship.py with 95% similarity]
test/typing/plain_files/orm/scoped_session.py [new file with mode: 0644]
test/typing/plain_files/orm/session.py [moved from test/ext/mypy/plain_files/session.py with 100% similarity]
test/typing/plain_files/orm/sessionmakers.py [moved from test/ext/mypy/plain_files/sessionmakers.py with 100% similarity]
test/typing/plain_files/orm/trad_relationship_uselist.py [moved from test/ext/mypy/plain_files/trad_relationship_uselist.py with 100% similarity]
test/typing/plain_files/orm/traditional_relationship.py [moved from test/ext/mypy/plain_files/traditional_relationship.py with 100% similarity]
test/typing/plain_files/orm/typed_queries.py [moved from test/ext/mypy/plain_files/typed_queries.py with 100% similarity]
test/typing/plain_files/orm/write_only.py [moved from test/ext/mypy/plain_files/write_only.py with 100% similarity]
test/typing/plain_files/sql/common_sql_element.py [moved from test/ext/mypy/plain_files/common_sql_element.py with 100% similarity]
test/typing/plain_files/sql/core_ddl.py [new file with mode: 0644]
test/typing/plain_files/sql/dml.py [moved from test/ext/mypy/plain_files/dml.py with 100% similarity]
test/typing/plain_files/sql/functions.py [moved from test/ext/mypy/plain_files/functions.py with 100% similarity]
test/typing/plain_files/sql/functions_again.py [new file with mode: 0644]
test/typing/plain_files/sql/lambda_stmt.py [moved from test/ext/mypy/plain_files/lambda_stmt.py with 100% similarity]
test/typing/plain_files/sql/lowercase_objects.py [new file with mode: 0644]
test/typing/plain_files/sql/operators.py [new file with mode: 0644]
test/typing/plain_files/sql/selectables.py [moved from test/ext/mypy/plain_files/selectables.py with 100% similarity]
test/typing/plain_files/sql/sql_operations.py [moved from test/ext/mypy/plain_files/sql_operations.py with 100% similarity]
test/typing/plain_files/sql/sqltypes.py [moved from test/ext/mypy/plain_files/sqltypes.py with 100% similarity]
test/typing/plain_files/sql/typed_results.py [moved from test/ext/mypy/plain_files/typed_results.py with 100% similarity]
test/typing/test_mypy.py [new file with mode: 0644]
test/typing/test_overloads.py [moved from test/ext/mypy/test_overloads.py with 100% similarity]
tools/generate_sql_functions.py

index 9f0dbeead2350da43ce5369beb5dd9973ed140bf..0f245835b095fa9816027e286e428a4599a1ded0 100644 (file)
@@ -255,7 +255,7 @@ class WriteOnlyAttributeImpl(
 
         state._modified_event(dict_, self, attributes.NEVER_SET)
 
-        # this is a hack to allow the fixtures.ComparableEntity fixture
+        # this is a hack to allow the entities.ComparableEntity fixture
         # to work
         dict_[self.key] = True
         return state.committed_state[self.key]
index d2bda4d83c5559c5c62cb22104ffa622d3c71d86..b8f03362f9aa4a7b8d62770aa46f199794a5bc8f 100644 (file)
@@ -9,6 +9,7 @@
 
 from __future__ import annotations
 
+from argparse import Namespace
 import collections
 import inspect
 import typing
@@ -34,6 +35,7 @@ test_schema_2 = None
 any_async = False
 _current = None
 ident = "main"
+options: Namespace = None  # type: ignore
 
 if typing.TYPE_CHECKING:
     from .plugin.plugin_base import FixtureFunctions
diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py
deleted file mode 100644 (file)
index cb08380..0000000
+++ /dev/null
@@ -1,1055 +0,0 @@
-# testing/fixtures.py
-# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
-# <see AUTHORS file>
-#
-# This module is part of SQLAlchemy and is released under
-# the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
-
-
-from __future__ import annotations
-
-import itertools
-import random
-import re
-import sys
-from typing import Any
-
-import sqlalchemy as sa
-from . import assertions
-from . import config
-from . import mock
-from . import schema
-from .assertions import eq_
-from .assertions import ne_
-from .entities import BasicEntity
-from .entities import ComparableEntity
-from .entities import ComparableMixin  # noqa
-from .util import adict
-from .util import drop_all_tables_from_metadata
-from .. import Column
-from .. import event
-from .. import func
-from .. import Integer
-from .. import select
-from .. import Table
-from .. import util
-from ..orm import DeclarativeBase
-from ..orm import events as orm_events
-from ..orm import MappedAsDataclass
-from ..orm import registry
-from ..schema import sort_tables_and_constraints
-from ..sql import visitors
-from ..sql.elements import ClauseElement
-
-
-@config.mark_base_test_class()
-class TestBase:
-    # A sequence of requirement names matching testing.requires decorators
-    __requires__ = ()
-
-    # A sequence of dialect names to exclude from the test class.
-    __unsupported_on__ = ()
-
-    # If present, test class is only runnable for the *single* specified
-    # dialect.  If you need multiple, use __unsupported_on__ and invert.
-    __only_on__ = None
-
-    # A sequence of no-arg callables. If any are True, the entire testcase is
-    # skipped.
-    __skip_if__ = None
-
-    # if True, the testing reaper will not attempt to touch connection
-    # state after a test is completed and before the outer teardown
-    # starts
-    __leave_connections_for_teardown__ = False
-
-    def assert_(self, val, msg=None):
-        assert val, msg
-
-    @config.fixture()
-    def nocache(self):
-        _cache = config.db._compiled_cache
-        config.db._compiled_cache = None
-        yield
-        config.db._compiled_cache = _cache
-
-    @config.fixture()
-    def connection_no_trans(self):
-        eng = getattr(self, "bind", None) or config.db
-
-        with eng.connect() as conn:
-            yield conn
-
-    @config.fixture()
-    def connection(self):
-        global _connection_fixture_connection
-
-        eng = getattr(self, "bind", None) or config.db
-
-        conn = eng.connect()
-        trans = conn.begin()
-
-        _connection_fixture_connection = conn
-        yield conn
-
-        _connection_fixture_connection = None
-
-        if trans.is_active:
-            trans.rollback()
-        # trans would not be active here if the test is using
-        # the legacy @provide_metadata decorator still, as it will
-        # run a close all connections.
-        conn.close()
-
-    @config.fixture()
-    def close_result_when_finished(self):
-        to_close = []
-        to_consume = []
-
-        def go(result, consume=False):
-            to_close.append(result)
-            if consume:
-                to_consume.append(result)
-
-        yield go
-        for r in to_consume:
-            try:
-                r.all()
-            except:
-                pass
-        for r in to_close:
-            try:
-                r.close()
-            except:
-                pass
-
-    @config.fixture()
-    def registry(self, metadata):
-        reg = registry(
-            metadata=metadata,
-            type_annotation_map={
-                str: sa.String().with_variant(
-                    sa.String(50), "mysql", "mariadb", "oracle"
-                )
-            },
-        )
-        yield reg
-        reg.dispose()
-
-    @config.fixture
-    def decl_base(self, metadata):
-        _md = metadata
-
-        class Base(DeclarativeBase):
-            metadata = _md
-            type_annotation_map = {
-                str: sa.String().with_variant(
-                    sa.String(50), "mysql", "mariadb", "oracle"
-                )
-            }
-
-        yield Base
-        Base.registry.dispose()
-
-    @config.fixture
-    def dc_decl_base(self, metadata):
-        _md = metadata
-
-        class Base(MappedAsDataclass, DeclarativeBase):
-            metadata = _md
-            type_annotation_map = {
-                str: sa.String().with_variant(
-                    sa.String(50), "mysql", "mariadb"
-                )
-            }
-
-        yield Base
-        Base.registry.dispose()
-
-    @config.fixture()
-    def future_connection(self, future_engine, connection):
-        # integrate the future_engine and connection fixtures so
-        # that users of the "connection" fixture will get at the
-        # "future" connection
-        yield connection
-
-    @config.fixture()
-    def future_engine(self):
-        yield
-
-    @config.fixture()
-    def testing_engine(self):
-        from . import engines
-
-        def gen_testing_engine(
-            url=None,
-            options=None,
-            future=None,
-            asyncio=False,
-            transfer_staticpool=False,
-            share_pool=False,
-        ):
-            if options is None:
-                options = {}
-            options["scope"] = "fixture"
-            return engines.testing_engine(
-                url=url,
-                options=options,
-                asyncio=asyncio,
-                transfer_staticpool=transfer_staticpool,
-                share_pool=share_pool,
-            )
-
-        yield gen_testing_engine
-
-        engines.testing_reaper._drop_testing_engines("fixture")
-
-    @config.fixture()
-    def async_testing_engine(self, testing_engine):
-        def go(**kw):
-            kw["asyncio"] = True
-            return testing_engine(**kw)
-
-        return go
-
-    @config.fixture
-    def fixture_session(self):
-        return fixture_session()
-
-    @config.fixture()
-    def metadata(self, request):
-        """Provide bound MetaData for a single test, dropping afterwards."""
-
-        from ..sql import schema
-
-        metadata = schema.MetaData()
-        request.instance.metadata = metadata
-        yield metadata
-        del request.instance.metadata
-
-        if (
-            _connection_fixture_connection
-            and _connection_fixture_connection.in_transaction()
-        ):
-            trans = _connection_fixture_connection.get_transaction()
-            trans.rollback()
-            with _connection_fixture_connection.begin():
-                drop_all_tables_from_metadata(
-                    metadata, _connection_fixture_connection
-                )
-        else:
-            drop_all_tables_from_metadata(metadata, config.db)
-
-    @config.fixture(
-        params=[
-            (rollback, second_operation, begin_nested)
-            for rollback in (True, False)
-            for second_operation in ("none", "execute", "begin")
-            for begin_nested in (
-                True,
-                False,
-            )
-        ]
-    )
-    def trans_ctx_manager_fixture(self, request, metadata):
-        rollback, second_operation, begin_nested = request.param
-
-        t = Table("test", metadata, Column("data", Integer))
-        eng = getattr(self, "bind", None) or config.db
-
-        t.create(eng)
-
-        def run_test(subject, trans_on_subject, execute_on_subject):
-            with subject.begin() as trans:
-                if begin_nested:
-                    if not config.requirements.savepoints.enabled:
-                        config.skip_test("savepoints not enabled")
-                    if execute_on_subject:
-                        nested_trans = subject.begin_nested()
-                    else:
-                        nested_trans = trans.begin_nested()
-
-                    with nested_trans:
-                        if execute_on_subject:
-                            subject.execute(t.insert(), {"data": 10})
-                        else:
-                            trans.execute(t.insert(), {"data": 10})
-
-                        # for nested trans, we always commit/rollback on the
-                        # "nested trans" object itself.
-                        # only Session(future=False) will affect savepoint
-                        # transaction for session.commit/rollback
-
-                        if rollback:
-                            nested_trans.rollback()
-                        else:
-                            nested_trans.commit()
-
-                        if second_operation != "none":
-                            with assertions.expect_raises_message(
-                                sa.exc.InvalidRequestError,
-                                "Can't operate on closed transaction "
-                                "inside context "
-                                "manager.  Please complete the context "
-                                "manager "
-                                "before emitting further commands.",
-                            ):
-                                if second_operation == "execute":
-                                    if execute_on_subject:
-                                        subject.execute(
-                                            t.insert(), {"data": 12}
-                                        )
-                                    else:
-                                        trans.execute(t.insert(), {"data": 12})
-                                elif second_operation == "begin":
-                                    if execute_on_subject:
-                                        subject.begin_nested()
-                                    else:
-                                        trans.begin_nested()
-
-                    # outside the nested trans block, but still inside the
-                    # transaction block, we can run SQL, and it will be
-                    # committed
-                    if execute_on_subject:
-                        subject.execute(t.insert(), {"data": 14})
-                    else:
-                        trans.execute(t.insert(), {"data": 14})
-
-                else:
-                    if execute_on_subject:
-                        subject.execute(t.insert(), {"data": 10})
-                    else:
-                        trans.execute(t.insert(), {"data": 10})
-
-                    if trans_on_subject:
-                        if rollback:
-                            subject.rollback()
-                        else:
-                            subject.commit()
-                    else:
-                        if rollback:
-                            trans.rollback()
-                        else:
-                            trans.commit()
-
-                    if second_operation != "none":
-                        with assertions.expect_raises_message(
-                            sa.exc.InvalidRequestError,
-                            "Can't operate on closed transaction inside "
-                            "context "
-                            "manager.  Please complete the context manager "
-                            "before emitting further commands.",
-                        ):
-                            if second_operation == "execute":
-                                if execute_on_subject:
-                                    subject.execute(t.insert(), {"data": 12})
-                                else:
-                                    trans.execute(t.insert(), {"data": 12})
-                            elif second_operation == "begin":
-                                if hasattr(trans, "begin"):
-                                    trans.begin()
-                                else:
-                                    subject.begin()
-                            elif second_operation == "begin_nested":
-                                if execute_on_subject:
-                                    subject.begin_nested()
-                                else:
-                                    trans.begin_nested()
-
-            expected_committed = 0
-            if begin_nested:
-                # begin_nested variant, we inserted a row after the nested
-                # block
-                expected_committed += 1
-            if not rollback:
-                # not rollback variant, our row inserted in the target
-                # block itself would be committed
-                expected_committed += 1
-
-            if execute_on_subject:
-                eq_(
-                    subject.scalar(select(func.count()).select_from(t)),
-                    expected_committed,
-                )
-            else:
-                with subject.connect() as conn:
-                    eq_(
-                        conn.scalar(select(func.count()).select_from(t)),
-                        expected_committed,
-                    )
-
-        return run_test
-
-
-_connection_fixture_connection = None
-
-
-class FutureEngineMixin:
-    """alembic's suite still using this"""
-
-
-class TablesTest(TestBase):
-    # 'once', None
-    run_setup_bind = "once"
-
-    # 'once', 'each', None
-    run_define_tables = "once"
-
-    # 'once', 'each', None
-    run_create_tables = "once"
-
-    # 'once', 'each', None
-    run_inserts = "each"
-
-    # 'each', None
-    run_deletes = "each"
-
-    # 'once', None
-    run_dispose_bind = None
-
-    bind = None
-    _tables_metadata = None
-    tables = None
-    other = None
-    sequences = None
-
-    @config.fixture(autouse=True, scope="class")
-    def _setup_tables_test_class(self):
-        cls = self.__class__
-        cls._init_class()
-
-        cls._setup_once_tables()
-
-        cls._setup_once_inserts()
-
-        yield
-
-        cls._teardown_once_metadata_bind()
-
-    @config.fixture(autouse=True, scope="function")
-    def _setup_tables_test_instance(self):
-        self._setup_each_tables()
-        self._setup_each_inserts()
-
-        yield
-
-        self._teardown_each_tables()
-
-    @property
-    def tables_test_metadata(self):
-        return self._tables_metadata
-
-    @classmethod
-    def _init_class(cls):
-        if cls.run_define_tables == "each":
-            if cls.run_create_tables == "once":
-                cls.run_create_tables = "each"
-            assert cls.run_inserts in ("each", None)
-
-        cls.other = adict()
-        cls.tables = adict()
-        cls.sequences = adict()
-
-        cls.bind = cls.setup_bind()
-        cls._tables_metadata = sa.MetaData()
-
-    @classmethod
-    def _setup_once_inserts(cls):
-        if cls.run_inserts == "once":
-            cls._load_fixtures()
-            with cls.bind.begin() as conn:
-                cls.insert_data(conn)
-
-    @classmethod
-    def _setup_once_tables(cls):
-        if cls.run_define_tables == "once":
-            cls.define_tables(cls._tables_metadata)
-            if cls.run_create_tables == "once":
-                cls._tables_metadata.create_all(cls.bind)
-            cls.tables.update(cls._tables_metadata.tables)
-            cls.sequences.update(cls._tables_metadata._sequences)
-
-    def _setup_each_tables(self):
-        if self.run_define_tables == "each":
-            self.define_tables(self._tables_metadata)
-            if self.run_create_tables == "each":
-                self._tables_metadata.create_all(self.bind)
-            self.tables.update(self._tables_metadata.tables)
-            self.sequences.update(self._tables_metadata._sequences)
-        elif self.run_create_tables == "each":
-            self._tables_metadata.create_all(self.bind)
-
-    def _setup_each_inserts(self):
-        if self.run_inserts == "each":
-            self._load_fixtures()
-            with self.bind.begin() as conn:
-                self.insert_data(conn)
-
-    def _teardown_each_tables(self):
-        if self.run_define_tables == "each":
-            self.tables.clear()
-            if self.run_create_tables == "each":
-                drop_all_tables_from_metadata(self._tables_metadata, self.bind)
-            self._tables_metadata.clear()
-        elif self.run_create_tables == "each":
-            drop_all_tables_from_metadata(self._tables_metadata, self.bind)
-
-        savepoints = getattr(config.requirements, "savepoints", False)
-        if savepoints:
-            savepoints = savepoints.enabled
-
-        # no need to run deletes if tables are recreated on setup
-        if (
-            self.run_define_tables != "each"
-            and self.run_create_tables != "each"
-            and self.run_deletes == "each"
-        ):
-            with self.bind.begin() as conn:
-                for table in reversed(
-                    [
-                        t
-                        for (t, fks) in sort_tables_and_constraints(
-                            self._tables_metadata.tables.values()
-                        )
-                        if t is not None
-                    ]
-                ):
-                    try:
-                        if savepoints:
-                            with conn.begin_nested():
-                                conn.execute(table.delete())
-                        else:
-                            conn.execute(table.delete())
-                    except sa.exc.DBAPIError as ex:
-                        print(
-                            ("Error emptying table %s: %r" % (table, ex)),
-                            file=sys.stderr,
-                        )
-
-    @classmethod
-    def _teardown_once_metadata_bind(cls):
-        if cls.run_create_tables:
-            drop_all_tables_from_metadata(cls._tables_metadata, cls.bind)
-
-        if cls.run_dispose_bind == "once":
-            cls.dispose_bind(cls.bind)
-
-        cls._tables_metadata.bind = None
-
-        if cls.run_setup_bind is not None:
-            cls.bind = None
-
-    @classmethod
-    def setup_bind(cls):
-        return config.db
-
-    @classmethod
-    def dispose_bind(cls, bind):
-        if hasattr(bind, "dispose"):
-            bind.dispose()
-        elif hasattr(bind, "close"):
-            bind.close()
-
-    @classmethod
-    def define_tables(cls, metadata):
-        pass
-
-    @classmethod
-    def fixtures(cls):
-        return {}
-
-    @classmethod
-    def insert_data(cls, connection):
-        pass
-
-    def sql_count_(self, count, fn):
-        self.assert_sql_count(self.bind, fn, count)
-
-    def sql_eq_(self, callable_, statements):
-        self.assert_sql(self.bind, callable_, statements)
-
-    @classmethod
-    def _load_fixtures(cls):
-        """Insert rows as represented by the fixtures() method."""
-        headers, rows = {}, {}
-        for table, data in cls.fixtures().items():
-            if len(data) < 2:
-                continue
-            if isinstance(table, str):
-                table = cls.tables[table]
-            headers[table] = data[0]
-            rows[table] = data[1:]
-        for table, fks in sort_tables_and_constraints(
-            cls._tables_metadata.tables.values()
-        ):
-            if table is None:
-                continue
-            if table not in headers:
-                continue
-            with cls.bind.begin() as conn:
-                conn.execute(
-                    table.insert(),
-                    [
-                        dict(zip(headers[table], column_values))
-                        for column_values in rows[table]
-                    ],
-                )
-
-
-class NoCache:
-    @config.fixture(autouse=True, scope="function")
-    def _disable_cache(self):
-        _cache = config.db._compiled_cache
-        config.db._compiled_cache = None
-        yield
-        config.db._compiled_cache = _cache
-
-
-class RemovesEvents:
-    @util.memoized_property
-    def _event_fns(self):
-        return set()
-
-    def event_listen(self, target, name, fn, **kw):
-        self._event_fns.add((target, name, fn))
-        event.listen(target, name, fn, **kw)
-
-    @config.fixture(autouse=True, scope="function")
-    def _remove_events(self):
-        yield
-        for key in self._event_fns:
-            event.remove(*key)
-
-
-class RemoveORMEventsGlobally:
-    @config.fixture(autouse=True)
-    def _remove_listeners(self):
-        yield
-        orm_events.MapperEvents._clear()
-        orm_events.InstanceEvents._clear()
-        orm_events.SessionEvents._clear()
-        orm_events.InstrumentationEvents._clear()
-        orm_events.QueryEvents._clear()
-
-
-_fixture_sessions = set()
-
-
-def fixture_session(**kw):
-    kw.setdefault("autoflush", True)
-    kw.setdefault("expire_on_commit", True)
-
-    bind = kw.pop("bind", config.db)
-
-    sess = sa.orm.Session(bind, **kw)
-    _fixture_sessions.add(sess)
-    return sess
-
-
-def _close_all_sessions():
-    # will close all still-referenced sessions
-    sa.orm.session.close_all_sessions()
-    _fixture_sessions.clear()
-
-
-def stop_test_class_inside_fixtures(cls):
-    _close_all_sessions()
-    sa.orm.clear_mappers()
-
-
-def after_test():
-    if _fixture_sessions:
-        _close_all_sessions()
-
-
-class ORMTest(TestBase):
-    pass
-
-
-class MappedTest(TablesTest, assertions.AssertsExecutionResults):
-    # 'once', 'each', None
-    run_setup_classes = "once"
-
-    # 'once', 'each', None
-    run_setup_mappers = "each"
-
-    classes: Any = None
-
-    @config.fixture(autouse=True, scope="class")
-    def _setup_tables_test_class(self):
-        cls = self.__class__
-        cls._init_class()
-
-        if cls.classes is None:
-            cls.classes = adict()
-
-        cls._setup_once_tables()
-        cls._setup_once_classes()
-        cls._setup_once_mappers()
-        cls._setup_once_inserts()
-
-        yield
-
-        cls._teardown_once_class()
-        cls._teardown_once_metadata_bind()
-
-    @config.fixture(autouse=True, scope="function")
-    def _setup_tables_test_instance(self):
-        self._setup_each_tables()
-        self._setup_each_classes()
-        self._setup_each_mappers()
-        self._setup_each_inserts()
-
-        yield
-
-        sa.orm.session.close_all_sessions()
-        self._teardown_each_mappers()
-        self._teardown_each_classes()
-        self._teardown_each_tables()
-
-    @classmethod
-    def _teardown_once_class(cls):
-        cls.classes.clear()
-
-    @classmethod
-    def _setup_once_classes(cls):
-        if cls.run_setup_classes == "once":
-            cls._with_register_classes(cls.setup_classes)
-
-    @classmethod
-    def _setup_once_mappers(cls):
-        if cls.run_setup_mappers == "once":
-            cls.mapper_registry, cls.mapper = cls._generate_registry()
-            cls._with_register_classes(cls.setup_mappers)
-
-    def _setup_each_mappers(self):
-        if self.run_setup_mappers != "once":
-            (
-                self.__class__.mapper_registry,
-                self.__class__.mapper,
-            ) = self._generate_registry()
-
-        if self.run_setup_mappers == "each":
-            self._with_register_classes(self.setup_mappers)
-
-    def _setup_each_classes(self):
-        if self.run_setup_classes == "each":
-            self._with_register_classes(self.setup_classes)
-
-    @classmethod
-    def _generate_registry(cls):
-        decl = registry(metadata=cls._tables_metadata)
-        return decl, decl.map_imperatively
-
-    @classmethod
-    def _with_register_classes(cls, fn):
-        """Run a setup method, framing the operation with a Base class
-        that will catch new subclasses to be established within
-        the "classes" registry.
-
-        """
-        cls_registry = cls.classes
-
-        class _Base:
-            def __init_subclass__(cls) -> None:
-                assert cls_registry is not None
-                cls_registry[cls.__name__] = cls
-                super().__init_subclass__()
-
-        class Basic(BasicEntity, _Base):
-            pass
-
-        class Comparable(ComparableEntity, _Base):
-            pass
-
-        cls.Basic = Basic
-        cls.Comparable = Comparable
-        fn()
-
-    def _teardown_each_mappers(self):
-        # some tests create mappers in the test bodies
-        # and will define setup_mappers as None -
-        # clear mappers in any case
-        if self.run_setup_mappers != "once":
-            sa.orm.clear_mappers()
-
-    def _teardown_each_classes(self):
-        if self.run_setup_classes != "once":
-            self.classes.clear()
-
-    @classmethod
-    def setup_classes(cls):
-        pass
-
-    @classmethod
-    def setup_mappers(cls):
-        pass
-
-
-class DeclarativeMappedTest(MappedTest):
-    run_setup_classes = "once"
-    run_setup_mappers = "once"
-
-    @classmethod
-    def _setup_once_tables(cls):
-        pass
-
-    @classmethod
-    def _with_register_classes(cls, fn):
-        cls_registry = cls.classes
-
-        class _DeclBase(DeclarativeBase):
-            __table_cls__ = schema.Table
-            metadata = cls._tables_metadata
-            type_annotation_map = {
-                str: sa.String().with_variant(
-                    sa.String(50), "mysql", "mariadb", "oracle"
-                )
-            }
-
-            def __init_subclass__(cls, **kw) -> None:
-                assert cls_registry is not None
-                cls_registry[cls.__name__] = cls
-                super().__init_subclass__(**kw)
-
-        cls.DeclarativeBasic = _DeclBase
-
-        # sets up cls.Basic which is helpful for things like composite
-        # classes
-        super()._with_register_classes(fn)
-
-        if cls._tables_metadata.tables and cls.run_create_tables:
-            cls._tables_metadata.create_all(config.db)
-
-
-class ComputedReflectionFixtureTest(TablesTest):
-    run_inserts = run_deletes = None
-
-    __backend__ = True
-    __requires__ = ("computed_columns", "table_reflection")
-
-    regexp = re.compile(r"[\[\]\(\)\s`'\"]*")
-
-    def normalize(self, text):
-        return self.regexp.sub("", text).lower()
-
-    @classmethod
-    def define_tables(cls, metadata):
-        from .. import Integer
-        from .. import testing
-        from ..schema import Column
-        from ..schema import Computed
-        from ..schema import Table
-
-        Table(
-            "computed_default_table",
-            metadata,
-            Column("id", Integer, primary_key=True),
-            Column("normal", Integer),
-            Column("computed_col", Integer, Computed("normal + 42")),
-            Column("with_default", Integer, server_default="42"),
-        )
-
-        t = Table(
-            "computed_column_table",
-            metadata,
-            Column("id", Integer, primary_key=True),
-            Column("normal", Integer),
-            Column("computed_no_flag", Integer, Computed("normal + 42")),
-        )
-
-        if testing.requires.schemas.enabled:
-            t2 = Table(
-                "computed_column_table",
-                metadata,
-                Column("id", Integer, primary_key=True),
-                Column("normal", Integer),
-                Column("computed_no_flag", Integer, Computed("normal / 42")),
-                schema=config.test_schema,
-            )
-
-        if testing.requires.computed_columns_virtual.enabled:
-            t.append_column(
-                Column(
-                    "computed_virtual",
-                    Integer,
-                    Computed("normal + 2", persisted=False),
-                )
-            )
-            if testing.requires.schemas.enabled:
-                t2.append_column(
-                    Column(
-                        "computed_virtual",
-                        Integer,
-                        Computed("normal / 2", persisted=False),
-                    )
-                )
-        if testing.requires.computed_columns_stored.enabled:
-            t.append_column(
-                Column(
-                    "computed_stored",
-                    Integer,
-                    Computed("normal - 42", persisted=True),
-                )
-            )
-            if testing.requires.schemas.enabled:
-                t2.append_column(
-                    Column(
-                        "computed_stored",
-                        Integer,
-                        Computed("normal * 42", persisted=True),
-                    )
-                )
-
-
-class CacheKeyFixture:
-    def _compare_equal(self, a, b, compare_values):
-        a_key = a._generate_cache_key()
-        b_key = b._generate_cache_key()
-
-        if a_key is None:
-            assert a._annotations.get("nocache")
-
-            assert b_key is None
-        else:
-            eq_(a_key.key, b_key.key)
-            eq_(hash(a_key.key), hash(b_key.key))
-
-            for a_param, b_param in zip(a_key.bindparams, b_key.bindparams):
-                assert a_param.compare(b_param, compare_values=compare_values)
-        return a_key, b_key
-
-    def _run_cache_key_fixture(self, fixture, compare_values):
-        case_a = fixture()
-        case_b = fixture()
-
-        for a, b in itertools.combinations_with_replacement(
-            range(len(case_a)), 2
-        ):
-            if a == b:
-                a_key, b_key = self._compare_equal(
-                    case_a[a], case_b[b], compare_values
-                )
-                if a_key is None:
-                    continue
-            else:
-                a_key = case_a[a]._generate_cache_key()
-                b_key = case_b[b]._generate_cache_key()
-
-                if a_key is None or b_key is None:
-                    if a_key is None:
-                        assert case_a[a]._annotations.get("nocache")
-                    if b_key is None:
-                        assert case_b[b]._annotations.get("nocache")
-                    continue
-
-                if a_key.key == b_key.key:
-                    for a_param, b_param in zip(
-                        a_key.bindparams, b_key.bindparams
-                    ):
-                        if not a_param.compare(
-                            b_param, compare_values=compare_values
-                        ):
-                            break
-                    else:
-                        # this fails unconditionally since we could not
-                        # find bound parameter values that differed.
-                        # Usually we intended to get two distinct keys here
-                        # so the failure will be more descriptive using the
-                        # ne_() assertion.
-                        ne_(a_key.key, b_key.key)
-                else:
-                    ne_(a_key.key, b_key.key)
-
-            # ClauseElement-specific test to ensure the cache key
-            # collected all the bound parameters that aren't marked
-            # as "literal execute"
-            if isinstance(case_a[a], ClauseElement) and isinstance(
-                case_b[b], ClauseElement
-            ):
-                assert_a_params = []
-                assert_b_params = []
-
-                for elem in visitors.iterate(case_a[a]):
-                    if elem.__visit_name__ == "bindparam":
-                        assert_a_params.append(elem)
-
-                for elem in visitors.iterate(case_b[b]):
-                    if elem.__visit_name__ == "bindparam":
-                        assert_b_params.append(elem)
-
-                # note we're asserting the order of the params as well as
-                # if there are dupes or not.  ordering has to be
-                # deterministic and matches what a traversal would provide.
-                eq_(
-                    sorted(a_key.bindparams, key=lambda b: b.key),
-                    sorted(
-                        util.unique_list(assert_a_params), key=lambda b: b.key
-                    ),
-                )
-                eq_(
-                    sorted(b_key.bindparams, key=lambda b: b.key),
-                    sorted(
-                        util.unique_list(assert_b_params), key=lambda b: b.key
-                    ),
-                )
-
-    def _run_cache_key_equal_fixture(self, fixture, compare_values):
-        case_a = fixture()
-        case_b = fixture()
-
-        for a, b in itertools.combinations_with_replacement(
-            range(len(case_a)), 2
-        ):
-            self._compare_equal(case_a[a], case_b[b], compare_values)
-
-
-def insertmanyvalues_fixture(
-    connection, randomize_rows=False, warn_on_downgraded=False
-):
-    dialect = connection.dialect
-    orig_dialect = dialect._deliver_insertmanyvalues_batches
-    orig_conn = connection._exec_insertmany_context
-
-    class RandomCursor:
-        __slots__ = ("cursor",)
-
-        def __init__(self, cursor):
-            self.cursor = cursor
-
-        # only this method is called by the deliver method.
-        # by not having the other methods we assert that those aren't being
-        # used
-
-        def fetchall(self):
-            rows = self.cursor.fetchall()
-            rows = list(rows)
-            random.shuffle(rows)
-            return rows
-
-    def _deliver_insertmanyvalues_batches(
-        cursor, statement, parameters, generic_setinputsizes, context
-    ):
-        if randomize_rows:
-            cursor = RandomCursor(cursor)
-        for batch in orig_dialect(
-            cursor, statement, parameters, generic_setinputsizes, context
-        ):
-            if warn_on_downgraded and batch.is_downgraded:
-                util.warn("Batches were downgraded for sorted INSERT")
-
-            yield batch
-
-    def _exec_insertmany_context(
-        dialect,
-        context,
-    ):
-        with mock.patch.object(
-            dialect,
-            "_deliver_insertmanyvalues_batches",
-            new=_deliver_insertmanyvalues_batches,
-        ):
-            return orig_conn(dialect, context)
-
-    connection._exec_insertmany_context = _exec_insertmany_context
diff --git a/lib/sqlalchemy/testing/fixtures/__init__.py b/lib/sqlalchemy/testing/fixtures/__init__.py
new file mode 100644 (file)
index 0000000..932051c
--- /dev/null
@@ -0,0 +1,28 @@
+# testing/fixtures/__init__.py
+# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+# mypy: ignore-errors
+from .base import FutureEngineMixin as FutureEngineMixin
+from .base import TestBase as TestBase
+from .mypy import MypyTest as MypyTest
+from .orm import after_test as after_test
+from .orm import close_all_sessions as close_all_sessions
+from .orm import DeclarativeMappedTest as DeclarativeMappedTest
+from .orm import fixture_session as fixture_session
+from .orm import MappedTest as MappedTest
+from .orm import ORMTest as ORMTest
+from .orm import RemoveORMEventsGlobally as RemoveORMEventsGlobally
+from .orm import (
+    stop_test_class_inside_fixtures as stop_test_class_inside_fixtures,
+)
+from .sql import CacheKeyFixture as CacheKeyFixture
+from .sql import (
+    ComputedReflectionFixtureTest as ComputedReflectionFixtureTest,
+)
+from .sql import insertmanyvalues_fixture as insertmanyvalues_fixture
+from .sql import NoCache as NoCache
+from .sql import RemovesEvents as RemovesEvents
+from .sql import TablesTest as TablesTest
diff --git a/lib/sqlalchemy/testing/fixtures/base.py b/lib/sqlalchemy/testing/fixtures/base.py
new file mode 100644 (file)
index 0000000..199ae71
--- /dev/null
@@ -0,0 +1,366 @@
+# testing/fixtures/base.py
+# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+# mypy: ignore-errors
+
+
+from __future__ import annotations
+
+import sqlalchemy as sa
+from .. import assertions
+from .. import config
+from ..assertions import eq_
+from ..util import drop_all_tables_from_metadata
+from ... import Column
+from ... import func
+from ... import Integer
+from ... import select
+from ... import Table
+from ...orm import DeclarativeBase
+from ...orm import MappedAsDataclass
+from ...orm import registry
+
+
+@config.mark_base_test_class()
+class TestBase:
+    # A sequence of requirement names matching testing.requires decorators
+    __requires__ = ()
+
+    # A sequence of dialect names to exclude from the test class.
+    __unsupported_on__ = ()
+
+    # If present, test class is only runnable for the *single* specified
+    # dialect.  If you need multiple, use __unsupported_on__ and invert.
+    __only_on__ = None
+
+    # A sequence of no-arg callables. If any are True, the entire testcase is
+    # skipped.
+    __skip_if__ = None
+
+    # if True, the testing reaper will not attempt to touch connection
+    # state after a test is completed and before the outer teardown
+    # starts
+    __leave_connections_for_teardown__ = False
+
+    def assert_(self, val, msg=None):
+        assert val, msg
+
+    @config.fixture()
+    def nocache(self):
+        _cache = config.db._compiled_cache
+        config.db._compiled_cache = None
+        yield
+        config.db._compiled_cache = _cache
+
+    @config.fixture()
+    def connection_no_trans(self):
+        eng = getattr(self, "bind", None) or config.db
+
+        with eng.connect() as conn:
+            yield conn
+
+    @config.fixture()
+    def connection(self):
+        global _connection_fixture_connection
+
+        eng = getattr(self, "bind", None) or config.db
+
+        conn = eng.connect()
+        trans = conn.begin()
+
+        _connection_fixture_connection = conn
+        yield conn
+
+        _connection_fixture_connection = None
+
+        if trans.is_active:
+            trans.rollback()
+        # trans would not be active here if the test is using
+        # the legacy @provide_metadata decorator still, as it will
+        # run a close all connections.
+        conn.close()
+
+    @config.fixture()
+    def close_result_when_finished(self):
+        to_close = []
+        to_consume = []
+
+        def go(result, consume=False):
+            to_close.append(result)
+            if consume:
+                to_consume.append(result)
+
+        yield go
+        for r in to_consume:
+            try:
+                r.all()
+            except:
+                pass
+        for r in to_close:
+            try:
+                r.close()
+            except:
+                pass
+
+    @config.fixture()
+    def registry(self, metadata):
+        reg = registry(
+            metadata=metadata,
+            type_annotation_map={
+                str: sa.String().with_variant(
+                    sa.String(50), "mysql", "mariadb", "oracle"
+                )
+            },
+        )
+        yield reg
+        reg.dispose()
+
+    @config.fixture
+    def decl_base(self, metadata):
+        _md = metadata
+
+        class Base(DeclarativeBase):
+            metadata = _md
+            type_annotation_map = {
+                str: sa.String().with_variant(
+                    sa.String(50), "mysql", "mariadb", "oracle"
+                )
+            }
+
+        yield Base
+        Base.registry.dispose()
+
+    @config.fixture
+    def dc_decl_base(self, metadata):
+        _md = metadata
+
+        class Base(MappedAsDataclass, DeclarativeBase):
+            metadata = _md
+            type_annotation_map = {
+                str: sa.String().with_variant(
+                    sa.String(50), "mysql", "mariadb"
+                )
+            }
+
+        yield Base
+        Base.registry.dispose()
+
+    @config.fixture()
+    def future_connection(self, future_engine, connection):
+        # integrate the future_engine and connection fixtures so
+        # that users of the "connection" fixture will get at the
+        # "future" connection
+        yield connection
+
+    @config.fixture()
+    def future_engine(self):
+        yield
+
+    @config.fixture()
+    def testing_engine(self):
+        from .. import engines
+
+        def gen_testing_engine(
+            url=None,
+            options=None,
+            future=None,
+            asyncio=False,
+            transfer_staticpool=False,
+            share_pool=False,
+        ):
+            if options is None:
+                options = {}
+            options["scope"] = "fixture"
+            return engines.testing_engine(
+                url=url,
+                options=options,
+                asyncio=asyncio,
+                transfer_staticpool=transfer_staticpool,
+                share_pool=share_pool,
+            )
+
+        yield gen_testing_engine
+
+        engines.testing_reaper._drop_testing_engines("fixture")
+
+    @config.fixture()
+    def async_testing_engine(self, testing_engine):
+        def go(**kw):
+            kw["asyncio"] = True
+            return testing_engine(**kw)
+
+        return go
+
+    @config.fixture()
+    def metadata(self, request):
+        """Provide bound MetaData for a single test, dropping afterwards."""
+
+        from ...sql import schema
+
+        metadata = schema.MetaData()
+        request.instance.metadata = metadata
+        yield metadata
+        del request.instance.metadata
+
+        if (
+            _connection_fixture_connection
+            and _connection_fixture_connection.in_transaction()
+        ):
+            trans = _connection_fixture_connection.get_transaction()
+            trans.rollback()
+            with _connection_fixture_connection.begin():
+                drop_all_tables_from_metadata(
+                    metadata, _connection_fixture_connection
+                )
+        else:
+            drop_all_tables_from_metadata(metadata, config.db)
+
+    @config.fixture(
+        params=[
+            (rollback, second_operation, begin_nested)
+            for rollback in (True, False)
+            for second_operation in ("none", "execute", "begin")
+            for begin_nested in (
+                True,
+                False,
+            )
+        ]
+    )
+    def trans_ctx_manager_fixture(self, request, metadata):
+        rollback, second_operation, begin_nested = request.param
+
+        t = Table("test", metadata, Column("data", Integer))
+        eng = getattr(self, "bind", None) or config.db
+
+        t.create(eng)
+
+        def run_test(subject, trans_on_subject, execute_on_subject):
+            with subject.begin() as trans:
+                if begin_nested:
+                    if not config.requirements.savepoints.enabled:
+                        config.skip_test("savepoints not enabled")
+                    if execute_on_subject:
+                        nested_trans = subject.begin_nested()
+                    else:
+                        nested_trans = trans.begin_nested()
+
+                    with nested_trans:
+                        if execute_on_subject:
+                            subject.execute(t.insert(), {"data": 10})
+                        else:
+                            trans.execute(t.insert(), {"data": 10})
+
+                        # for nested trans, we always commit/rollback on the
+                        # "nested trans" object itself.
+                        # only Session(future=False) will affect savepoint
+                        # transaction for session.commit/rollback
+
+                        if rollback:
+                            nested_trans.rollback()
+                        else:
+                            nested_trans.commit()
+
+                        if second_operation != "none":
+                            with assertions.expect_raises_message(
+                                sa.exc.InvalidRequestError,
+                                "Can't operate on closed transaction "
+                                "inside context "
+                                "manager.  Please complete the context "
+                                "manager "
+                                "before emitting further commands.",
+                            ):
+                                if second_operation == "execute":
+                                    if execute_on_subject:
+                                        subject.execute(
+                                            t.insert(), {"data": 12}
+                                        )
+                                    else:
+                                        trans.execute(t.insert(), {"data": 12})
+                                elif second_operation == "begin":
+                                    if execute_on_subject:
+                                        subject.begin_nested()
+                                    else:
+                                        trans.begin_nested()
+
+                    # outside the nested trans block, but still inside the
+                    # transaction block, we can run SQL, and it will be
+                    # committed
+                    if execute_on_subject:
+                        subject.execute(t.insert(), {"data": 14})
+                    else:
+                        trans.execute(t.insert(), {"data": 14})
+
+                else:
+                    if execute_on_subject:
+                        subject.execute(t.insert(), {"data": 10})
+                    else:
+                        trans.execute(t.insert(), {"data": 10})
+
+                    if trans_on_subject:
+                        if rollback:
+                            subject.rollback()
+                        else:
+                            subject.commit()
+                    else:
+                        if rollback:
+                            trans.rollback()
+                        else:
+                            trans.commit()
+
+                    if second_operation != "none":
+                        with assertions.expect_raises_message(
+                            sa.exc.InvalidRequestError,
+                            "Can't operate on closed transaction inside "
+                            "context "
+                            "manager.  Please complete the context manager "
+                            "before emitting further commands.",
+                        ):
+                            if second_operation == "execute":
+                                if execute_on_subject:
+                                    subject.execute(t.insert(), {"data": 12})
+                                else:
+                                    trans.execute(t.insert(), {"data": 12})
+                            elif second_operation == "begin":
+                                if hasattr(trans, "begin"):
+                                    trans.begin()
+                                else:
+                                    subject.begin()
+                            elif second_operation == "begin_nested":
+                                if execute_on_subject:
+                                    subject.begin_nested()
+                                else:
+                                    trans.begin_nested()
+
+            expected_committed = 0
+            if begin_nested:
+                # begin_nested variant, we inserted a row after the nested
+                # block
+                expected_committed += 1
+            if not rollback:
+                # not rollback variant, our row inserted in the target
+                # block itself would be committed
+                expected_committed += 1
+
+            if execute_on_subject:
+                eq_(
+                    subject.scalar(select(func.count()).select_from(t)),
+                    expected_committed,
+                )
+            else:
+                with subject.connect() as conn:
+                    eq_(
+                        conn.scalar(select(func.count()).select_from(t)),
+                        expected_committed,
+                    )
+
+        return run_test
+
+
+_connection_fixture_connection = None
+
+
+class FutureEngineMixin:
+    """alembic's suite still using this"""
diff --git a/lib/sqlalchemy/testing/fixtures/mypy.py b/lib/sqlalchemy/testing/fixtures/mypy.py
new file mode 100644 (file)
index 0000000..80e5ee0
--- /dev/null
@@ -0,0 +1,308 @@
+# testing/fixtures/mypy.py
+# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+# mypy: ignore-errors
+
+from __future__ import annotations
+
+import inspect
+import os
+from pathlib import Path
+import re
+import shutil
+import sys
+import tempfile
+
+from .base import TestBase
+from .. import config
+from ..assertions import eq_
+from ... import util
+
+
+@config.add_to_marker.mypy
+class MypyTest(TestBase):
+    __requires__ = ("no_sqlalchemy2_stubs",)
+
+    @config.fixture(scope="function")
+    def per_func_cachedir(self):
+        yield from self._cachedir()
+
+    @config.fixture(scope="class")
+    def cachedir(self):
+        yield from self._cachedir()
+
+    def _cachedir(self):
+        # as of mypy 0.971 i think we need to keep mypy_path empty
+        mypy_path = ""
+
+        with tempfile.TemporaryDirectory() as cachedir:
+            with open(
+                Path(cachedir) / "sqla_mypy_config.cfg", "w"
+            ) as config_file:
+                config_file.write(
+                    f"""
+                    [mypy]\n
+                    plugins = sqlalchemy.ext.mypy.plugin\n
+                    show_error_codes = True\n
+                    {mypy_path}
+                    disable_error_code = no-untyped-call
+
+                    [mypy-sqlalchemy.*]
+                    ignore_errors = True
+
+                    """
+                )
+            with open(
+                Path(cachedir) / "plain_mypy_config.cfg", "w"
+            ) as config_file:
+                config_file.write(
+                    f"""
+                    [mypy]\n
+                    show_error_codes = True\n
+                    {mypy_path}
+                    disable_error_code = var-annotated,no-untyped-call
+                    [mypy-sqlalchemy.*]
+                    ignore_errors = True
+
+                    """
+                )
+            yield cachedir
+
+    @config.fixture()
+    def mypy_runner(self, cachedir):
+        from mypy import api
+
+        def run(path, use_plugin=False, use_cachedir=None):
+            if use_cachedir is None:
+                use_cachedir = cachedir
+            args = [
+                "--strict",
+                "--raise-exceptions",
+                "--cache-dir",
+                use_cachedir,
+                "--config-file",
+                os.path.join(
+                    use_cachedir,
+                    "sqla_mypy_config.cfg"
+                    if use_plugin
+                    else "plain_mypy_config.cfg",
+                ),
+            ]
+
+            # mypy as of 0.990 is more aggressively blocking messaging
+            # for paths that are in sys.path, and as pytest puts currdir,
+            # test/ etc in sys.path, just copy the source file to the
+            # tempdir we are working in so that we don't have to try to
+            # manipulate sys.path and/or guess what mypy is doing
+            filename = os.path.basename(path)
+            test_program = os.path.join(use_cachedir, filename)
+            if path != test_program:
+                shutil.copyfile(path, test_program)
+            args.append(test_program)
+
+            # I set this locally but for the suite here needs to be
+            # disabled
+            os.environ.pop("MYPY_FORCE_COLOR", None)
+
+            stdout, stderr, exitcode = api.run(args)
+            return stdout, stderr, exitcode
+
+        return run
+
+    @config.fixture
+    def mypy_typecheck_file(self, mypy_runner):
+        def run(path, use_plugin=False):
+            expected_messages = self._collect_messages(path)
+            stdout, stderr, exitcode = mypy_runner(path, use_plugin=use_plugin)
+            self._check_output(
+                path, expected_messages, stdout, stderr, exitcode
+            )
+
+        return run
+
+    @staticmethod
+    def file_combinations(dirname):
+        if os.path.isabs(dirname):
+            path = dirname
+        else:
+            caller_path = inspect.stack()[1].filename
+            path = os.path.join(os.path.dirname(caller_path), dirname)
+        files = list(Path(path).glob("**/*.py"))
+
+        for extra_dir in config.options.mypy_extra_test_paths:
+            if extra_dir and os.path.isdir(extra_dir):
+                files.extend((Path(extra_dir) / dirname).glob("**/*.py"))
+        return files
+
+    def _collect_messages(self, path):
+        from sqlalchemy.ext.mypy.util import mypy_14
+
+        expected_messages = []
+        expected_re = re.compile(r"\s*# EXPECTED(_MYPY)?(_RE)?(_TYPE)?: (.+)")
+        py_ver_re = re.compile(r"^#\s*PYTHON_VERSION\s?>=\s?(\d+\.\d+)")
+        with open(path) as file_:
+            current_assert_messages = []
+            for num, line in enumerate(file_, 1):
+                m = py_ver_re.match(line)
+                if m:
+                    major, _, minor = m.group(1).partition(".")
+                    if sys.version_info < (int(major), int(minor)):
+                        config.skip_test(
+                            "Requires python >= %s" % (m.group(1))
+                        )
+                    continue
+
+                m = expected_re.match(line)
+                if m:
+                    is_mypy = bool(m.group(1))
+                    is_re = bool(m.group(2))
+                    is_type = bool(m.group(3))
+
+                    expected_msg = re.sub(r"# noqa[:]? ?.*", "", m.group(4))
+                    if is_type:
+                        if not is_re:
+                            # the goal here is that we can cut-and-paste
+                            # from vscode -> pylance into the
+                            # EXPECTED_TYPE: line, then the test suite will
+                            # validate that line against what mypy produces
+                            expected_msg = re.sub(
+                                r"([\[\]])",
+                                lambda m: rf"\{m.group(0)}",
+                                expected_msg,
+                            )
+
+                            # note making sure preceding text matches
+                            # with a dot, so that an expect for "Select"
+                            # does not match "TypedSelect"
+                            expected_msg = re.sub(
+                                r"([\w_]+)",
+                                lambda m: rf"(?:.*\.)?{m.group(1)}\*?",
+                                expected_msg,
+                            )
+
+                            expected_msg = re.sub(
+                                "List", "builtins.list", expected_msg
+                            )
+
+                            expected_msg = re.sub(
+                                r"\b(int|str|float|bool)\b",
+                                lambda m: rf"builtins.{m.group(0)}\*?",
+                                expected_msg,
+                            )
+                            # expected_msg = re.sub(
+                            #     r"(Sequence|Tuple|List|Union)",
+                            #     lambda m: fr"typing.{m.group(0)}\*?",
+                            #     expected_msg,
+                            # )
+
+                        is_mypy = is_re = True
+                        expected_msg = f'Revealed type is "{expected_msg}"'
+
+                    if mypy_14 and util.py39:
+                        # use_lowercase_names, py39 and above
+                        # https://github.com/python/mypy/blob/304997bfb85200fb521ac727ee0ce3e6085e5278/mypy/options.py#L363  # noqa: E501
+
+                        # skip first character which could be capitalized
+                        # "List item x not found" type of message
+                        expected_msg = expected_msg[0] + re.sub(
+                            r"\b(List|Tuple|Dict|Set)\b"
+                            if is_type
+                            else r"\b(List|Tuple|Dict|Set|Type)\b",
+                            lambda m: m.group(1).lower(),
+                            expected_msg[1:],
+                        )
+
+                    if mypy_14 and util.py310:
+                        # use_or_syntax, py310 and above
+                        # https://github.com/python/mypy/blob/304997bfb85200fb521ac727ee0ce3e6085e5278/mypy/options.py#L368  # noqa: E501
+                        expected_msg = re.sub(
+                            r"Optional\[(.*?)\]",
+                            lambda m: f"{m.group(1)} | None",
+                            expected_msg,
+                        )
+                    current_assert_messages.append(
+                        (is_mypy, is_re, expected_msg.strip())
+                    )
+                elif current_assert_messages:
+                    expected_messages.extend(
+                        (num, is_mypy, is_re, expected_msg)
+                        for (
+                            is_mypy,
+                            is_re,
+                            expected_msg,
+                        ) in current_assert_messages
+                    )
+                    current_assert_messages[:] = []
+
+        return expected_messages
+
+    def _check_output(self, path, expected_messages, stdout, stderr, exitcode):
+        not_located = []
+        filename = os.path.basename(path)
+        if expected_messages:
+            # mypy 0.990 changed how return codes work, so don't assume a
+            # 1 or a 0 return code here, could be either depending on if
+            # errors were generated or not
+
+            output = []
+
+            raw_lines = stdout.split("\n")
+            while raw_lines:
+                e = raw_lines.pop(0)
+                if re.match(r".+\.py:\d+: error: .*", e):
+                    output.append(("error", e))
+                elif re.match(
+                    r".+\.py:\d+: note: +(?:Possible overload|def ).*", e
+                ):
+                    while raw_lines:
+                        ol = raw_lines.pop(0)
+                        if not re.match(r".+\.py:\d+: note: +def \[.*", ol):
+                            break
+                elif re.match(
+                    r".+\.py:\d+: note: .*(?:perhaps|suggestion)", e, re.I
+                ):
+                    pass
+                elif re.match(r".+\.py:\d+: note: .*", e):
+                    output.append(("note", e))
+
+            for num, is_mypy, is_re, msg in expected_messages:
+                msg = msg.replace("'", '"')
+                prefix = "[SQLAlchemy Mypy plugin] " if not is_mypy else ""
+                for idx, (typ, errmsg) in enumerate(output):
+                    if is_re:
+                        if re.match(
+                            rf".*{filename}\:{num}\: {typ}\: {prefix}{msg}",
+                            errmsg,
+                        ):
+                            break
+                    elif (
+                        f"{filename}:{num}: {typ}: {prefix}{msg}"
+                        in errmsg.replace("'", '"')
+                    ):
+                        break
+                else:
+                    not_located.append(msg)
+                    continue
+                del output[idx]
+
+            if not_located:
+                missing = "\n".join(not_located)
+                print("Couldn't locate expected messages:", missing, sep="\n")
+                if output:
+                    extra = "\n".join(msg for _, msg in output)
+                    print("Remaining messages:", extra, sep="\n")
+                assert False, "expected messages not found, see stdout"
+
+            if output:
+                print(f"{len(output)} messages from mypy were not consumed:")
+                print("\n".join(msg for _, msg in output))
+                assert False, "errors and/or notes remain, see stdout"
+
+        else:
+            if exitcode != 0:
+                print(stdout, stderr, sep="\n")
+
+            eq_(exitcode, 0, msg=stdout)
diff --git a/lib/sqlalchemy/testing/fixtures/orm.py b/lib/sqlalchemy/testing/fixtures/orm.py
new file mode 100644 (file)
index 0000000..da622c0
--- /dev/null
@@ -0,0 +1,227 @@
+# testing/fixtures/orm.py
+# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+# mypy: ignore-errors
+from __future__ import annotations
+
+from typing import Any
+
+import sqlalchemy as sa
+from .base import TestBase
+from .sql import TablesTest
+from .. import assertions
+from .. import config
+from .. import schema
+from ..entities import BasicEntity
+from ..entities import ComparableEntity
+from ..util import adict
+from ... import orm
+from ...orm import DeclarativeBase
+from ...orm import events as orm_events
+from ...orm import registry
+
+
+class ORMTest(TestBase):
+    @config.fixture
+    def fixture_session(self):
+        return fixture_session()
+
+
+class MappedTest(ORMTest, TablesTest, assertions.AssertsExecutionResults):
+    # 'once', 'each', None
+    run_setup_classes = "once"
+
+    # 'once', 'each', None
+    run_setup_mappers = "each"
+
+    classes: Any = None
+
+    @config.fixture(autouse=True, scope="class")
+    def _setup_tables_test_class(self):
+        cls = self.__class__
+        cls._init_class()
+
+        if cls.classes is None:
+            cls.classes = adict()
+
+        cls._setup_once_tables()
+        cls._setup_once_classes()
+        cls._setup_once_mappers()
+        cls._setup_once_inserts()
+
+        yield
+
+        cls._teardown_once_class()
+        cls._teardown_once_metadata_bind()
+
+    @config.fixture(autouse=True, scope="function")
+    def _setup_tables_test_instance(self):
+        self._setup_each_tables()
+        self._setup_each_classes()
+        self._setup_each_mappers()
+        self._setup_each_inserts()
+
+        yield
+
+        orm.session.close_all_sessions()
+        self._teardown_each_mappers()
+        self._teardown_each_classes()
+        self._teardown_each_tables()
+
+    @classmethod
+    def _teardown_once_class(cls):
+        cls.classes.clear()
+
+    @classmethod
+    def _setup_once_classes(cls):
+        if cls.run_setup_classes == "once":
+            cls._with_register_classes(cls.setup_classes)
+
+    @classmethod
+    def _setup_once_mappers(cls):
+        if cls.run_setup_mappers == "once":
+            cls.mapper_registry, cls.mapper = cls._generate_registry()
+            cls._with_register_classes(cls.setup_mappers)
+
+    def _setup_each_mappers(self):
+        if self.run_setup_mappers != "once":
+            (
+                self.__class__.mapper_registry,
+                self.__class__.mapper,
+            ) = self._generate_registry()
+
+        if self.run_setup_mappers == "each":
+            self._with_register_classes(self.setup_mappers)
+
+    def _setup_each_classes(self):
+        if self.run_setup_classes == "each":
+            self._with_register_classes(self.setup_classes)
+
+    @classmethod
+    def _generate_registry(cls):
+        decl = registry(metadata=cls._tables_metadata)
+        return decl, decl.map_imperatively
+
+    @classmethod
+    def _with_register_classes(cls, fn):
+        """Run a setup method, framing the operation with a Base class
+        that will catch new subclasses to be established within
+        the "classes" registry.
+
+        """
+        cls_registry = cls.classes
+
+        class _Base:
+            def __init_subclass__(cls) -> None:
+                assert cls_registry is not None
+                cls_registry[cls.__name__] = cls
+                super().__init_subclass__()
+
+        class Basic(BasicEntity, _Base):
+            pass
+
+        class Comparable(ComparableEntity, _Base):
+            pass
+
+        cls.Basic = Basic
+        cls.Comparable = Comparable
+        fn()
+
+    def _teardown_each_mappers(self):
+        # some tests create mappers in the test bodies
+        # and will define setup_mappers as None -
+        # clear mappers in any case
+        if self.run_setup_mappers != "once":
+            orm.clear_mappers()
+
+    def _teardown_each_classes(self):
+        if self.run_setup_classes != "once":
+            self.classes.clear()
+
+    @classmethod
+    def setup_classes(cls):
+        pass
+
+    @classmethod
+    def setup_mappers(cls):
+        pass
+
+
+class DeclarativeMappedTest(MappedTest):
+    run_setup_classes = "once"
+    run_setup_mappers = "once"
+
+    @classmethod
+    def _setup_once_tables(cls):
+        pass
+
+    @classmethod
+    def _with_register_classes(cls, fn):
+        cls_registry = cls.classes
+
+        class _DeclBase(DeclarativeBase):
+            __table_cls__ = schema.Table
+            metadata = cls._tables_metadata
+            type_annotation_map = {
+                str: sa.String().with_variant(
+                    sa.String(50), "mysql", "mariadb", "oracle"
+                )
+            }
+
+            def __init_subclass__(cls, **kw) -> None:
+                assert cls_registry is not None
+                cls_registry[cls.__name__] = cls
+                super().__init_subclass__(**kw)
+
+        cls.DeclarativeBasic = _DeclBase
+
+        # sets up cls.Basic which is helpful for things like composite
+        # classes
+        super()._with_register_classes(fn)
+
+        if cls._tables_metadata.tables and cls.run_create_tables:
+            cls._tables_metadata.create_all(config.db)
+
+
+class RemoveORMEventsGlobally:
+    @config.fixture(autouse=True)
+    def _remove_listeners(self):
+        yield
+        orm_events.MapperEvents._clear()
+        orm_events.InstanceEvents._clear()
+        orm_events.SessionEvents._clear()
+        orm_events.InstrumentationEvents._clear()
+        orm_events.QueryEvents._clear()
+
+
+_fixture_sessions = set()
+
+
+def fixture_session(**kw):
+    kw.setdefault("autoflush", True)
+    kw.setdefault("expire_on_commit", True)
+
+    bind = kw.pop("bind", config.db)
+
+    sess = orm.Session(bind, **kw)
+    _fixture_sessions.add(sess)
+    return sess
+
+
+def close_all_sessions():
+    # will close all still-referenced sessions
+    orm.close_all_sessions()
+    _fixture_sessions.clear()
+
+
+def stop_test_class_inside_fixtures(cls):
+    close_all_sessions()
+    orm.clear_mappers()
+
+
+def after_test():
+    if _fixture_sessions:
+        close_all_sessions()
diff --git a/lib/sqlalchemy/testing/fixtures/sql.py b/lib/sqlalchemy/testing/fixtures/sql.py
new file mode 100644 (file)
index 0000000..911dddd
--- /dev/null
@@ -0,0 +1,492 @@
+# testing/fixtures/sql.py
+# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+# mypy: ignore-errors
+from __future__ import annotations
+
+import itertools
+import random
+import re
+import sys
+
+import sqlalchemy as sa
+from .base import TestBase
+from .. import config
+from .. import mock
+from ..assertions import eq_
+from ..assertions import ne_
+from ..util import adict
+from ..util import drop_all_tables_from_metadata
+from ... import event
+from ... import util
+from ...schema import sort_tables_and_constraints
+from ...sql import visitors
+from ...sql.elements import ClauseElement
+
+
+class TablesTest(TestBase):
+    # 'once', None
+    run_setup_bind = "once"
+
+    # 'once', 'each', None
+    run_define_tables = "once"
+
+    # 'once', 'each', None
+    run_create_tables = "once"
+
+    # 'once', 'each', None
+    run_inserts = "each"
+
+    # 'each', None
+    run_deletes = "each"
+
+    # 'once', None
+    run_dispose_bind = None
+
+    bind = None
+    _tables_metadata = None
+    tables = None
+    other = None
+    sequences = None
+
+    @config.fixture(autouse=True, scope="class")
+    def _setup_tables_test_class(self):
+        cls = self.__class__
+        cls._init_class()
+
+        cls._setup_once_tables()
+
+        cls._setup_once_inserts()
+
+        yield
+
+        cls._teardown_once_metadata_bind()
+
+    @config.fixture(autouse=True, scope="function")
+    def _setup_tables_test_instance(self):
+        self._setup_each_tables()
+        self._setup_each_inserts()
+
+        yield
+
+        self._teardown_each_tables()
+
+    @property
+    def tables_test_metadata(self):
+        return self._tables_metadata
+
+    @classmethod
+    def _init_class(cls):
+        if cls.run_define_tables == "each":
+            if cls.run_create_tables == "once":
+                cls.run_create_tables = "each"
+            assert cls.run_inserts in ("each", None)
+
+        cls.other = adict()
+        cls.tables = adict()
+        cls.sequences = adict()
+
+        cls.bind = cls.setup_bind()
+        cls._tables_metadata = sa.MetaData()
+
+    @classmethod
+    def _setup_once_inserts(cls):
+        if cls.run_inserts == "once":
+            cls._load_fixtures()
+            with cls.bind.begin() as conn:
+                cls.insert_data(conn)
+
+    @classmethod
+    def _setup_once_tables(cls):
+        if cls.run_define_tables == "once":
+            cls.define_tables(cls._tables_metadata)
+            if cls.run_create_tables == "once":
+                cls._tables_metadata.create_all(cls.bind)
+            cls.tables.update(cls._tables_metadata.tables)
+            cls.sequences.update(cls._tables_metadata._sequences)
+
+    def _setup_each_tables(self):
+        if self.run_define_tables == "each":
+            self.define_tables(self._tables_metadata)
+            if self.run_create_tables == "each":
+                self._tables_metadata.create_all(self.bind)
+            self.tables.update(self._tables_metadata.tables)
+            self.sequences.update(self._tables_metadata._sequences)
+        elif self.run_create_tables == "each":
+            self._tables_metadata.create_all(self.bind)
+
+    def _setup_each_inserts(self):
+        if self.run_inserts == "each":
+            self._load_fixtures()
+            with self.bind.begin() as conn:
+                self.insert_data(conn)
+
+    def _teardown_each_tables(self):
+        if self.run_define_tables == "each":
+            self.tables.clear()
+            if self.run_create_tables == "each":
+                drop_all_tables_from_metadata(self._tables_metadata, self.bind)
+            self._tables_metadata.clear()
+        elif self.run_create_tables == "each":
+            drop_all_tables_from_metadata(self._tables_metadata, self.bind)
+
+        savepoints = getattr(config.requirements, "savepoints", False)
+        if savepoints:
+            savepoints = savepoints.enabled
+
+        # no need to run deletes if tables are recreated on setup
+        if (
+            self.run_define_tables != "each"
+            and self.run_create_tables != "each"
+            and self.run_deletes == "each"
+        ):
+            with self.bind.begin() as conn:
+                for table in reversed(
+                    [
+                        t
+                        for (t, fks) in sort_tables_and_constraints(
+                            self._tables_metadata.tables.values()
+                        )
+                        if t is not None
+                    ]
+                ):
+                    try:
+                        if savepoints:
+                            with conn.begin_nested():
+                                conn.execute(table.delete())
+                        else:
+                            conn.execute(table.delete())
+                    except sa.exc.DBAPIError as ex:
+                        print(
+                            ("Error emptying table %s: %r" % (table, ex)),
+                            file=sys.stderr,
+                        )
+
+    @classmethod
+    def _teardown_once_metadata_bind(cls):
+        if cls.run_create_tables:
+            drop_all_tables_from_metadata(cls._tables_metadata, cls.bind)
+
+        if cls.run_dispose_bind == "once":
+            cls.dispose_bind(cls.bind)
+
+        cls._tables_metadata.bind = None
+
+        if cls.run_setup_bind is not None:
+            cls.bind = None
+
+    @classmethod
+    def setup_bind(cls):
+        return config.db
+
+    @classmethod
+    def dispose_bind(cls, bind):
+        if hasattr(bind, "dispose"):
+            bind.dispose()
+        elif hasattr(bind, "close"):
+            bind.close()
+
+    @classmethod
+    def define_tables(cls, metadata):
+        pass
+
+    @classmethod
+    def fixtures(cls):
+        return {}
+
+    @classmethod
+    def insert_data(cls, connection):
+        pass
+
+    def sql_count_(self, count, fn):
+        self.assert_sql_count(self.bind, fn, count)
+
+    def sql_eq_(self, callable_, statements):
+        self.assert_sql(self.bind, callable_, statements)
+
+    @classmethod
+    def _load_fixtures(cls):
+        """Insert rows as represented by the fixtures() method."""
+        headers, rows = {}, {}
+        for table, data in cls.fixtures().items():
+            if len(data) < 2:
+                continue
+            if isinstance(table, str):
+                table = cls.tables[table]
+            headers[table] = data[0]
+            rows[table] = data[1:]
+        for table, fks in sort_tables_and_constraints(
+            cls._tables_metadata.tables.values()
+        ):
+            if table is None:
+                continue
+            if table not in headers:
+                continue
+            with cls.bind.begin() as conn:
+                conn.execute(
+                    table.insert(),
+                    [
+                        dict(zip(headers[table], column_values))
+                        for column_values in rows[table]
+                    ],
+                )
+
+
+class NoCache:
+    @config.fixture(autouse=True, scope="function")
+    def _disable_cache(self):
+        _cache = config.db._compiled_cache
+        config.db._compiled_cache = None
+        yield
+        config.db._compiled_cache = _cache
+
+
+class RemovesEvents:
+    @util.memoized_property
+    def _event_fns(self):
+        return set()
+
+    def event_listen(self, target, name, fn, **kw):
+        self._event_fns.add((target, name, fn))
+        event.listen(target, name, fn, **kw)
+
+    @config.fixture(autouse=True, scope="function")
+    def _remove_events(self):
+        yield
+        for key in self._event_fns:
+            event.remove(*key)
+
+
+class ComputedReflectionFixtureTest(TablesTest):
+    run_inserts = run_deletes = None
+
+    __backend__ = True
+    __requires__ = ("computed_columns", "table_reflection")
+
+    regexp = re.compile(r"[\[\]\(\)\s`'\"]*")
+
+    def normalize(self, text):
+        return self.regexp.sub("", text).lower()
+
+    @classmethod
+    def define_tables(cls, metadata):
+        from ... import Integer
+        from ... import testing
+        from ...schema import Column
+        from ...schema import Computed
+        from ...schema import Table
+
+        Table(
+            "computed_default_table",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            Column("normal", Integer),
+            Column("computed_col", Integer, Computed("normal + 42")),
+            Column("with_default", Integer, server_default="42"),
+        )
+
+        t = Table(
+            "computed_column_table",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            Column("normal", Integer),
+            Column("computed_no_flag", Integer, Computed("normal + 42")),
+        )
+
+        if testing.requires.schemas.enabled:
+            t2 = Table(
+                "computed_column_table",
+                metadata,
+                Column("id", Integer, primary_key=True),
+                Column("normal", Integer),
+                Column("computed_no_flag", Integer, Computed("normal / 42")),
+                schema=config.test_schema,
+            )
+
+        if testing.requires.computed_columns_virtual.enabled:
+            t.append_column(
+                Column(
+                    "computed_virtual",
+                    Integer,
+                    Computed("normal + 2", persisted=False),
+                )
+            )
+            if testing.requires.schemas.enabled:
+                t2.append_column(
+                    Column(
+                        "computed_virtual",
+                        Integer,
+                        Computed("normal / 2", persisted=False),
+                    )
+                )
+        if testing.requires.computed_columns_stored.enabled:
+            t.append_column(
+                Column(
+                    "computed_stored",
+                    Integer,
+                    Computed("normal - 42", persisted=True),
+                )
+            )
+            if testing.requires.schemas.enabled:
+                t2.append_column(
+                    Column(
+                        "computed_stored",
+                        Integer,
+                        Computed("normal * 42", persisted=True),
+                    )
+                )
+
+
+class CacheKeyFixture:
+    def _compare_equal(self, a, b, compare_values):
+        a_key = a._generate_cache_key()
+        b_key = b._generate_cache_key()
+
+        if a_key is None:
+            assert a._annotations.get("nocache")
+
+            assert b_key is None
+        else:
+            eq_(a_key.key, b_key.key)
+            eq_(hash(a_key.key), hash(b_key.key))
+
+            for a_param, b_param in zip(a_key.bindparams, b_key.bindparams):
+                assert a_param.compare(b_param, compare_values=compare_values)
+        return a_key, b_key
+
+    def _run_cache_key_fixture(self, fixture, compare_values):
+        case_a = fixture()
+        case_b = fixture()
+
+        for a, b in itertools.combinations_with_replacement(
+            range(len(case_a)), 2
+        ):
+            if a == b:
+                a_key, b_key = self._compare_equal(
+                    case_a[a], case_b[b], compare_values
+                )
+                if a_key is None:
+                    continue
+            else:
+                a_key = case_a[a]._generate_cache_key()
+                b_key = case_b[b]._generate_cache_key()
+
+                if a_key is None or b_key is None:
+                    if a_key is None:
+                        assert case_a[a]._annotations.get("nocache")
+                    if b_key is None:
+                        assert case_b[b]._annotations.get("nocache")
+                    continue
+
+                if a_key.key == b_key.key:
+                    for a_param, b_param in zip(
+                        a_key.bindparams, b_key.bindparams
+                    ):
+                        if not a_param.compare(
+                            b_param, compare_values=compare_values
+                        ):
+                            break
+                    else:
+                        # this fails unconditionally since we could not
+                        # find bound parameter values that differed.
+                        # Usually we intended to get two distinct keys here
+                        # so the failure will be more descriptive using the
+                        # ne_() assertion.
+                        ne_(a_key.key, b_key.key)
+                else:
+                    ne_(a_key.key, b_key.key)
+
+            # ClauseElement-specific test to ensure the cache key
+            # collected all the bound parameters that aren't marked
+            # as "literal execute"
+            if isinstance(case_a[a], ClauseElement) and isinstance(
+                case_b[b], ClauseElement
+            ):
+                assert_a_params = []
+                assert_b_params = []
+
+                for elem in visitors.iterate(case_a[a]):
+                    if elem.__visit_name__ == "bindparam":
+                        assert_a_params.append(elem)
+
+                for elem in visitors.iterate(case_b[b]):
+                    if elem.__visit_name__ == "bindparam":
+                        assert_b_params.append(elem)
+
+                # note we're asserting the order of the params as well as
+                # if there are dupes or not.  ordering has to be
+                # deterministic and matches what a traversal would provide.
+                eq_(
+                    sorted(a_key.bindparams, key=lambda b: b.key),
+                    sorted(
+                        util.unique_list(assert_a_params), key=lambda b: b.key
+                    ),
+                )
+                eq_(
+                    sorted(b_key.bindparams, key=lambda b: b.key),
+                    sorted(
+                        util.unique_list(assert_b_params), key=lambda b: b.key
+                    ),
+                )
+
+    def _run_cache_key_equal_fixture(self, fixture, compare_values):
+        case_a = fixture()
+        case_b = fixture()
+
+        for a, b in itertools.combinations_with_replacement(
+            range(len(case_a)), 2
+        ):
+            self._compare_equal(case_a[a], case_b[b], compare_values)
+
+
+def insertmanyvalues_fixture(
+    connection, randomize_rows=False, warn_on_downgraded=False
+):
+    dialect = connection.dialect
+    orig_dialect = dialect._deliver_insertmanyvalues_batches
+    orig_conn = connection._exec_insertmany_context
+
+    class RandomCursor:
+        __slots__ = ("cursor",)
+
+        def __init__(self, cursor):
+            self.cursor = cursor
+
+        # only this method is called by the deliver method.
+        # by not having the other methods we assert that those aren't being
+        # used
+
+        def fetchall(self):
+            rows = self.cursor.fetchall()
+            rows = list(rows)
+            random.shuffle(rows)
+            return rows
+
+    def _deliver_insertmanyvalues_batches(
+        cursor, statement, parameters, generic_setinputsizes, context
+    ):
+        if randomize_rows:
+            cursor = RandomCursor(cursor)
+        for batch in orig_dialect(
+            cursor, statement, parameters, generic_setinputsizes, context
+        ):
+            if warn_on_downgraded and batch.is_downgraded:
+                util.warn("Batches were downgraded for sorted INSERT")
+
+            yield batch
+
+    def _exec_insertmany_context(
+        dialect,
+        context,
+    ):
+        with mock.patch.object(
+            dialect,
+            "_deliver_insertmanyvalues_batches",
+            new=_deliver_insertmanyvalues_batches,
+        ):
+            return orig_conn(dialect, context)
+
+    connection._exec_insertmany_context = _exec_insertmany_context
index b0823983fe31d4f86f9687f0d26b6dfbb9b10d4e..89155a841900cc1c2b8bc8ae51a534648ff510b5 100644 (file)
@@ -13,20 +13,20 @@ unpickling.
 
 from __future__ import annotations
 
-from . import fixtures
+from .entities import ComparableEntity
 from ..schema import Column
 from ..types import String
 
 
-class User(fixtures.ComparableEntity):
+class User(ComparableEntity):
     pass
 
 
-class Order(fixtures.ComparableEntity):
+class Order(ComparableEntity):
     pass
 
 
-class Dingaling(fixtures.ComparableEntity):
+class Dingaling(ComparableEntity):
     pass
 
 
@@ -34,20 +34,20 @@ class EmailUser(User):
     pass
 
 
-class Address(fixtures.ComparableEntity):
+class Address(ComparableEntity):
     pass
 
 
 # TODO: these are kind of arbitrary....
-class Child1(fixtures.ComparableEntity):
+class Child1(ComparableEntity):
     pass
 
 
-class Child2(fixtures.ComparableEntity):
+class Child2(ComparableEntity):
     pass
 
 
-class Parent(fixtures.ComparableEntity):
+class Parent(ComparableEntity):
     pass
 
 
@@ -61,7 +61,7 @@ class Mixin:
     email_address = Column(String)
 
 
-class AddressWMixin(Mixin, fixtures.ComparableEntity):
+class AddressWMixin(Mixin, ComparableEntity):
     pass
 
 
index cff53ea727f580b689544d0792d0c3239b0652e3..393070d08c24d49c55fbfb551f0bfcfbaab96a69 100644 (file)
@@ -10,6 +10,7 @@
 from __future__ import annotations
 
 import abc
+from argparse import Namespace
 import configparser
 import logging
 import os
@@ -51,7 +52,7 @@ file_config = None
 logging = None
 include_tags = set()
 exclude_tags = set()
-options = None
+options: Namespace = None  # type: ignore
 
 
 def setup_options(make_option):
index 17bd038d38e537140499921a941c1cb3f84f5ea8..a676e7e28d04514e85ba55d1195767dddaf6ec7f 100644 (file)
@@ -11,13 +11,15 @@ import operator
 import os
 import re
 import sys
+from typing import TYPE_CHECKING
 import uuid
 
 import pytest
 
 try:
     # installed by bootstrap.py
-    import sqla_plugin_base as plugin_base
+    if not TYPE_CHECKING:
+        import sqla_plugin_base as plugin_base
 except ImportError:
     # assume we're a package, use traditional import
     from . import plugin_base
index ccd06716e0917c191909bde7ac264b2d5d932af0..cf24b43a9693ec9e7b8f305308bb8bd10c03fd55 100644 (file)
@@ -216,22 +216,21 @@ def provide_metadata(fn, *args, **kw):
         # we have to hardcode some of that cleanup ahead of time.
 
         # close ORM sessions
-        fixtures._close_all_sessions()
+        fixtures.close_all_sessions()
 
         # integrate with the "connection" fixture as there are many
         # tests where it is used along with provide_metadata
-        if fixtures._connection_fixture_connection:
+        cfc = fixtures.base._connection_fixture_connection
+        if cfc:
             # TODO: this warning can be used to find all the places
             # this is used with connection fixture
             # warn("mixing legacy provide metadata with connection fixture")
-            drop_all_tables_from_metadata(
-                metadata, fixtures._connection_fixture_connection
-            )
+            drop_all_tables_from_metadata(metadata, cfc)
             # as the provide_metadata fixture is often used with "testing.db",
             # when we do the drop we have to commit the transaction so that
             # the DB is actually updated as the CREATE would have been
             # committed
-            fixtures._connection_fixture_connection.get_transaction().commit()
+            cfc.get_transaction().commit()
         else:
             drop_all_tables_from_metadata(metadata, config.db)
         self.metadata = prev_meta
index 4857a19f1494008af5931e607d1d33fa5825cf4e..ed5a4a92e5ab667318ac2e3c8ef7014d51dcdc8a 100644 (file)
--- a/setup.cfg
+++ b/setup.cfg
@@ -113,7 +113,7 @@ application-import-names = sqlalchemy,test
 per-file-ignores =
     **/__init__.py:F401
     test/*:FA100
-    test/ext/mypy/plain_files/*:F821,E501,FA100
+    test/typing/plain_files/*:F821,E501,FA100
     test/ext/mypy/plugin_files/*:F821,E501,FA100
     lib/sqlalchemy/events.py:F401
     lib/sqlalchemy/schema.py:F401
index 047853e675321446c72a6e3c03ed49f090373907..fc6be0f0960b283ffc33e9889b01f40e45821a27 100644 (file)
@@ -46,6 +46,7 @@ from sqlalchemy.testing import engines
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import pickleable
+from sqlalchemy.testing.entities import ComparableEntity
 from sqlalchemy.testing.fixtures import fixture_session
 from sqlalchemy.testing.schema import Column
 from sqlalchemy.testing.schema import Table
@@ -53,11 +54,11 @@ from sqlalchemy.testing.util import gc_collect
 from ..orm import _fixtures
 
 
-class A(fixtures.ComparableEntity):
+class A(ComparableEntity):
     pass
 
 
-class B(fixtures.ComparableEntity):
+class B(ComparableEntity):
     pass
 
 
@@ -916,7 +917,7 @@ class MemUsageWBackendTest(fixtures.MappedTest, EnsureZeroed):
 
         @profile_memory()
         def go():
-            class A(fixtures.ComparableEntity):
+            class A(ComparableEntity):
                 pass
 
             class B(A):
@@ -997,10 +998,10 @@ class MemUsageWBackendTest(fixtures.MappedTest, EnsureZeroed):
 
         @profile_memory()
         def go():
-            class A(fixtures.ComparableEntity):
+            class A(ComparableEntity):
                 pass
 
-            class B(fixtures.ComparableEntity):
+            class B(ComparableEntity):
                 pass
 
             self.mapper_registry.map_imperatively(
index 1767f2f4e34adf878227eddad77468c5c548eaf7..228489349a15c339cc285ff6836d79254255c9be 100644 (file)
@@ -41,6 +41,7 @@ from sqlalchemy.testing import is_true
 from sqlalchemy.testing import mock
 from sqlalchemy.testing.assertions import expect_deprecated
 from sqlalchemy.testing.assertions import is_false
+from sqlalchemy.testing.entities import ComparableEntity
 from sqlalchemy.testing.provision import normalize_sequence
 from .test_engine_py3k import AsyncFixture as _AsyncFixture
 from ...orm import _fixtures
@@ -1023,7 +1024,7 @@ class AsyncAttrsTest(
     def decl_base(self, metadata):
         _md = metadata
 
-        class Base(fixtures.ComparableEntity, AsyncAttrs, DeclarativeBase):
+        class Base(ComparableEntity, AsyncAttrs, DeclarativeBase):
             metadata = _md
             type_annotation_map = {
                 str: String().with_variant(
index 62e15a124be7ac9b926fe617ffff1d7be57fdd3b..d6d059cbef9657550fe69298fc27319fa6db925d 100644 (file)
@@ -21,6 +21,7 @@ from sqlalchemy.testing import eq_
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import mock
 from sqlalchemy.testing.assertions import expect_raises_message
+from sqlalchemy.testing.entities import ComparableEntity
 from sqlalchemy.testing.fixtures import fixture_session
 from sqlalchemy.testing.fixtures import RemoveORMEventsGlobally
 from sqlalchemy.testing.schema import Column
@@ -144,7 +145,7 @@ class ConcreteInhTest(
             "punion",
         )
 
-        class Employee(Base, fixtures.ComparableEntity):
+        class Employee(Base, ComparableEntity):
             __table__ = punion
             __mapper_args__ = {"polymorphic_on": punion.c.type}
 
@@ -174,7 +175,7 @@ class ConcreteInhTest(
     def test_concrete_inline_non_polymorphic(self):
         """test the example from the declarative docs."""
 
-        class Employee(Base, fixtures.ComparableEntity):
+        class Employee(Base, ComparableEntity):
             __tablename__ = "people"
             id = Column(
                 Integer, primary_key=True, test_needs_autoincrement=True
@@ -211,7 +212,7 @@ class ConcreteInhTest(
         self._roundtrip(Employee, Manager, Engineer, Boss, polymorphic=False)
 
     def test_abstract_concrete_base_didnt_configure(self):
-        class Employee(AbstractConcreteBase, Base, fixtures.ComparableEntity):
+        class Employee(AbstractConcreteBase, Base, ComparableEntity):
             strict_attrs = True
 
         assert_raises_message(
@@ -269,7 +270,7 @@ class ConcreteInhTest(
         )
 
     def test_abstract_concrete_extension(self):
-        class Employee(AbstractConcreteBase, Base, fixtures.ComparableEntity):
+        class Employee(AbstractConcreteBase, Base, ComparableEntity):
             name = Column(String(50))
 
         class Manager(Employee):
@@ -321,7 +322,7 @@ class ConcreteInhTest(
     def test_abstract_concrete_extension_descriptor_refresh(
         self, use_strict_attrs
     ):
-        class Employee(AbstractConcreteBase, Base, fixtures.ComparableEntity):
+        class Employee(AbstractConcreteBase, Base, ComparableEntity):
             strict_attrs = use_strict_attrs
 
             @declared_attr
@@ -378,7 +379,7 @@ class ConcreteInhTest(
         eq_(e1.name, "d")
 
     def test_concrete_extension(self):
-        class Employee(ConcreteBase, Base, fixtures.ComparableEntity):
+        class Employee(ConcreteBase, Base, ComparableEntity):
             __tablename__ = "employee"
             employee_id = Column(
                 Integer, primary_key=True, test_needs_autoincrement=True
@@ -428,7 +429,7 @@ class ConcreteInhTest(
         self._roundtrip(Employee, Manager, Engineer, Boss)
 
     def test_concrete_extension_warn_for_overlap(self):
-        class Employee(ConcreteBase, Base, fixtures.ComparableEntity):
+        class Employee(ConcreteBase, Base, ComparableEntity):
             __tablename__ = "employee"
 
             employee_id = Column(
@@ -463,7 +464,7 @@ class ConcreteInhTest(
             configure_mappers()
 
     def test_concrete_extension_warn_concrete_disc_resolves_overlap(self):
-        class Employee(ConcreteBase, Base, fixtures.ComparableEntity):
+        class Employee(ConcreteBase, Base, ComparableEntity):
             _concrete_discriminator_name = "_type"
 
             __tablename__ = "employee"
@@ -562,7 +563,7 @@ class ConcreteInhTest(
         )
 
     def test_abs_concrete_extension_warn_for_overlap(self):
-        class Employee(AbstractConcreteBase, Base, fixtures.ComparableEntity):
+        class Employee(AbstractConcreteBase, Base, ComparableEntity):
             name = Column(String(50))
             __mapper_args__ = {
                 "polymorphic_identity": "employee",
@@ -595,7 +596,7 @@ class ConcreteInhTest(
     def test_abs_concrete_extension_warn_concrete_disc_resolves_overlap(
         self, use_strict_attrs
     ):
-        class Employee(AbstractConcreteBase, Base, fixtures.ComparableEntity):
+        class Employee(AbstractConcreteBase, Base, ComparableEntity):
             strict_attrs = use_strict_attrs
             _concrete_discriminator_name = "_type"
 
@@ -671,7 +672,7 @@ class ConcreteInhTest(
         assert PolyTest.__mapper__.polymorphic_on is Test.__table__.c.type
 
     def test_ok_to_override_type_from_abstract(self):
-        class Employee(AbstractConcreteBase, Base, fixtures.ComparableEntity):
+        class Employee(AbstractConcreteBase, Base, ComparableEntity):
             name = Column(String(50))
 
         class Manager(Employee):
@@ -734,7 +735,7 @@ class ConcreteExtensionConfigTest(
     __dialect__ = "default"
 
     def test_classreg_setup(self):
-        class A(Base, fixtures.ComparableEntity):
+        class A(Base, ComparableEntity):
             __tablename__ = "a"
             id = Column(
                 Integer, primary_key=True, test_needs_autoincrement=True
@@ -744,7 +745,7 @@ class ConcreteExtensionConfigTest(
                 "BC", primaryjoin="BC.a_id == A.id", collection_class=set
             )
 
-        class BC(AbstractConcreteBase, Base, fixtures.ComparableEntity):
+        class BC(AbstractConcreteBase, Base, ComparableEntity):
             a_id = Column(Integer, ForeignKey("a.id"))
 
         class B(BC):
index 103d3d07ffd5a6d5bc66898975706cbb7b8364be..4f81d7c470605a0f9a489a80f53b279d1e50ea6b 100644 (file)
@@ -22,6 +22,7 @@ from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import expect_raises_message
 from sqlalchemy.testing import fixtures
+from sqlalchemy.testing.entities import ComparableEntity
 from sqlalchemy.testing.fixtures import fixture_session
 from sqlalchemy.testing.schema import Column
 from sqlalchemy.testing.schema import Table
@@ -74,11 +75,11 @@ class DeferredReflectPKFKTest(DeferredReflectBase):
         )
 
     def test_pk_fk(self):
-        class B(DeferredReflection, fixtures.ComparableEntity, Base):
+        class B(DeferredReflection, ComparableEntity, Base):
             __tablename__ = "b"
             a = relationship("A")
 
-        class A(DeferredReflection, fixtures.ComparableEntity, Base):
+        class A(DeferredReflection, ComparableEntity, Base):
             __tablename__ = "a"
 
         DeferredReflection.prepare(testing.db)
@@ -133,11 +134,11 @@ class DeferredReflectionTest(testing.AssertsCompiledSQL, DeferredReflectBase):
             eq_(a1.user, User(name="u1"))
 
     def test_exception_prepare_not_called(self):
-        class User(DeferredReflection, fixtures.ComparableEntity, Base):
+        class User(DeferredReflection, ComparableEntity, Base):
             __tablename__ = "users"
             addresses = relationship("Address", backref="user")
 
-        class Address(DeferredReflection, fixtures.ComparableEntity, Base):
+        class Address(DeferredReflection, ComparableEntity, Base):
             __tablename__ = "addresses"
 
         assert_raises_message(
@@ -152,11 +153,11 @@ class DeferredReflectionTest(testing.AssertsCompiledSQL, DeferredReflectBase):
 
     @testing.variation("bind", ["engine", "connection", "raise_"])
     def test_basic_deferred(self, bind):
-        class User(DeferredReflection, fixtures.ComparableEntity, Base):
+        class User(DeferredReflection, ComparableEntity, Base):
             __tablename__ = "users"
             addresses = relationship("Address", backref="user")
 
-        class Address(DeferredReflection, fixtures.ComparableEntity, Base):
+        class Address(DeferredReflection, ComparableEntity, Base):
             __tablename__ = "addresses"
 
         if bind.engine:
@@ -218,11 +219,11 @@ class DeferredReflectionTest(testing.AssertsCompiledSQL, DeferredReflectBase):
         class OtherDefBase(DeferredReflection, Base):
             __abstract__ = True
 
-        class User(fixtures.ComparableEntity, DefBase):
+        class User(ComparableEntity, DefBase):
             __tablename__ = "users"
             addresses = relationship("Address", backref="user")
 
-        class Address(fixtures.ComparableEntity, DefBase):
+        class Address(ComparableEntity, DefBase):
             __tablename__ = "addresses"
 
         class Fake(OtherDefBase):
@@ -232,11 +233,11 @@ class DeferredReflectionTest(testing.AssertsCompiledSQL, DeferredReflectBase):
         self._roundtrip()
 
     def test_redefine_fk_double(self):
-        class User(DeferredReflection, fixtures.ComparableEntity, Base):
+        class User(DeferredReflection, ComparableEntity, Base):
             __tablename__ = "users"
             addresses = relationship("Address", backref="user")
 
-        class Address(DeferredReflection, fixtures.ComparableEntity, Base):
+        class Address(DeferredReflection, ComparableEntity, Base):
             __tablename__ = "addresses"
             user_id = Column(Integer, ForeignKey("users.id"))
 
@@ -247,7 +248,7 @@ class DeferredReflectionTest(testing.AssertsCompiledSQL, DeferredReflectBase):
         """test that __mapper_args__ is not called until *after*
         table reflection"""
 
-        class User(DeferredReflection, fixtures.ComparableEntity, Base):
+        class User(DeferredReflection, ComparableEntity, Base):
             __tablename__ = "users"
 
             @declared_attr
@@ -277,10 +278,10 @@ class DeferredReflectionTest(testing.AssertsCompiledSQL, DeferredReflectBase):
 
     @testing.requires.predictable_gc
     def test_cls_not_strong_ref(self):
-        class User(DeferredReflection, fixtures.ComparableEntity, Base):
+        class User(DeferredReflection, ComparableEntity, Base):
             __tablename__ = "users"
 
-        class Address(DeferredReflection, fixtures.ComparableEntity, Base):
+        class Address(DeferredReflection, ComparableEntity, Base):
             __tablename__ = "addresses"
 
         eq_(len(_DeferredMapperConfig._configs), 2)
@@ -340,26 +341,26 @@ class DeferredSecondaryReflectionTest(DeferredReflectBase):
             )
 
     def test_string_resolution(self):
-        class User(DeferredReflection, fixtures.ComparableEntity, Base):
+        class User(DeferredReflection, ComparableEntity, Base):
             __tablename__ = "users"
 
             items = relationship("Item", secondary="user_items")
 
-        class Item(DeferredReflection, fixtures.ComparableEntity, Base):
+        class Item(DeferredReflection, ComparableEntity, Base):
             __tablename__ = "items"
 
         DeferredReflection.prepare(testing.db)
         self._roundtrip()
 
     def test_table_resolution(self):
-        class User(DeferredReflection, fixtures.ComparableEntity, Base):
+        class User(DeferredReflection, ComparableEntity, Base):
             __tablename__ = "users"
 
             items = relationship(
                 "Item", secondary=Table("user_items", Base.metadata)
             )
 
-        class Item(DeferredReflection, fixtures.ComparableEntity, Base):
+        class Item(DeferredReflection, ComparableEntity, Base):
             __tablename__ = "items"
 
         DeferredReflection.prepare(testing.db)
@@ -408,7 +409,7 @@ class DeferredSingleInhReflectionTest(DeferredInhReflectBase):
         )
 
     def test_basic(self, decl_base):
-        class Foo(DeferredReflection, fixtures.ComparableEntity, decl_base):
+        class Foo(DeferredReflection, ComparableEntity, decl_base):
             __tablename__ = "foo"
             __mapper_args__ = {
                 "polymorphic_on": "type",
@@ -422,7 +423,7 @@ class DeferredSingleInhReflectionTest(DeferredInhReflectBase):
         self._roundtrip()
 
     def test_add_subclass_column(self, decl_base):
-        class Foo(DeferredReflection, fixtures.ComparableEntity, decl_base):
+        class Foo(DeferredReflection, ComparableEntity, decl_base):
             __tablename__ = "foo"
             __mapper_args__ = {
                 "polymorphic_on": "type",
@@ -437,7 +438,7 @@ class DeferredSingleInhReflectionTest(DeferredInhReflectBase):
         self._roundtrip()
 
     def test_add_subclass_mapped_column(self, decl_base):
-        class Foo(DeferredReflection, fixtures.ComparableEntity, decl_base):
+        class Foo(DeferredReflection, ComparableEntity, decl_base):
             __tablename__ = "foo"
             __mapper_args__ = {
                 "polymorphic_on": "type",
@@ -452,7 +453,7 @@ class DeferredSingleInhReflectionTest(DeferredInhReflectBase):
         self._roundtrip()
 
     def test_subclass_mapped_column_no_existing(self, decl_base):
-        class Foo(DeferredReflection, fixtures.ComparableEntity, decl_base):
+        class Foo(DeferredReflection, ComparableEntity, decl_base):
             __tablename__ = "foo"
             __mapper_args__ = {
                 "polymorphic_on": "type",
@@ -469,7 +470,7 @@ class DeferredSingleInhReflectionTest(DeferredInhReflectBase):
                 bar_data: Mapped[str] = mapped_column(use_existing_column=True)
 
     def test_add_pk_column(self, decl_base):
-        class Foo(DeferredReflection, fixtures.ComparableEntity, decl_base):
+        class Foo(DeferredReflection, ComparableEntity, decl_base):
             __tablename__ = "foo"
             __mapper_args__ = {
                 "polymorphic_on": "type",
@@ -484,7 +485,7 @@ class DeferredSingleInhReflectionTest(DeferredInhReflectBase):
         self._roundtrip()
 
     def test_add_pk_mapped_column(self, decl_base):
-        class Foo(DeferredReflection, fixtures.ComparableEntity, decl_base):
+        class Foo(DeferredReflection, ComparableEntity, decl_base):
             __tablename__ = "foo"
             __mapper_args__ = {
                 "polymorphic_on": "type",
@@ -521,7 +522,7 @@ class DeferredJoinedInhReflectionTest(DeferredInhReflectBase):
         )
 
     def test_basic(self):
-        class Foo(DeferredReflection, fixtures.ComparableEntity, Base):
+        class Foo(DeferredReflection, ComparableEntity, Base):
             __tablename__ = "foo"
             __mapper_args__ = {
                 "polymorphic_on": "type",
@@ -536,7 +537,7 @@ class DeferredJoinedInhReflectionTest(DeferredInhReflectBase):
         self._roundtrip()
 
     def test_add_subclass_column(self):
-        class Foo(DeferredReflection, fixtures.ComparableEntity, Base):
+        class Foo(DeferredReflection, ComparableEntity, Base):
             __tablename__ = "foo"
             __mapper_args__ = {
                 "polymorphic_on": "type",
@@ -552,7 +553,7 @@ class DeferredJoinedInhReflectionTest(DeferredInhReflectBase):
         self._roundtrip()
 
     def test_add_pk_column(self):
-        class Foo(DeferredReflection, fixtures.ComparableEntity, Base):
+        class Foo(DeferredReflection, ComparableEntity, Base):
             __tablename__ = "foo"
             __mapper_args__ = {
                 "polymorphic_on": "type",
@@ -568,7 +569,7 @@ class DeferredJoinedInhReflectionTest(DeferredInhReflectBase):
         self._roundtrip()
 
     def test_add_fk_pk_column(self):
-        class Foo(DeferredReflection, fixtures.ComparableEntity, Base):
+        class Foo(DeferredReflection, ComparableEntity, Base):
             __tablename__ = "foo"
             __mapper_args__ = {
                 "polymorphic_on": "type",
diff --git a/test/ext/mypy/plain_files/core_ddl.py b/test/ext/mypy/plain_files/core_ddl.py
deleted file mode 100644 (file)
index 673a90e..0000000
+++ /dev/null
@@ -1,43 +0,0 @@
-from sqlalchemy import CheckConstraint
-from sqlalchemy import Column
-from sqlalchemy import DateTime
-from sqlalchemy import ForeignKey
-from sqlalchemy import Index
-from sqlalchemy import Integer
-from sqlalchemy import MetaData
-from sqlalchemy import PrimaryKeyConstraint
-from sqlalchemy import String
-from sqlalchemy import Table
-
-
-m = MetaData()
-
-
-t1 = Table(
-    "t1",
-    m,
-    Column("id", Integer, primary_key=True),
-    Column("data", String),
-    Column("data2", String(50)),
-    Column("timestamp", DateTime()),
-    Index(None, "data2"),
-)
-
-t2 = Table(
-    "t2",
-    m,
-    Column("t1id", ForeignKey("t1.id")),
-    Column("q", Integer, CheckConstraint("q > 5")),
-)
-
-t3 = Table(
-    "t3",
-    m,
-    Column("x", Integer),
-    Column("y", Integer),
-    Column("t1id", ForeignKey(t1.c.id)),
-    PrimaryKeyConstraint("x", "y"),
-)
-
-# cols w/ no name or type, used by declarative
-c1: Column[int] = Column(ForeignKey(t3.c.x))
index a9cc1eb336351c67781392fe9221d829f47e4d0c..f1b36ac52bb836393ec0d78d800651a63970bfb9 100644 (file)
@@ -1,35 +1,11 @@
 import os
-import re
 import shutil
-import sys
-import tempfile
-from typing import Any
-from typing import cast
-from typing import List
-from typing import Tuple
 
 from sqlalchemy import testing
-from sqlalchemy import util
-from sqlalchemy.testing import config
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import fixtures
 
 
-def _file_combinations(dirname):
-    path = os.path.join(os.path.dirname(__file__), dirname)
-    files = []
-    for f in os.listdir(path):
-        if f.endswith(".py"):
-            files.append(os.path.join(os.path.dirname(__file__), dirname, f))
-
-    for extra_dir in testing.config.options.mypy_extra_test_paths:
-        if extra_dir and os.path.isdir(extra_dir):
-            for f in os.listdir(os.path.join(extra_dir, dirname)):
-                if f.endswith(".py"):
-                    files.append(os.path.join(extra_dir, dirname, f))
-    return files
-
-
 def _incremental_dirs():
     path = os.path.join(os.path.dirname(__file__), "incremental")
     files = []
@@ -47,99 +23,9 @@ def _incremental_dirs():
     return files
 
 
-@testing.add_to_marker.mypy
-class MypyPluginTest(fixtures.TestBase):
-    __tags__ = ("mypy",)
-    __requires__ = ("no_sqlalchemy2_stubs",)
-
-    @testing.fixture(scope="function")
-    def per_func_cachedir(self):
-        yield from self._cachedir()
-
-    @testing.fixture(scope="class")
-    def cachedir(self):
-        yield from self._cachedir()
-
-    def _cachedir(self):
-        # as of mypy 0.971 i think we need to keep mypy_path empty
-        mypy_path = ""
-
-        with tempfile.TemporaryDirectory() as cachedir:
-            with open(
-                os.path.join(cachedir, "sqla_mypy_config.cfg"), "w"
-            ) as config_file:
-                config_file.write(
-                    f"""
-                    [mypy]\n
-                    plugins = sqlalchemy.ext.mypy.plugin\n
-                    show_error_codes = True\n
-                    {mypy_path}
-                    disable_error_code = no-untyped-call
-
-                    [mypy-sqlalchemy.*]
-                    ignore_errors = True
-
-                    """
-                )
-            with open(
-                os.path.join(cachedir, "plain_mypy_config.cfg"), "w"
-            ) as config_file:
-                config_file.write(
-                    f"""
-                    [mypy]\n
-                    show_error_codes = True\n
-                    {mypy_path}
-                    disable_error_code = var-annotated,no-untyped-call
-                    [mypy-sqlalchemy.*]
-                    ignore_errors = True
-
-                    """
-                )
-            yield cachedir
-
-    @testing.fixture()
-    def mypy_runner(self, cachedir):
-        from mypy import api
-
-        def run(path, use_plugin=True, incremental=False):
-            args = [
-                "--strict",
-                "--raise-exceptions",
-                "--cache-dir",
-                cachedir,
-                "--config-file",
-                os.path.join(
-                    cachedir,
-                    "sqla_mypy_config.cfg"
-                    if use_plugin
-                    else "plain_mypy_config.cfg",
-                ),
-            ]
-
-            # mypy as of 0.990 is more aggressively blocking messaging
-            # for paths that are in sys.path, and as pytest puts currdir,
-            # test/ etc in sys.path, just copy the source file to the
-            # tempdir we are working in so that we don't have to try to
-            # manipulate sys.path and/or guess what mypy is doing
-            filename = os.path.basename(path)
-            test_program = os.path.join(cachedir, filename)
-            shutil.copyfile(path, test_program)
-            args.append(test_program)
-
-            # I set this locally but for the suite here needs to be
-            # disabled
-            os.environ.pop("MYPY_FORCE_COLOR", None)
-
-            result = api.run(args)
-            return result
-
-        return run
-
+class MypyPluginTest(fixtures.MypyTest):
     @testing.combinations(
-        *[
-            (pathname, testing.exclusions.closed())
-            for pathname in _incremental_dirs()
-        ],
+        *[(pathname) for pathname in _incremental_dirs()],
         argnames="pathname",
     )
     @testing.requires.patch_library
@@ -175,7 +61,7 @@ class MypyPluginTest(fixtures.TestBase):
             result = mypy_runner(
                 dest,
                 use_plugin=True,
-                incremental=True,
+                use_cachedir=cachedir,
             )
             eq_(
                 result[2],
@@ -186,191 +72,11 @@ class MypyPluginTest(fixtures.TestBase):
 
     @testing.combinations(
         *(
-            cast(
-                List[Tuple[Any, ...]],
-                [
-                    ("w_plugin", os.path.basename(path), path, True)
-                    for path in _file_combinations("plugin_files")
-                ],
-            )
-            + cast(
-                List[Tuple[Any, ...]],
-                [
-                    ("plain", os.path.basename(path), path, False)
-                    for path in _file_combinations("plain_files")
-                ],
-            )
+            (os.path.basename(path), path, True)
+            for path in fixtures.MypyTest.file_combinations("plugin_files")
         ),
-        argnames="filename,path,use_plugin",
-        id_="isaa",
+        argnames="path",
+        id_="ia",
     )
-    def test_files(self, mypy_runner, filename, path, use_plugin):
-        expected_messages = []
-        expected_re = re.compile(r"\s*# EXPECTED(_MYPY)?(_RE)?(_TYPE)?: (.+)")
-        py_ver_re = re.compile(r"^#\s*PYTHON_VERSION\s?>=\s?(\d+\.\d+)")
-
-        from sqlalchemy.ext.mypy.util import mypy_14
-
-        with open(path) as file_:
-            current_assert_messages = []
-            for num, line in enumerate(file_, 1):
-                m = py_ver_re.match(line)
-                if m:
-                    major, _, minor = m.group(1).partition(".")
-                    if sys.version_info < (int(major), int(minor)):
-                        config.skip_test(
-                            "Requires python >= %s" % (m.group(1))
-                        )
-                    continue
-                if line.startswith("# NOPLUGINS"):
-                    use_plugin = False
-                    continue
-
-                m = expected_re.match(line)
-                if m:
-                    is_mypy = bool(m.group(1))
-                    is_re = bool(m.group(2))
-                    is_type = bool(m.group(3))
-
-                    expected_msg = re.sub(r"# noqa[:]? ?.*", "", m.group(4))
-
-                    if is_type:
-                        if not is_re:
-                            # the goal here is that we can cut-and-paste
-                            # from vscode -> pylance into the
-                            # EXPECTED_TYPE: line, then the test suite will
-                            # validate that line against what mypy produces
-                            expected_msg = re.sub(
-                                r"([\[\]])",
-                                lambda m: rf"\{m.group(0)}",
-                                expected_msg,
-                            )
-
-                            # note making sure preceding text matches
-                            # with a dot, so that an expect for "Select"
-                            # does not match "TypedSelect"
-                            expected_msg = re.sub(
-                                r"([\w_]+)",
-                                lambda m: rf"(?:.*\.)?{m.group(1)}\*?",
-                                expected_msg,
-                            )
-
-                            expected_msg = re.sub(
-                                "List", "builtins.list", expected_msg
-                            )
-
-                            expected_msg = re.sub(
-                                r"\b(int|str|float|bool)\b",
-                                lambda m: rf"builtins.{m.group(0)}\*?",
-                                expected_msg,
-                            )
-                            # expected_msg = re.sub(
-                            #     r"(Sequence|Tuple|List|Union)",
-                            #     lambda m: fr"typing.{m.group(0)}\*?",
-                            #     expected_msg,
-                            # )
-
-                        is_mypy = is_re = True
-                        expected_msg = f'Revealed type is "{expected_msg}"'
-
-                    if mypy_14 and util.py39:
-                        # use_lowercase_names, py39 and above
-                        # https://github.com/python/mypy/blob/304997bfb85200fb521ac727ee0ce3e6085e5278/mypy/options.py#L363  # noqa: E501
-
-                        # skip first character which could be capitalized
-                        # "List item x not found" type of message
-                        expected_msg = expected_msg[0] + re.sub(
-                            r"\b(List|Tuple|Dict|Set)\b"
-                            if is_type
-                            else r"\b(List|Tuple|Dict|Set|Type)\b",
-                            lambda m: m.group(1).lower(),
-                            expected_msg[1:],
-                        )
-
-                    if mypy_14 and util.py310:
-                        # use_or_syntax, py310 and above
-                        # https://github.com/python/mypy/blob/304997bfb85200fb521ac727ee0ce3e6085e5278/mypy/options.py#L368  # noqa: E501
-                        expected_msg = re.sub(
-                            r"Optional\[(.*?)\]",
-                            lambda m: f"{m.group(1)} | None",
-                            expected_msg,
-                        )
-                    current_assert_messages.append(
-                        (is_mypy, is_re, expected_msg.strip())
-                    )
-                elif current_assert_messages:
-                    expected_messages.extend(
-                        (num, is_mypy, is_re, expected_msg)
-                        for (
-                            is_mypy,
-                            is_re,
-                            expected_msg,
-                        ) in current_assert_messages
-                    )
-                    current_assert_messages[:] = []
-
-        result = mypy_runner(path, use_plugin=use_plugin)
-
-        not_located = []
-
-        if expected_messages:
-            # mypy 0.990 changed how return codes work, so don't assume a
-            # 1 or a 0 return code here, could be either depending on if
-            # errors were generated or not
-
-            output = []
-
-            raw_lines = result[0].split("\n")
-            while raw_lines:
-                e = raw_lines.pop(0)
-                if re.match(r".+\.py:\d+: error: .*", e):
-                    output.append(("error", e))
-                elif re.match(
-                    r".+\.py:\d+: note: +(?:Possible overload|def ).*", e
-                ):
-                    while raw_lines:
-                        ol = raw_lines.pop(0)
-                        if not re.match(r".+\.py:\d+: note: +def \[.*", ol):
-                            break
-                elif re.match(
-                    r".+\.py:\d+: note: .*(?:perhaps|suggestion)", e, re.I
-                ):
-                    pass
-                elif re.match(r".+\.py:\d+: note: .*", e):
-                    output.append(("note", e))
-
-            for num, is_mypy, is_re, msg in expected_messages:
-                msg = msg.replace("'", '"')
-                prefix = "[SQLAlchemy Mypy plugin] " if not is_mypy else ""
-                for idx, (typ, errmsg) in enumerate(output):
-                    if is_re:
-                        if re.match(
-                            rf".*{filename}\:{num}\: {typ}\: {prefix}{msg}",  # noqa: E501
-                            errmsg,
-                        ):
-                            break
-                    elif (
-                        f"{filename}:{num}: {typ}: {prefix}{msg}"
-                        in errmsg.replace("'", '"')
-                    ):
-                        break
-                else:
-                    not_located.append(msg)
-                    continue
-                del output[idx]
-
-            if not_located:
-                print(f"Couldn't locate expected messages: {not_located}")
-                print("\n".join(msg for _, msg in output))
-                assert False, "expected messages not found, see stdout"
-
-            if output:
-                print(f"{len(output)} messages from mypy were not consumed:")
-                print("\n".join(msg for _, msg in output))
-                assert False, "errors and/or notes remain, see stdout"
-
-        else:
-            if result[2] != 0:
-                print(result[0])
-
-            eq_(result[2], 0, msg=result)
+    def test_plugin_files(self, mypy_typecheck_file, path):
+        mypy_typecheck_file(path, use_plugin=True)
index d7b7b0bb20e223577451bb46cdabfbc1642eb804..abf8efec7129d28fb03f4e580852674758574735 100644 (file)
@@ -45,6 +45,7 @@ from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import is_
 from sqlalchemy.testing import is_false
 from sqlalchemy.testing.assertions import expect_raises_message
+from sqlalchemy.testing.entities import ComparableEntity  # noqa
 from sqlalchemy.testing.entities import ComparableMixin  # noqa
 from sqlalchemy.testing.fixtures import fixture_session
 from sqlalchemy.testing.schema import Column
@@ -2451,7 +2452,7 @@ class CompositeAccessTest(fixtures.DeclarativeMappedTest):
                 creator=lambda point: PointData(point=point),
             )
 
-        class PointData(fixtures.ComparableEntity, cls.DeclarativeBasic):
+        class PointData(ComparableEntity, cls.DeclarativeBasic):
             __tablename__ = "point"
 
             id = Column(
index e68f9c0351da93015a7e15d905139d2fdabd68a6..4421c3a6edef5e4a8489d01b3505910c071a3206 100644 (file)
@@ -15,6 +15,7 @@ from sqlalchemy.testing import in_
 from sqlalchemy.testing import is_
 from sqlalchemy.testing import ne_
 from sqlalchemy.testing import not_in
+from sqlalchemy.testing.entities import ComparableEntity
 from sqlalchemy.testing.schema import Column
 
 
@@ -176,7 +177,7 @@ class IndexPropertyArrayTest(fixtures.DeclarativeMappedTest):
     def setup_classes(cls):
         Base = cls.DeclarativeBasic
 
-        class Array(fixtures.ComparableEntity, Base):
+        class Array(ComparableEntity, Base):
             __tablename__ = "array"
 
             id = Column(
@@ -270,7 +271,7 @@ class IndexPropertyJsonTest(fixtures.DeclarativeMappedTest):
                 expr = super().expr(model)
                 return expr.astext.cast(self.cast_type)
 
-        class Json(fixtures.ComparableEntity, Base):
+        class Json(ComparableEntity, Base):
             __tablename__ = "json"
 
             id = Column(
index 6c428fa85467e97addbb8b0d0df990f074194719..dffdac8d84298fceb2e3fed90b682f4f5baa4aa1 100644 (file)
@@ -35,6 +35,7 @@ from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import is_
 from sqlalchemy.testing import is_true
 from sqlalchemy.testing import mock
+from sqlalchemy.testing.entities import BasicEntity
 from sqlalchemy.testing.fixtures import fixture_session
 from sqlalchemy.testing.schema import Column
 from sqlalchemy.testing.schema import Table
@@ -44,7 +45,7 @@ from sqlalchemy.types import TypeDecorator
 from sqlalchemy.types import VARCHAR
 
 
-class Foo(fixtures.BasicEntity):
+class Foo(BasicEntity):
     pass
 
 
@@ -52,7 +53,7 @@ class SubFoo(Foo):
     pass
 
 
-class Foo2(fixtures.BasicEntity):
+class Foo2(BasicEntity):
     pass
 
 
@@ -68,7 +69,7 @@ class FooWithEq:
         return self.id == other.id
 
 
-class FooWNoHash(fixtures.BasicEntity):
+class FooWNoHash(BasicEntity):
     __hash__ = None
 
 
index 8318484c027e3e99c2a55f188eadd3391a5ae844..a52c59e2d34d23c3b88eedd75f48f77452bd5019 100644 (file)
@@ -20,6 +20,7 @@ from sqlalchemy.orm import sessionmaker
 from sqlalchemy.testing import AssertsCompiledSQL
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import fixtures
+from sqlalchemy.testing.entities import ComparableEntity
 from sqlalchemy.testing.schema import Column
 from sqlalchemy.testing.schema import Table
 
@@ -29,11 +30,11 @@ def pickle_protocols():
     # return iter([-1, 0, 1, 2])
 
 
-class User(fixtures.ComparableEntity):
+class User(ComparableEntity):
     pass
 
 
-class Address(fixtures.ComparableEntity):
+class Address(ComparableEntity):
     pass
 
 
index 985b600f0d66938be5c1edc94e34933db0298ce2..7085b2af9f6b5a8fb3393520a89c7b361364047e 100644 (file)
@@ -59,6 +59,7 @@ from sqlalchemy.testing import expect_warnings
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import is_
 from sqlalchemy.testing import mock
+from sqlalchemy.testing.entities import ComparableEntity
 from sqlalchemy.testing.fixtures import fixture_session
 from sqlalchemy.testing.schema import Column
 from sqlalchemy.testing.schema import Table
@@ -517,7 +518,7 @@ class DeclarativeBaseSetupsTest(fixtures.TestBase):
             xyzzy = "magic"
 
         # _as_declarative() inspects obj.__class__.__bases__
-        class User(BrokenParent, fixtures.ComparableEntity):
+        class User(BrokenParent, ComparableEntity):
             __tablename__ = "users"
             id = Column(
                 "id", Integer, primary_key=True, test_needs_autoincrement=True
@@ -707,7 +708,7 @@ class DeclarativeBaseSetupsTest(fixtures.TestBase):
         assert Base().foobar() == "foobar"
 
     def test_as_declarative(self, metadata):
-        class User(fixtures.ComparableEntity):
+        class User(ComparableEntity):
             __tablename__ = "users"
             id = Column(
                 "id", Integer, primary_key=True, test_needs_autoincrement=True
@@ -715,7 +716,7 @@ class DeclarativeBaseSetupsTest(fixtures.TestBase):
             name = Column("name", String(50))
             addresses = relationship("Address", backref="user")
 
-        class Address(fixtures.ComparableEntity):
+        class Address(ComparableEntity):
             __tablename__ = "addresses"
             id = Column(
                 "id", Integer, primary_key=True, test_needs_autoincrement=True
@@ -746,7 +747,7 @@ class DeclarativeBaseSetupsTest(fixtures.TestBase):
             )
 
     def test_map_declaratively(self, metadata):
-        class User(fixtures.ComparableEntity):
+        class User(ComparableEntity):
             __tablename__ = "users"
             id = Column(
                 "id", Integer, primary_key=True, test_needs_autoincrement=True
@@ -754,7 +755,7 @@ class DeclarativeBaseSetupsTest(fixtures.TestBase):
             name = Column("name", String(50))
             addresses = relationship("Address", backref="user")
 
-        class Address(fixtures.ComparableEntity):
+        class Address(ComparableEntity):
             __tablename__ = "addresses"
             id = Column(
                 "id", Integer, primary_key=True, test_needs_autoincrement=True
@@ -1105,7 +1106,7 @@ class DeclarativeMultiBaseTest(
             testing.config.skip_test("current base has no metaclass")
 
     def test_basic(self):
-        class User(Base, fixtures.ComparableEntity):
+        class User(Base, ComparableEntity):
             __tablename__ = "users"
 
             id = Column(
@@ -1114,7 +1115,7 @@ class DeclarativeMultiBaseTest(
             name = Column("name", String(50))
             addresses = relationship("Address", backref="user")
 
-        class Address(Base, fixtures.ComparableEntity):
+        class Address(Base, ComparableEntity):
             __tablename__ = "addresses"
 
             id = Column(
@@ -1263,7 +1264,7 @@ class DeclarativeMultiBaseTest(
         )
 
     def test_unicode_string_resolve(self):
-        class User(Base, fixtures.ComparableEntity):
+        class User(Base, ComparableEntity):
             __tablename__ = "users"
 
             id = Column(
@@ -1272,7 +1273,7 @@ class DeclarativeMultiBaseTest(
             name = Column("name", String(50))
             addresses = relationship("Address", backref="user")
 
-        class Address(Base, fixtures.ComparableEntity):
+        class Address(Base, ComparableEntity):
             __tablename__ = "addresses"
 
             id = Column(
@@ -1286,7 +1287,7 @@ class DeclarativeMultiBaseTest(
         assert User.addresses.property.mapper.class_ is Address
 
     def test_unicode_string_resolve_backref(self):
-        class User(Base, fixtures.ComparableEntity):
+        class User(Base, ComparableEntity):
             __tablename__ = "users"
 
             id = Column(
@@ -1294,7 +1295,7 @@ class DeclarativeMultiBaseTest(
             )
             name = Column("name", String(50))
 
-        class Address(Base, fixtures.ComparableEntity):
+        class Address(Base, ComparableEntity):
             __tablename__ = "addresses"
 
             id = Column(
@@ -1729,7 +1730,7 @@ class DeclarativeMultiBaseTest(
         assert User.__mapper__.registry._new_mappers is False
 
     def test_string_dependency_resolution(self):
-        class User(Base, fixtures.ComparableEntity):
+        class User(Base, ComparableEntity):
             __tablename__ = "users"
             id = Column(
                 Integer, primary_key=True, test_needs_autoincrement=True
@@ -1747,7 +1748,7 @@ class DeclarativeMultiBaseTest(
                 ),
             )
 
-        class Address(Base, fixtures.ComparableEntity):
+        class Address(Base, ComparableEntity):
             __tablename__ = "addresses"
             id = Column(
                 Integer, primary_key=True, test_needs_autoincrement=True
@@ -1780,7 +1781,7 @@ class DeclarativeMultiBaseTest(
             ),
         )
 
-        class Foo(Base, fixtures.ComparableEntity):
+        class Foo(Base, ComparableEntity):
             __tablename__ = "foo"
             id = Column(Integer, primary_key=True)
             rel = relationship("User", primaryjoin="User.addresses==Foo.id")
@@ -1792,7 +1793,7 @@ class DeclarativeMultiBaseTest(
         )
 
     def test_string_dependency_resolution_synonym(self):
-        class User(Base, fixtures.ComparableEntity):
+        class User(Base, ComparableEntity):
             __tablename__ = "users"
             id = Column(
                 Integer, primary_key=True, test_needs_autoincrement=True
@@ -1807,7 +1808,7 @@ class DeclarativeMultiBaseTest(
         sess.expunge_all()
         eq_(sess.query(User).filter(User.name == "ed").one(), User(name="ed"))
 
-        class Foo(Base, fixtures.ComparableEntity):
+        class Foo(Base, ComparableEntity):
             __tablename__ = "foo"
             id = Column(Integer, primary_key=True)
             _user_id = Column(Integer)
@@ -1902,14 +1903,14 @@ class DeclarativeMultiBaseTest(
         )
 
     def test_string_dependency_resolution_no_table(self):
-        class User(Base, fixtures.ComparableEntity):
+        class User(Base, ComparableEntity):
             __tablename__ = "users"
             id = Column(
                 Integer, primary_key=True, test_needs_autoincrement=True
             )
             name = Column(String(50))
 
-        class Bar(Base, fixtures.ComparableEntity):
+        class Bar(Base, ComparableEntity):
             __tablename__ = "bar"
             id = Column(Integer, primary_key=True)
             rel = relationship("User", primaryjoin="User.id==Bar.__table__.id")
@@ -1921,14 +1922,14 @@ class DeclarativeMultiBaseTest(
         )
 
     def test_string_w_pj_annotations(self):
-        class User(Base, fixtures.ComparableEntity):
+        class User(Base, ComparableEntity):
             __tablename__ = "users"
             id = Column(
                 Integer, primary_key=True, test_needs_autoincrement=True
             )
             name = Column(String(50))
 
-        class Address(Base, fixtures.ComparableEntity):
+        class Address(Base, ComparableEntity):
             __tablename__ = "addresses"
             id = Column(
                 Integer, primary_key=True, test_needs_autoincrement=True
@@ -1947,7 +1948,7 @@ class DeclarativeMultiBaseTest(
     def test_string_dependency_resolution_no_magic(self):
         """test that full tinkery expressions work as written"""
 
-        class User(Base, fixtures.ComparableEntity):
+        class User(Base, ComparableEntity):
             __tablename__ = "users"
             id = Column(Integer, primary_key=True)
             addresses = relationship(
@@ -1955,7 +1956,7 @@ class DeclarativeMultiBaseTest(
                 primaryjoin="User.id==Address.user_id.prop.columns[0]",
             )
 
-        class Address(Base, fixtures.ComparableEntity):
+        class Address(Base, ComparableEntity):
             __tablename__ = "addresses"
             id = Column(Integer, primary_key=True)
             user_id = Column(Integer, ForeignKey("users.id"))
@@ -1967,7 +1968,7 @@ class DeclarativeMultiBaseTest(
         )
 
     def test_string_dependency_resolution_module_qualified(self):
-        class User(Base, fixtures.ComparableEntity):
+        class User(Base, ComparableEntity):
             __tablename__ = "users"
             id = Column(Integer, primary_key=True)
             addresses = relationship(
@@ -1976,7 +1977,7 @@ class DeclarativeMultiBaseTest(
                 % (__name__, __name__),
             )
 
-        class Address(Base, fixtures.ComparableEntity):
+        class Address(Base, ComparableEntity):
             __tablename__ = "addresses"
             id = Column(Integer, primary_key=True)
             user_id = Column(Integer, ForeignKey("users.id"))
@@ -1988,7 +1989,7 @@ class DeclarativeMultiBaseTest(
         )
 
     def test_string_dependency_resolution_in_backref(self):
-        class User(Base, fixtures.ComparableEntity):
+        class User(Base, ComparableEntity):
             __tablename__ = "users"
             id = Column(Integer, primary_key=True)
             name = Column(String(50))
@@ -1998,7 +1999,7 @@ class DeclarativeMultiBaseTest(
                 backref="user",
             )
 
-        class Address(Base, fixtures.ComparableEntity):
+        class Address(Base, ComparableEntity):
             __tablename__ = "addresses"
             id = Column(Integer, primary_key=True)
             email = Column(String(50))
@@ -2011,7 +2012,7 @@ class DeclarativeMultiBaseTest(
         )
 
     def test_string_dependency_resolution_tables(self):
-        class User(Base, fixtures.ComparableEntity):
+        class User(Base, ComparableEntity):
             __tablename__ = "users"
             id = Column(Integer, primary_key=True)
             name = Column(String(50))
@@ -2023,7 +2024,7 @@ class DeclarativeMultiBaseTest(
                 backref="users",
             )
 
-        class Prop(Base, fixtures.ComparableEntity):
+        class Prop(Base, ComparableEntity):
             __tablename__ = "props"
             id = Column(Integer, primary_key=True)
             name = Column(String(50))
@@ -2042,7 +2043,7 @@ class DeclarativeMultiBaseTest(
 
     def test_string_dependency_resolution_table_over_class(self):
         # test for second half of #5774
-        class User(Base, fixtures.ComparableEntity):
+        class User(Base, ComparableEntity):
             __tablename__ = "users"
             id = Column(Integer, primary_key=True)
             name = Column(String(50))
@@ -2052,7 +2053,7 @@ class DeclarativeMultiBaseTest(
                 backref="users",
             )
 
-        class Prop(Base, fixtures.ComparableEntity):
+        class Prop(Base, ComparableEntity):
             __tablename__ = "props"
             id = Column(Integer, primary_key=True)
             name = Column(String(50))
@@ -2071,7 +2072,7 @@ class DeclarativeMultiBaseTest(
 
     def test_string_dependency_resolution_class_over_table(self):
         # test for second half of #5774
-        class User(Base, fixtures.ComparableEntity):
+        class User(Base, ComparableEntity):
             __tablename__ = "users"
             id = Column(Integer, primary_key=True)
             name = Column(String(50))
@@ -2091,7 +2092,7 @@ class DeclarativeMultiBaseTest(
         )
 
     def test_uncompiled_attributes_in_relationship(self):
-        class Address(Base, fixtures.ComparableEntity):
+        class Address(Base, ComparableEntity):
             __tablename__ = "addresses"
             id = Column(
                 Integer, primary_key=True, test_needs_autoincrement=True
@@ -2099,7 +2100,7 @@ class DeclarativeMultiBaseTest(
             email = Column(String(50))
             user_id = Column(Integer, ForeignKey("users.id"))
 
-        class User(Base, fixtures.ComparableEntity):
+        class User(Base, ComparableEntity):
             __tablename__ = "users"
             id = Column(
                 Integer, primary_key=True, test_needs_autoincrement=True
@@ -2276,14 +2277,14 @@ class DeclarativeMultiBaseTest(
     def test_add_prop_auto(
         self, require_metaclass, assert_user_address_mapping, _column
     ):
-        class User(Base, fixtures.ComparableEntity):
+        class User(Base, ComparableEntity):
             __tablename__ = "users"
             id = Column("id", Integer, primary_key=True)
 
         User.name = _column("name", String(50))
         User.addresses = relationship("Address", backref="user")
 
-        class Address(Base, fixtures.ComparableEntity):
+        class Address(Base, ComparableEntity):
             __tablename__ = "addresses"
             id = _column(Integer, primary_key=True)
 
@@ -2300,7 +2301,7 @@ class DeclarativeMultiBaseTest(
 
     @testing.combinations(Column, mapped_column, argnames="_column")
     def test_add_prop_manual(self, assert_user_address_mapping, _column):
-        class User(Base, fixtures.ComparableEntity):
+        class User(Base, ComparableEntity):
             __tablename__ = "users"
             id = _column("id", Integer, primary_key=True)
 
@@ -2309,7 +2310,7 @@ class DeclarativeMultiBaseTest(
             User, "addresses", relationship("Address", backref="user")
         )
 
-        class Address(Base, fixtures.ComparableEntity):
+        class Address(Base, ComparableEntity):
             __tablename__ = "addresses"
             id = _column(Integer, primary_key=True)
 
@@ -2404,7 +2405,7 @@ class DeclarativeMultiBaseTest(
         A(brap=B())
 
     def test_eager_order_by(self):
-        class Address(Base, fixtures.ComparableEntity):
+        class Address(Base, ComparableEntity):
             __tablename__ = "addresses"
             id = Column(
                 "id", Integer, primary_key=True, test_needs_autoincrement=True
@@ -2412,7 +2413,7 @@ class DeclarativeMultiBaseTest(
             email = Column("email", String(50))
             user_id = Column("user_id", Integer, ForeignKey("users.id"))
 
-        class User(Base, fixtures.ComparableEntity):
+        class User(Base, ComparableEntity):
             __tablename__ = "users"
             id = Column(
                 "id", Integer, primary_key=True, test_needs_autoincrement=True
@@ -2439,7 +2440,7 @@ class DeclarativeMultiBaseTest(
         )
 
     def test_order_by_multi(self):
-        class Address(Base, fixtures.ComparableEntity):
+        class Address(Base, ComparableEntity):
             __tablename__ = "addresses"
             id = Column(
                 "id", Integer, primary_key=True, test_needs_autoincrement=True
@@ -2447,7 +2448,7 @@ class DeclarativeMultiBaseTest(
             email = Column("email", String(50))
             user_id = Column("user_id", Integer, ForeignKey("users.id"))
 
-        class User(Base, fixtures.ComparableEntity):
+        class User(Base, ComparableEntity):
             __tablename__ = "users"
             id = Column(
                 "id", Integer, primary_key=True, test_needs_autoincrement=True
@@ -2473,7 +2474,7 @@ class DeclarativeMultiBaseTest(
             "Ignoring declarative-like tuple value of " "attribute 'name'"
         ):
 
-            class User(Base, fixtures.ComparableEntity):
+            class User(Base, ComparableEntity):
                 __tablename__ = "users"
                 id = Column("id", Integer, primary_key=True)
                 name = (Column("name", String(50)),)
@@ -2573,7 +2574,7 @@ class DeclarativeMultiBaseTest(
         is_(inspect(Employee).local_table, Person.__table__)
 
     def test_expression(self, require_metaclass):
-        class User(Base, fixtures.ComparableEntity):
+        class User(Base, ComparableEntity):
             __tablename__ = "users"
             id = Column(
                 "id", Integer, primary_key=True, test_needs_autoincrement=True
@@ -2581,7 +2582,7 @@ class DeclarativeMultiBaseTest(
             name = Column("name", String(50))
             addresses = relationship("Address", backref="user")
 
-        class Address(Base, fixtures.ComparableEntity):
+        class Address(Base, ComparableEntity):
             __tablename__ = "addresses"
             id = Column(
                 "id", Integer, primary_key=True, test_needs_autoincrement=True
@@ -2614,7 +2615,7 @@ class DeclarativeMultiBaseTest(
         )
 
     def test_useless_declared_attr(self):
-        class Address(Base, fixtures.ComparableEntity):
+        class Address(Base, ComparableEntity):
             __tablename__ = "addresses"
             id = Column(
                 "id", Integer, primary_key=True, test_needs_autoincrement=True
@@ -2622,7 +2623,7 @@ class DeclarativeMultiBaseTest(
             email = Column("email", String(50))
             user_id = Column("user_id", Integer, ForeignKey("users.id"))
 
-        class User(Base, fixtures.ComparableEntity):
+        class User(Base, ComparableEntity):
             __tablename__ = "users"
             id = Column(
                 "id", Integer, primary_key=True, test_needs_autoincrement=True
@@ -2696,7 +2697,7 @@ class DeclarativeMultiBaseTest(
                     return Column(Integer)
 
     def test_column(self, require_metaclass):
-        class User(Base, fixtures.ComparableEntity):
+        class User(Base, ComparableEntity):
             __tablename__ = "users"
             id = Column(
                 "id", Integer, primary_key=True, test_needs_autoincrement=True
@@ -2736,7 +2737,7 @@ class DeclarativeMultiBaseTest(
         eq_(Foo.d.impl.active_history, False)
 
     def test_column_properties(self):
-        class Address(Base, fixtures.ComparableEntity):
+        class Address(Base, ComparableEntity):
             __tablename__ = "addresses"
             id = Column(
                 Integer, primary_key=True, test_needs_autoincrement=True
@@ -2744,7 +2745,7 @@ class DeclarativeMultiBaseTest(
             email = Column(String(50))
             user_id = Column(Integer, ForeignKey("users.id"))
 
-        class User(Base, fixtures.ComparableEntity):
+        class User(Base, ComparableEntity):
             __tablename__ = "users"
             id = Column(
                 "id", Integer, primary_key=True, test_needs_autoincrement=True
@@ -2778,13 +2779,13 @@ class DeclarativeMultiBaseTest(
         )
 
     def test_column_properties_2(self):
-        class Address(Base, fixtures.ComparableEntity):
+        class Address(Base, ComparableEntity):
             __tablename__ = "addresses"
             id = Column(Integer, primary_key=True)
             email = Column(String(50))
             user_id = Column(Integer, ForeignKey("users.id"))
 
-        class User(Base, fixtures.ComparableEntity):
+        class User(Base, ComparableEntity):
             __tablename__ = "users"
             id = Column("id", Integer, primary_key=True)
             name = Column("name", String(50))
@@ -2798,7 +2799,7 @@ class DeclarativeMultiBaseTest(
         eq_(set(Address.__table__.c.keys()), {"id", "email", "user_id"})
 
     def test_deferred(self):
-        class User(Base, fixtures.ComparableEntity):
+        class User(Base, ComparableEntity):
             __tablename__ = "users"
             id = Column(
                 Integer, primary_key=True, test_needs_autoincrement=True
@@ -2819,7 +2820,7 @@ class DeclarativeMultiBaseTest(
         self.assert_sql_count(testing.db, go, 1)
 
     def test_composite_inline(self):
-        class AddressComposite(fixtures.ComparableEntity):
+        class AddressComposite(ComparableEntity):
             def __init__(self, street, state):
                 self.street = street
                 self.state = state
@@ -2827,7 +2828,7 @@ class DeclarativeMultiBaseTest(
             def __composite_values__(self):
                 return [self.street, self.state]
 
-        class User(Base, fixtures.ComparableEntity):
+        class User(Base, ComparableEntity):
             __tablename__ = "user"
             id = Column(
                 Integer, primary_key=True, test_needs_autoincrement=True
@@ -2848,7 +2849,7 @@ class DeclarativeMultiBaseTest(
         )
 
     def test_composite_separate(self):
-        class AddressComposite(fixtures.ComparableEntity):
+        class AddressComposite(ComparableEntity):
             def __init__(self, street, state):
                 self.street = street
                 self.state = state
@@ -2856,7 +2857,7 @@ class DeclarativeMultiBaseTest(
             def __composite_values__(self):
                 return [self.street, self.state]
 
-        class User(Base, fixtures.ComparableEntity):
+        class User(Base, ComparableEntity):
             __tablename__ = "user"
             id = Column(
                 Integer, primary_key=True, test_needs_autoincrement=True
@@ -2903,7 +2904,7 @@ class DeclarativeMultiBaseTest(
         )
 
     def test_synonym_inline(self):
-        class User(Base, fixtures.ComparableEntity):
+        class User(Base, ComparableEntity):
             __tablename__ = "users"
             id = Column(
                 "id", Integer, primary_key=True, test_needs_autoincrement=True
@@ -2939,7 +2940,7 @@ class DeclarativeMultiBaseTest(
             def __eq__(self, other):
                 return self.__clause_element__() == other + " FOO"
 
-        class User(Base, fixtures.ComparableEntity):
+        class User(Base, ComparableEntity):
             __tablename__ = "users"
             id = Column(
                 "id", Integer, primary_key=True, test_needs_autoincrement=True
@@ -2955,7 +2956,7 @@ class DeclarativeMultiBaseTest(
         eq_(sess.query(User).filter(User.name == "someuser").one(), u1)
 
     def test_synonym_added(self, require_metaclass):
-        class User(Base, fixtures.ComparableEntity):
+        class User(Base, ComparableEntity):
             __tablename__ = "users"
             id = Column(
                 "id", Integer, primary_key=True, test_needs_autoincrement=True
@@ -2982,7 +2983,7 @@ class DeclarativeMultiBaseTest(
         )
 
     def test_reentrant_compile_via_foreignkey(self):
-        class User(Base, fixtures.ComparableEntity):
+        class User(Base, ComparableEntity):
             __tablename__ = "users"
             id = Column(
                 "id", Integer, primary_key=True, test_needs_autoincrement=True
@@ -2990,7 +2991,7 @@ class DeclarativeMultiBaseTest(
             name = Column("name", String(50))
             addresses = relationship("Address", backref="user")
 
-        class Address(Base, fixtures.ComparableEntity):
+        class Address(Base, ComparableEntity):
             __tablename__ = "addresses"
             id = Column(
                 "id", Integer, primary_key=True, test_needs_autoincrement=True
@@ -3026,7 +3027,7 @@ class DeclarativeMultiBaseTest(
         )
 
     def test_relationship_reference(self, require_metaclass):
-        class Address(Base, fixtures.ComparableEntity):
+        class Address(Base, ComparableEntity):
             __tablename__ = "addresses"
             id = Column(
                 "id", Integer, primary_key=True, test_needs_autoincrement=True
@@ -3034,7 +3035,7 @@ class DeclarativeMultiBaseTest(
             email = Column("email", String(50))
             user_id = Column("user_id", Integer, ForeignKey("users.id"))
 
-        class User(Base, fixtures.ComparableEntity):
+        class User(Base, ComparableEntity):
             __tablename__ = "users"
             id = Column(
                 "id", Integer, primary_key=True, test_needs_autoincrement=True
@@ -3105,7 +3106,7 @@ class DeclarativeMultiBaseTest(
         eq_(sess.execute(t1.select()).fetchall(), [("someid", "somedata")])
 
     def test_synonym_for(self):
-        class User(Base, fixtures.ComparableEntity):
+        class User(Base, ComparableEntity):
             __tablename__ = "users"
             id = Column(
                 "id", Integer, primary_key=True, test_needs_autoincrement=True
index 79639ed9ce0cea4661ef23b334c2f88e531df280..c5b908cd822ead8aac0e8a014e3aa7699f0eaf1e 100644 (file)
@@ -28,6 +28,7 @@ from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import is_
 from sqlalchemy.testing import is_false
 from sqlalchemy.testing import is_true
+from sqlalchemy.testing.entities import ComparableEntity
 from sqlalchemy.testing.fixtures import fixture_session
 from sqlalchemy.testing.schema import Column
 from sqlalchemy.testing.schema import Table
@@ -147,7 +148,7 @@ class DeclarativeInheritanceTest(
         configure_mappers()
 
     def test_joined(self):
-        class Company(Base, fixtures.ComparableEntity):
+        class Company(Base, ComparableEntity):
             __tablename__ = "companies"
             id = Column(
                 "id", Integer, primary_key=True, test_needs_autoincrement=True
@@ -155,7 +156,7 @@ class DeclarativeInheritanceTest(
             name = Column("name", String(50))
             employees = relationship("Person")
 
-        class Person(Base, fixtures.ComparableEntity):
+        class Person(Base, ComparableEntity):
             __tablename__ = "people"
             id = Column(
                 "id", Integer, primary_key=True, test_needs_autoincrement=True
@@ -245,7 +246,7 @@ class DeclarativeInheritanceTest(
         self.assert_sql_count(testing.db, go, 1)
 
     def test_add_subcol_after_the_fact(self):
-        class Person(Base, fixtures.ComparableEntity):
+        class Person(Base, ComparableEntity):
             __tablename__ = "people"
             id = Column(
                 "id", Integer, primary_key=True, test_needs_autoincrement=True
@@ -274,7 +275,7 @@ class DeclarativeInheritanceTest(
         )
 
     def test_add_parentcol_after_the_fact(self):
-        class Person(Base, fixtures.ComparableEntity):
+        class Person(Base, ComparableEntity):
             __tablename__ = "people"
             id = Column(
                 "id", Integer, primary_key=True, test_needs_autoincrement=True
@@ -303,7 +304,7 @@ class DeclarativeInheritanceTest(
         )
 
     def test_add_sub_parentcol_after_the_fact(self):
-        class Person(Base, fixtures.ComparableEntity):
+        class Person(Base, ComparableEntity):
             __tablename__ = "people"
             id = Column(
                 "id", Integer, primary_key=True, test_needs_autoincrement=True
@@ -340,7 +341,7 @@ class DeclarativeInheritanceTest(
         )
 
     def test_subclass_mixin(self):
-        class Person(Base, fixtures.ComparableEntity):
+        class Person(Base, ComparableEntity):
             __tablename__ = "people"
             id = Column("id", Integer, primary_key=True)
             name = Column("name", String(50))
@@ -532,7 +533,7 @@ class DeclarativeInheritanceTest(
         """test single inheritance where all the columns are on the base
         class."""
 
-        class Company(Base, fixtures.ComparableEntity):
+        class Company(Base, ComparableEntity):
             __tablename__ = "companies"
             id = Column(
                 "id", Integer, primary_key=True, test_needs_autoincrement=True
@@ -540,7 +541,7 @@ class DeclarativeInheritanceTest(
             name = Column("name", String(50))
             employees = relationship("Person")
 
-        class Person(Base, fixtures.ComparableEntity):
+        class Person(Base, ComparableEntity):
             __tablename__ = "people"
             id = Column(
                 "id", Integer, primary_key=True, test_needs_autoincrement=True
@@ -604,7 +605,7 @@ class DeclarativeInheritanceTest(
 
         """
 
-        class Company(Base, fixtures.ComparableEntity):
+        class Company(Base, ComparableEntity):
             __tablename__ = "companies"
             id = Column(
                 "id", Integer, primary_key=True, test_needs_autoincrement=True
@@ -612,7 +613,7 @@ class DeclarativeInheritanceTest(
             name = Column("name", String(50))
             employees = relationship("Person")
 
-        class Person(Base, fixtures.ComparableEntity):
+        class Person(Base, ComparableEntity):
             __tablename__ = "people"
             id = Column(
                 Integer, primary_key=True, test_needs_autoincrement=True
@@ -787,7 +788,7 @@ class DeclarativeInheritanceTest(
     def test_single_constraint_on_sub(self):
         """test the somewhat unusual case of [ticket:3341]"""
 
-        class Person(Base, fixtures.ComparableEntity):
+        class Person(Base, ComparableEntity):
             __tablename__ = "people"
             id = Column(
                 Integer, primary_key=True, test_needs_autoincrement=True
@@ -1147,7 +1148,7 @@ class DeclarativeInheritanceTest(
         is_(Manager.id.property.columns[0], Person.__table__.c.id)
 
     def test_joined_from_single(self):
-        class Company(Base, fixtures.ComparableEntity):
+        class Company(Base, ComparableEntity):
             __tablename__ = "companies"
             id = Column(
                 "id", Integer, primary_key=True, test_needs_autoincrement=True
@@ -1155,7 +1156,7 @@ class DeclarativeInheritanceTest(
             name = Column("name", String(50))
             employees = relationship("Person")
 
-        class Person(Base, fixtures.ComparableEntity):
+        class Person(Base, ComparableEntity):
             __tablename__ = "people"
             id = Column(
                 Integer, primary_key=True, test_needs_autoincrement=True
@@ -1220,7 +1221,7 @@ class DeclarativeInheritanceTest(
         )
 
     def test_single_from_joined_colsonsub(self):
-        class Person(Base, fixtures.ComparableEntity):
+        class Person(Base, ComparableEntity):
             __tablename__ = "people"
             id = Column(
                 Integer, primary_key=True, test_needs_autoincrement=True
@@ -1268,7 +1269,7 @@ class DeclarativeInheritanceTest(
         is_(B.__mapper__.polymorphic_on, A.__table__.c.discriminator)
 
     def test_add_deferred(self):
-        class Person(Base, fixtures.ComparableEntity):
+        class Person(Base, ComparableEntity):
             __tablename__ = "people"
             id = Column(
                 "id", Integer, primary_key=True, test_needs_autoincrement=True
@@ -1292,7 +1293,7 @@ class DeclarativeInheritanceTest(
 
         """
 
-        class Person(Base, fixtures.ComparableEntity):
+        class Person(Base, ComparableEntity):
             __tablename__ = "people"
             id = Column(
                 Integer, primary_key=True, test_needs_autoincrement=True
@@ -1306,7 +1307,7 @@ class DeclarativeInheritanceTest(
             primary_language_id = Column(Integer, ForeignKey("languages.id"))
             primary_language = relationship("Language")
 
-        class Language(Base, fixtures.ComparableEntity):
+        class Language(Base, ComparableEntity):
             __tablename__ = "languages"
             id = Column(
                 Integer, primary_key=True, test_needs_autoincrement=True
@@ -1354,7 +1355,7 @@ class DeclarativeInheritanceTest(
         )
 
     def test_single_three_levels(self):
-        class Person(Base, fixtures.ComparableEntity):
+        class Person(Base, ComparableEntity):
             __tablename__ = "people"
             id = Column(Integer, primary_key=True)
             name = Column(String(50))
@@ -1415,7 +1416,7 @@ class DeclarativeInheritanceTest(
         assert_raises(sa.exc.ArgumentError, go)
 
     def test_single_no_special_cols(self):
-        class Person(Base, fixtures.ComparableEntity):
+        class Person(Base, ComparableEntity):
             __tablename__ = "people"
             id = Column("id", Integer, primary_key=True)
             name = Column("name", String(50))
@@ -1431,7 +1432,7 @@ class DeclarativeInheritanceTest(
         assert_raises_message(sa.exc.ArgumentError, "place primary key", go)
 
     def test_single_no_table_args(self):
-        class Person(Base, fixtures.ComparableEntity):
+        class Person(Base, ComparableEntity):
             __tablename__ = "people"
             id = Column("id", Integer, primary_key=True)
             name = Column("name", String(50))
index a2ed8f0ebfc4e59cf039a9164a8ab0958fcc803c..be9d28b1931eb8bcce78eb23c48974987e135c57 100644 (file)
@@ -9,6 +9,8 @@ from sqlalchemy.orm import relationship
 from sqlalchemy.testing import assert_raises
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import fixtures
+from sqlalchemy.testing.entities import ComparableEntity
+from sqlalchemy.testing.entities import ComparableMixin
 from sqlalchemy.testing.fixtures import fixture_session
 from sqlalchemy.testing.schema import Column
 from sqlalchemy.testing.schema import Table
@@ -62,12 +64,12 @@ class DeclarativeReflectionTest(DeclarativeReflectionBase):
         )
 
     def test_basic(self):
-        class User(Base, fixtures.ComparableEntity):
+        class User(Base, ComparableEntity):
             __tablename__ = "users"
             __autoload_with__ = testing.db
             addresses = relationship("Address", backref="user")
 
-        class Address(Base, fixtures.ComparableEntity):
+        class Address(Base, ComparableEntity):
             __tablename__ = "addresses"
             __autoload_with__ = testing.db
 
@@ -92,13 +94,13 @@ class DeclarativeReflectionTest(DeclarativeReflectionBase):
         eq_(a1.user, User(name="u1"))
 
     def test_rekey_wbase(self):
-        class User(Base, fixtures.ComparableEntity):
+        class User(Base, ComparableEntity):
             __tablename__ = "users"
             __autoload_with__ = testing.db
             nom = Column("name", String(50), key="nom")
             addresses = relationship("Address", backref="user")
 
-        class Address(Base, fixtures.ComparableEntity):
+        class Address(Base, ComparableEntity):
             __tablename__ = "addresses"
             __autoload_with__ = testing.db
 
@@ -125,14 +127,14 @@ class DeclarativeReflectionTest(DeclarativeReflectionBase):
 
     def test_rekey_wdecorator(self):
         @registry.mapped
-        class User(fixtures.ComparableMixin):
+        class User(ComparableMixin):
             __tablename__ = "users"
             __autoload_with__ = testing.db
             nom = Column("name", String(50), key="nom")
             addresses = relationship("Address", backref="user")
 
         @registry.mapped
-        class Address(fixtures.ComparableMixin):
+        class Address(ComparableMixin):
             __tablename__ = "addresses"
             __autoload_with__ = testing.db
 
@@ -158,12 +160,12 @@ class DeclarativeReflectionTest(DeclarativeReflectionBase):
         assert_raises(TypeError, User, name="u3")
 
     def test_supplied_fk(self):
-        class IMHandle(Base, fixtures.ComparableEntity):
+        class IMHandle(Base, ComparableEntity):
             __tablename__ = "imhandles"
             __autoload_with__ = testing.db
             user_id = Column("user_id", Integer, ForeignKey("users.id"))
 
-        class User(Base, fixtures.ComparableEntity):
+        class User(Base, ComparableEntity):
             __tablename__ = "users"
             __autoload_with__ = testing.db
             handles = relationship("IMHandle", backref="user")
index 9550671119638101d24210d002357dfa2a1cd726..2888aeaf9e1a69ee12ae9a477e085d455891df8f 100644 (file)
@@ -61,7 +61,7 @@ class InsertStmtTest(testing.AssertsExecutionResults, fixtures.TestBase):
         style: testing.Variation,
         sort_by_parameter_order,
     ):
-        class A(fixtures.ComparableEntity, decl_base):
+        class A(ComparableEntity, decl_base):
             __tablename__ = "a"
             id: Mapped[int] = mapped_column(Identity(), primary_key=True)
             data: Mapped[str]
@@ -1700,7 +1700,7 @@ class BulkDMLReturningJoinedInhTest(
     def setup_classes(cls):
         decl_base = cls.DeclarativeBasic
 
-        class A(fixtures.ComparableEntity, decl_base):
+        class A(ComparableEntity, decl_base):
             __tablename__ = "a"
             id: Mapped[int] = mapped_column(Identity(), primary_key=True)
             type: Mapped[str]
@@ -1814,7 +1814,7 @@ class BulkDMLReturningSingleInhTest(
     def setup_classes(cls):
         decl_base = cls.DeclarativeBasic
 
-        class A(fixtures.ComparableEntity, decl_base):
+        class A(ComparableEntity, decl_base):
             __tablename__ = "a"
             id: Mapped[int] = mapped_column(Identity(), primary_key=True)
             type: Mapped[str]
@@ -1857,7 +1857,7 @@ class BulkDMLReturningConcreteInhTest(
     def setup_classes(cls):
         decl_base = cls.DeclarativeBasic
 
-        class A(fixtures.ComparableEntity, decl_base):
+        class A(ComparableEntity, decl_base):
             __tablename__ = "a"
             id: Mapped[int] = mapped_column(Identity(), primary_key=True)
             type: Mapped[str]
@@ -1897,7 +1897,7 @@ class CTETest(fixtures.DeclarativeMappedTest):
     def setup_classes(cls):
         decl_base = cls.DeclarativeBasic
 
-        class User(fixtures.ComparableEntity, decl_base):
+        class User(ComparableEntity, decl_base):
             __tablename__ = "users"
             id: Mapped[uuid.UUID] = mapped_column(primary_key=True)
             username: Mapped[str]
index ae64bc9f920f8b01ae8038f374d886bd2788651a..5b5989c9205dae306afcb82200b59a405ec0f377 100644 (file)
@@ -9,15 +9,16 @@ from sqlalchemy.sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
 from sqlalchemy.testing import AssertsCompiledSQL
 from sqlalchemy.testing import config
 from sqlalchemy.testing import fixtures
+from sqlalchemy.testing.entities import ComparableEntity
 from sqlalchemy.testing.schema import Column
 from sqlalchemy.testing.schema import Table
 
 
-class Company(fixtures.ComparableEntity):
+class Company(ComparableEntity):
     pass
 
 
-class Person(fixtures.ComparableEntity):
+class Person(ComparableEntity):
     pass
 
 
@@ -33,19 +34,19 @@ class Boss(Manager):
     pass
 
 
-class Machine(fixtures.ComparableEntity):
+class Machine(ComparableEntity):
     pass
 
 
-class MachineType(fixtures.ComparableEntity):
+class MachineType(ComparableEntity):
     pass
 
 
-class Paperwork(fixtures.ComparableEntity):
+class Paperwork(ComparableEntity):
     pass
 
 
-class Page(fixtures.ComparableEntity):
+class Page(ComparableEntity):
     pass
 
 
@@ -568,7 +569,7 @@ class GeometryFixtureBase(fixtures.DeclarativeMappedTest):
                     items["__mapper_args__"][mapper_opt] = value[mapper_opt]
 
             if is_base:
-                klass = type(key, (fixtures.ComparableEntity, base), items)
+                klass = type(key, (ComparableEntity, base), items)
             else:
                 klass = type(key, (base,), items)
 
index 3ec9b55857710d2749a0849a886b8dd8949c6af1..f0967d86cc6db132ed29e21d6849cf1ced634fb6 100644 (file)
@@ -4,6 +4,7 @@ from sqlalchemy import String
 from sqlalchemy import testing
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import fixtures
+from sqlalchemy.testing.entities import ComparableEntity
 from sqlalchemy.testing.fixtures import fixture_session
 from sqlalchemy.testing.schema import Column
 from sqlalchemy.testing.schema import Table
@@ -37,7 +38,7 @@ class ABCTest(fixtures.MappedTest):
 
     @testing.combinations(("union",), ("none",))
     def test_abc_poly_roundtrip(self, fetchtype):
-        class A(fixtures.ComparableEntity):
+        class A(ComparableEntity):
             pass
 
         class B(A):
index 60322295283ccc24ac1ac5a23e71db99ca374963..aa076e19f90d255bbd5b371396614069a886e348 100644 (file)
@@ -40,7 +40,7 @@ from sqlalchemy.testing import config
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import expect_warnings
 from sqlalchemy.testing import fixtures
-from sqlalchemy.testing.fixtures import ComparableEntity
+from sqlalchemy.testing.entities import ComparableEntity
 from sqlalchemy.testing.fixtures import fixture_session
 from sqlalchemy.testing.provision import normalize_sequence
 from sqlalchemy.testing.schema import Column
@@ -1083,7 +1083,7 @@ class RelationshipTest8(fixtures.MappedTest):
         )
 
     def test_selfref_onjoined(self):
-        class Taggable(fixtures.ComparableEntity):
+        class Taggable(ComparableEntity):
             pass
 
         class User(Taggable):
@@ -1880,14 +1880,14 @@ class InheritingEagerTest(fixtures.MappedTest):
         """test that Query uses the full set of mapper._eager_loaders
         when generating SQL"""
 
-        class Person(fixtures.ComparableEntity):
+        class Person(ComparableEntity):
             pass
 
         class Employee(Person):
             def __init__(self, name="bob"):
                 self.name = name
 
-        class Tag(fixtures.ComparableEntity):
+        class Tag(ComparableEntity):
             def __init__(self, label):
                 self.label = label
 
index ab97c5f250bda624d6afabb0a7fff18a6a2674c4..769ef645e8773b94324e69c4ec3cb875f3956d49 100644 (file)
@@ -46,6 +46,8 @@ from sqlalchemy.testing.assertsql import CompiledSQL
 from sqlalchemy.testing.assertsql import Conditional
 from sqlalchemy.testing.assertsql import Or
 from sqlalchemy.testing.assertsql import RegexSQL
+from sqlalchemy.testing.entities import BasicEntity
+from sqlalchemy.testing.entities import ComparableEntity
 from sqlalchemy.testing.fixtures import fixture_session
 from sqlalchemy.testing.schema import Column
 from sqlalchemy.testing.schema import Table
@@ -188,7 +190,7 @@ class PolyExpressionEagerLoad(fixtures.DeclarativeMappedTest):
     def setup_classes(cls):
         Base = cls.DeclarativeBasic
 
-        class A(fixtures.ComparableEntity, Base):
+        class A(ComparableEntity, Base):
             __tablename__ = "a"
 
             id = Column(
@@ -782,7 +784,7 @@ class PolymorphicSynonymTest(fixtures.MappedTest):
         )
 
     def test_polymorphic_synonym(self):
-        class T1(fixtures.ComparableEntity):
+        class T1(ComparableEntity):
             def info(self):
                 return "THE INFO IS:" + self._info
 
@@ -1055,16 +1057,16 @@ class CascadeTest(fixtures.MappedTest):
         )
 
     def test_cascade(self):
-        class T1(fixtures.BasicEntity):
+        class T1(BasicEntity):
             pass
 
-        class T2(fixtures.BasicEntity):
+        class T2(BasicEntity):
             pass
 
         class T3(T2):
             pass
 
-        class T4(fixtures.BasicEntity):
+        class T4(BasicEntity):
             pass
 
         self.mapper_registry.map_imperatively(
@@ -1137,7 +1139,7 @@ class M2OUseGetTest(fixtures.MappedTest):
         )
 
         # test [ticket:1186]
-        class Base(fixtures.BasicEntity):
+        class Base(BasicEntity):
             pass
 
         class Sub(Base):
@@ -1405,7 +1407,7 @@ class EagerTargetingTest(fixtures.MappedTest):
     def test_adapt_stringency(self):
         b_table, a_table = self.tables.b_table, self.tables.a_table
 
-        class A(fixtures.ComparableEntity):
+        class A(ComparableEntity):
             pass
 
         class B(A):
@@ -2107,7 +2109,7 @@ class VersioningTest(fixtures.MappedTest):
             self.tables.stuff,
         )
 
-        class Base(fixtures.BasicEntity):
+        class Base(BasicEntity):
             pass
 
         class Sub(Base):
@@ -2171,7 +2173,7 @@ class VersioningTest(fixtures.MappedTest):
     def test_delete(self):
         subtable, base = self.tables.subtable, self.tables.base
 
-        class Base(fixtures.BasicEntity):
+        class Base(BasicEntity):
             pass
 
         class Sub(Base):
@@ -2833,10 +2835,10 @@ class OptimizedLoadTest(fixtures.MappedTest):
     def test_no_optimize_on_map_to_join(self):
         base, sub = self.tables.base, self.tables.sub
 
-        class Base(fixtures.ComparableEntity):
+        class Base(ComparableEntity):
             pass
 
-        class JoinBase(fixtures.ComparableEntity):
+        class JoinBase(ComparableEntity):
             pass
 
         class SubJoinBase(JoinBase):
@@ -2902,7 +2904,7 @@ class OptimizedLoadTest(fixtures.MappedTest):
 
         base, sub = self.tables.base, self.tables.sub
 
-        class Base(fixtures.ComparableEntity):
+        class Base(ComparableEntity):
             pass
 
         class Sub(Base):
@@ -3014,7 +3016,7 @@ class OptimizedLoadTest(fixtures.MappedTest):
 
         base, sub = self.tables.base, self.tables.sub
 
-        class Base(fixtures.ComparableEntity):
+        class Base(ComparableEntity):
             pass
 
         class Sub(Base):
@@ -3050,7 +3052,7 @@ class OptimizedLoadTest(fixtures.MappedTest):
     def test_column_expression(self):
         base, sub = self.tables.base, self.tables.sub
 
-        class Base(fixtures.ComparableEntity):
+        class Base(ComparableEntity):
             pass
 
         class Sub(Base):
@@ -3079,7 +3081,7 @@ class OptimizedLoadTest(fixtures.MappedTest):
     def test_column_expression_joined(self):
         base, sub = self.tables.base, self.tables.sub
 
-        class Base(fixtures.ComparableEntity):
+        class Base(ComparableEntity):
             pass
 
         class Sub(Base):
@@ -3120,7 +3122,7 @@ class OptimizedLoadTest(fixtures.MappedTest):
     def test_composite_column_joined(self):
         base, with_comp = self.tables.base, self.tables.with_comp
 
-        class Base(fixtures.BasicEntity):
+        class Base(BasicEntity):
             pass
 
         class WithComp(Base):
@@ -3168,7 +3170,7 @@ class OptimizedLoadTest(fixtures.MappedTest):
             expected_eager_defaults and testing.db.dialect.insert_returning
         )
 
-        class Base(fixtures.BasicEntity):
+        class Base(BasicEntity):
             pass
 
         class Sub(Base):
@@ -3257,7 +3259,7 @@ class OptimizedLoadTest(fixtures.MappedTest):
     def test_dont_generate_on_none(self):
         base, sub = self.tables.base, self.tables.sub
 
-        class Base(fixtures.BasicEntity):
+        class Base(BasicEntity):
             pass
 
         class Sub(Base):
@@ -3305,7 +3307,7 @@ class OptimizedLoadTest(fixtures.MappedTest):
             self.tables.subsub,
         )
 
-        class Base(fixtures.BasicEntity):
+        class Base(BasicEntity):
             pass
 
         class Sub(Base):
@@ -3820,13 +3822,13 @@ class DeleteOrphanTest(fixtures.MappedTest):
         )
 
     def test_orphan_message(self):
-        class Base(fixtures.BasicEntity):
+        class Base(BasicEntity):
             pass
 
         class SubClass(Base):
             pass
 
-        class Parent(fixtures.BasicEntity):
+        class Parent(BasicEntity):
             pass
 
         self.mapper_registry.map_imperatively(
@@ -3927,11 +3929,11 @@ class DiscriminatorOrPkNoneTest(fixtures.DeclarativeMappedTest):
     def setup_classes(cls):
         Base = cls.DeclarativeBasic
 
-        class Parent(fixtures.ComparableEntity, Base):
+        class Parent(ComparableEntity, Base):
             __tablename__ = "parent"
             id = Column(Integer, primary_key=True)
 
-        class A(fixtures.ComparableEntity, Base):
+        class A(ComparableEntity, Base):
             __tablename__ = "a"
             id = Column(Integer, primary_key=True)
             parent_id = Column(ForeignKey("parent.id"))
@@ -4019,7 +4021,7 @@ class UnexpectedPolymorphicIdentityTest(fixtures.DeclarativeMappedTest):
     def setup_classes(cls):
         Base = cls.DeclarativeBasic
 
-        class AJoined(fixtures.ComparableEntity, Base):
+        class AJoined(ComparableEntity, Base):
             __tablename__ = "ajoined"
             id = Column(Integer, primary_key=True)
             type = Column(String(10), nullable=False)
@@ -4038,7 +4040,7 @@ class UnexpectedPolymorphicIdentityTest(fixtures.DeclarativeMappedTest):
             id = Column(ForeignKey("ajoined.id"), primary_key=True)
             __mapper_args__ = {"polymorphic_identity": "subb"}
 
-        class ASingle(fixtures.ComparableEntity, Base):
+        class ASingle(ComparableEntity, Base):
             __tablename__ = "asingle"
             id = Column(Integer, primary_key=True)
             type = Column(String(10), nullable=False)
@@ -4110,7 +4112,7 @@ class CompositeJoinedInTest(fixtures.DeclarativeMappedTest):
     def setup_classes(cls):
         Base = cls.DeclarativeBasic
 
-        class A(fixtures.ComparableEntity, Base):
+        class A(ComparableEntity, Base):
             __tablename__ = "table_a"
 
             order_id: Mapped[str] = mapped_column(String(50), primary_key=True)
index 755af492ca29c6b99a5087d9c00cbf6c5396e454..87908482886512fc8bf35593d8920753d11792d9 100644 (file)
@@ -829,11 +829,11 @@ class LoaderOptionsTest(
     def setup_classes(cls):
         Base = cls.DeclarativeBasic
 
-        class Parent(fixtures.ComparableEntity, Base):
+        class Parent(ComparableEntity, Base):
             __tablename__ = "parent"
             id = Column(Integer, primary_key=True)
 
-        class Child(fixtures.ComparableEntity, Base):
+        class Child(ComparableEntity, Base):
             __tablename__ = "child"
             id = Column(Integer, primary_key=True)
             parent_id = Column(Integer, ForeignKey("parent.id"))
@@ -850,7 +850,7 @@ class LoaderOptionsTest(
                 "polymorphic_load": "selectin",
             }
 
-        class Other(fixtures.ComparableEntity, Base):
+        class Other(ComparableEntity, Base):
             __tablename__ = "other"
 
             id = Column(Integer, primary_key=True)
index f244c911060525114fea01ba6f4a13e163a2dced..0a92b7f5a40bf45448a4e5d8e6b3b984bd9fbd93 100644 (file)
@@ -13,11 +13,12 @@ from sqlalchemy.testing import assert_raises
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import is_
+from sqlalchemy.testing.entities import ComparableEntity
 from sqlalchemy.testing.fixtures import fixture_session
 from sqlalchemy.testing.schema import Column
 
 
-class Person(fixtures.ComparableEntity):
+class Person(ComparableEntity):
     pass
 
 
@@ -33,7 +34,7 @@ class Boss(Manager):
     pass
 
 
-class Company(fixtures.ComparableEntity):
+class Company(ComparableEntity):
     pass
 
 
index 293c7dfb59b1daeb9e3d6be4a144adb1eca65324..4ed2a453d3ec0cc1b2b4ca412359628984a3ddd6 100644 (file)
@@ -29,11 +29,11 @@ from sqlalchemy.testing.schema import Column
 from sqlalchemy.testing.schema import Table
 
 
-class Company(fixtures.ComparableEntity):
+class Company(ComparableEntity):
     pass
 
 
-class Person(fixtures.ComparableEntity):
+class Person(ComparableEntity):
     pass
 
 
@@ -49,11 +49,11 @@ class Boss(Manager):
     pass
 
 
-class Machine(fixtures.ComparableEntity):
+class Machine(ComparableEntity):
     pass
 
 
-class Paperwork(fixtures.ComparableEntity):
+class Paperwork(ComparableEntity):
     pass
 
 
index 47827e8887db491604d4a2923834da7eecc7e580..5fb15c9b7c507aeba321c744d52869e726dc0ffe 100644 (file)
@@ -5,6 +5,7 @@ from sqlalchemy import String
 from sqlalchemy.orm import Session
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import fixtures
+from sqlalchemy.testing.entities import ComparableEntity
 from sqlalchemy.testing.fixtures import fixture_session
 from sqlalchemy.testing.schema import Column
 from sqlalchemy.testing.schema import Table
@@ -31,7 +32,7 @@ class InheritingSelectablesTest(fixtures.MappedTest):
         connection.execute(foo.insert(), dict(a="i am bar", b="bar"))
         connection.execute(foo.insert(), dict(a="also bar", b="bar"))
 
-        class Foo(fixtures.ComparableEntity):
+        class Foo(ComparableEntity):
             pass
 
         class Bar(Foo):
index 4461ac86d28bf2881f8ae12bcb7d56bfbd44b889..52f3cf9c9f71996f8405a921fe8f65c45b914c23 100644 (file)
@@ -34,6 +34,7 @@ from sqlalchemy.testing import expect_raises_message
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import mock
 from sqlalchemy.testing.assertsql import CompiledSQL
+from sqlalchemy.testing.entities import ComparableEntity
 from sqlalchemy.testing.fixtures import fixture_session
 from sqlalchemy.testing.schema import Column
 from sqlalchemy.testing.schema import Table
@@ -1693,7 +1694,7 @@ class SingleOnJoinedTest(fixtures.MappedTest):
         )
 
     def test_single_on_joined(self):
-        class Person(fixtures.ComparableEntity):
+        class Person(ComparableEntity):
             pass
 
         class Employee(Person):
index f53aedf07d69bd24d6c1a6d9a5e7e46ddb97b7d4..603e71d249ddde62b9011ec992cce4200c9b9864 100644 (file)
@@ -19,7 +19,7 @@ from sqlalchemy.testing import expect_warnings
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing.assertions import expect_raises_message
 from sqlalchemy.testing.assertsql import CompiledSQL
-from sqlalchemy.testing.fixtures import ComparableEntity
+from sqlalchemy.testing.entities import ComparableEntity
 from sqlalchemy.testing.fixtures import fixture_session
 
 
index 58e1ab97b9d230136a5d698bafbc8f83aa1ed401..4b9d3b2e0255087f747c47dee7591e05660eaa5e 100644 (file)
@@ -22,6 +22,7 @@ from sqlalchemy.testing import is_not
 from sqlalchemy.testing import is_true
 from sqlalchemy.testing import not_in
 from sqlalchemy.testing.assertions import assert_warns
+from sqlalchemy.testing.entities import BasicEntity
 from sqlalchemy.testing.util import all_partial_orderings
 from sqlalchemy.testing.util import gc_collect
 
@@ -770,10 +771,10 @@ class AttributesTest(fixtures.ORMTest):
     def test_lazyhistory(self):
         """tests that history functions work with lazy-loading attributes"""
 
-        class Foo(fixtures.BasicEntity):
+        class Foo(BasicEntity):
             pass
 
-        class Bar(fixtures.BasicEntity):
+        class Bar(BasicEntity):
             pass
 
         instrumentation.register_class(Foo)
@@ -1737,7 +1738,7 @@ class PendingBackrefTest(fixtures.ORMTest):
 
 class HistoryTest(fixtures.TestBase):
     def _fixture(self, uselist, useobject, active_history, **kw):
-        class Foo(fixtures.BasicEntity):
+        class Foo(BasicEntity):
             pass
 
         instrumentation.register_class(Foo)
@@ -1752,10 +1753,10 @@ class HistoryTest(fixtures.TestBase):
         return Foo
 
     def _two_obj_fixture(self, uselist, active_history=False):
-        class Foo(fixtures.BasicEntity):
+        class Foo(BasicEntity):
             pass
 
-        class Bar(fixtures.BasicEntity):
+        class Bar(BasicEntity):
             def __bool__(self):
                 assert False
 
@@ -2571,10 +2572,10 @@ class HistoryTest(fixtures.TestBase):
     def test_dict_collections(self):
         # TODO: break into individual tests
 
-        class Foo(fixtures.BasicEntity):
+        class Foo(BasicEntity):
             pass
 
-        class Bar(fixtures.BasicEntity):
+        class Bar(BasicEntity):
             pass
 
         instrumentation.register_class(Foo)
@@ -2630,10 +2631,10 @@ class HistoryTest(fixtures.TestBase):
     def test_object_collections_mutate(self):
         # TODO: break into individual tests
 
-        class Foo(fixtures.BasicEntity):
+        class Foo(BasicEntity):
             pass
 
-        class Bar(fixtures.BasicEntity):
+        class Bar(BasicEntity):
             pass
 
         instrumentation.register_class(Foo)
@@ -2818,10 +2819,10 @@ class HistoryTest(fixtures.TestBase):
     def test_collections_via_backref(self):
         # TODO: break into individual tests
 
-        class Foo(fixtures.BasicEntity):
+        class Foo(BasicEntity):
             pass
 
-        class Bar(fixtures.BasicEntity):
+        class Bar(BasicEntity):
             pass
 
         instrumentation.register_class(Foo)
@@ -2890,10 +2891,10 @@ class LazyloadHistoryTest(fixtures.TestBase):
     def test_lazy_backref_collections(self):
         # TODO: break into individual tests
 
-        class Foo(fixtures.BasicEntity):
+        class Foo(BasicEntity):
             pass
 
-        class Bar(fixtures.BasicEntity):
+        class Bar(BasicEntity):
             pass
 
         lazy_load = []
@@ -2949,10 +2950,10 @@ class LazyloadHistoryTest(fixtures.TestBase):
     def test_collections_via_lazyload(self):
         # TODO: break into individual tests
 
-        class Foo(fixtures.BasicEntity):
+        class Foo(BasicEntity):
             pass
 
-        class Bar(fixtures.BasicEntity):
+        class Bar(BasicEntity):
             pass
 
         lazy_load = []
@@ -3012,7 +3013,7 @@ class LazyloadHistoryTest(fixtures.TestBase):
     def test_scalar_via_lazyload(self):
         # TODO: break into individual tests
 
-        class Foo(fixtures.BasicEntity):
+        class Foo(BasicEntity):
             pass
 
         lazy_load = None
@@ -3068,7 +3069,7 @@ class LazyloadHistoryTest(fixtures.TestBase):
     def test_scalar_via_lazyload_with_active(self):
         # TODO: break into individual tests
 
-        class Foo(fixtures.BasicEntity):
+        class Foo(BasicEntity):
             pass
 
         lazy_load = None
@@ -3129,10 +3130,10 @@ class LazyloadHistoryTest(fixtures.TestBase):
     def test_scalar_object_via_lazyload(self):
         # TODO: break into individual tests
 
-        class Foo(fixtures.BasicEntity):
+        class Foo(BasicEntity):
             pass
 
-        class Bar(fixtures.BasicEntity):
+        class Bar(BasicEntity):
             pass
 
         lazy_load = None
@@ -3195,10 +3196,10 @@ class LazyloadHistoryTest(fixtures.TestBase):
 class CollectionKeyTest(fixtures.ORMTest):
     @testing.fixture
     def dict_collection(self):
-        class Foo(fixtures.BasicEntity):
+        class Foo(BasicEntity):
             pass
 
-        class Bar(fixtures.BasicEntity):
+        class Bar(BasicEntity):
             def __init__(self, name):
                 self.name = name
 
@@ -3222,10 +3223,10 @@ class CollectionKeyTest(fixtures.ORMTest):
 
     @testing.fixture
     def list_collection(self):
-        class Foo(fixtures.BasicEntity):
+        class Foo(BasicEntity):
             pass
 
-        class Bar(fixtures.BasicEntity):
+        class Bar(BasicEntity):
             pass
 
         instrumentation.register_class(Foo)
index 3c49fc8dcd0f17d35e077ec55b13060f65acfe92..6b84ec6f7887d3fb0c6871638a57478282db5e28 100644 (file)
@@ -30,6 +30,7 @@ from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import in_
 from sqlalchemy.testing import not_in
 from sqlalchemy.testing.assertsql import CompiledSQL
+from sqlalchemy.testing.entities import ComparableEntity
 from sqlalchemy.testing.fixtures import fixture_session
 from sqlalchemy.testing.schema import Column
 from sqlalchemy.testing.schema import Table
@@ -3256,13 +3257,13 @@ class DoubleParentO2MOrphanTest(fixtures.MappedTest):
             self.tables.accounts,
         )
 
-        class Customer(fixtures.ComparableEntity):
+        class Customer(ComparableEntity):
             pass
 
-        class Account(fixtures.ComparableEntity):
+        class Account(ComparableEntity):
             pass
 
-        class SalesRep(fixtures.ComparableEntity):
+        class SalesRep(ComparableEntity):
             pass
 
         self.mapper_registry.map_imperatively(
@@ -3428,13 +3429,13 @@ class DoubleParentM2OOrphanTest(fixtures.MappedTest):
             self.tables.addresses,
         )
 
-        class Address(fixtures.ComparableEntity):
+        class Address(ComparableEntity):
             pass
 
-        class Home(fixtures.ComparableEntity):
+        class Home(ComparableEntity):
             pass
 
-        class Business(fixtures.ComparableEntity):
+        class Business(ComparableEntity):
             pass
 
         self.mapper_registry.map_imperatively(Address, addresses)
@@ -3488,13 +3489,13 @@ class DoubleParentM2OOrphanTest(fixtures.MappedTest):
             self.tables.addresses,
         )
 
-        class Address(fixtures.ComparableEntity):
+        class Address(ComparableEntity):
             pass
 
-        class Home(fixtures.ComparableEntity):
+        class Home(ComparableEntity):
             pass
 
-        class Business(fixtures.ComparableEntity):
+        class Business(ComparableEntity):
             pass
 
         self.mapper_registry.map_imperatively(Address, addresses)
@@ -3546,10 +3547,10 @@ class CollectionAssignmentOrphanTest(fixtures.MappedTest):
     def test_basic(self):
         table_b, table_a = self.tables.table_b, self.tables.table_a
 
-        class A(fixtures.ComparableEntity):
+        class A(ComparableEntity):
             pass
 
-        class B(fixtures.ComparableEntity):
+        class B(ComparableEntity):
             pass
 
         self.mapper_registry.map_imperatively(
@@ -4044,10 +4045,10 @@ class PartialFlushTest(fixtures.MappedTest):
     def test_o2m_m2o(self):
         base, noninh_child = self.tables.base, self.tables.noninh_child
 
-        class Base(fixtures.ComparableEntity):
+        class Base(ComparableEntity):
             pass
 
-        class Child(fixtures.ComparableEntity):
+        class Child(ComparableEntity):
             pass
 
         self.mapper_registry.map_imperatively(
@@ -4103,7 +4104,7 @@ class PartialFlushTest(fixtures.MappedTest):
             self.tables.parent,
         )
 
-        class Base(fixtures.ComparableEntity):
+        class Base(ComparableEntity):
             pass
 
         class Parent(Base):
index 562d9b9dc95daa022d6d72c76c2e6d77065415df..d230d4aafc0658af77ff826f8d0071d9f5888ced 100644 (file)
@@ -11,6 +11,7 @@ from sqlalchemy.testing.assertsql import assert_engine
 from sqlalchemy.testing.assertsql import CompiledSQL
 from sqlalchemy.testing.assertsql import Conditional
 from sqlalchemy.testing.assertsql import RegexSQL
+from sqlalchemy.testing.entities import BasicEntity
 from sqlalchemy.testing.fixtures import fixture_session
 from sqlalchemy.testing.schema import Column
 from sqlalchemy.testing.schema import Table
@@ -209,7 +210,7 @@ class ExcludedDefaultsTest(fixtures.MappedTest):
     def test_exclude(self):
         dt = self.tables.dt
 
-        class Foo(fixtures.BasicEntity):
+        class Foo(BasicEntity):
             pass
 
         self.mapper_registry.map_imperatively(
index c93ac6d60ac0a1b457adab04e9370aa772f01b9a..fa044d033c99edd60293181e35d4a041ef43d830 100644 (file)
@@ -43,6 +43,7 @@ from sqlalchemy.testing import eq_
 from sqlalchemy.testing import expect_raises_message
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import is_
+from sqlalchemy.testing.entities import ComparableEntity
 from sqlalchemy.testing.fixtures import fixture_session
 from sqlalchemy.testing.schema import Column
 from sqlalchemy.testing.schema import Table
@@ -2117,7 +2118,7 @@ class WithExpressionTest(fixtures.DeclarativeMappedTest):
     def setup_classes(cls):
         Base = cls.DeclarativeBasic
 
-        class A(fixtures.ComparableEntity, Base):
+        class A(ComparableEntity, Base):
             __tablename__ = "a"
             id = Column(Integer, primary_key=True)
             x = Column(Integer)
@@ -2127,7 +2128,7 @@ class WithExpressionTest(fixtures.DeclarativeMappedTest):
 
             bs = relationship("B", order_by="B.id")
 
-        class A_default(fixtures.ComparableEntity, Base):
+        class A_default(ComparableEntity, Base):
             __tablename__ = "a_default"
             id = Column(Integer, primary_key=True)
             x = Column(Integer)
@@ -2135,7 +2136,7 @@ class WithExpressionTest(fixtures.DeclarativeMappedTest):
 
             my_expr = query_expression(default_expr=literal(15))
 
-        class B(fixtures.ComparableEntity, Base):
+        class B(ComparableEntity, Base):
             __tablename__ = "b"
             id = Column(Integer, primary_key=True)
             a_id = Column(ForeignKey("a.id"))
@@ -2144,7 +2145,7 @@ class WithExpressionTest(fixtures.DeclarativeMappedTest):
 
             b_expr = query_expression()
 
-        class C(fixtures.ComparableEntity, Base):
+        class C(ComparableEntity, Base):
             __tablename__ = "c"
             id = Column(Integer, primary_key=True)
             x = Column(Integer)
@@ -2489,7 +2490,7 @@ class RaiseLoadTest(fixtures.DeclarativeMappedTest):
     def setup_classes(cls):
         Base = cls.DeclarativeBasic
 
-        class A(fixtures.ComparableEntity, Base):
+        class A(ComparableEntity, Base):
             __tablename__ = "a"
             id = Column(Integer, primary_key=True)
             x = Column(Integer)
index 0fa6c94a30d22cbe9eff3b92a931b9ba362bd53f..23248349cd2ccc066bc94504c7e78d75bd708df8 100644 (file)
@@ -58,6 +58,7 @@ from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import is_
 from sqlalchemy.testing import is_true
 from sqlalchemy.testing import mock
+from sqlalchemy.testing.entities import ComparableEntity
 from sqlalchemy.testing.fixtures import CacheKeyFixture
 from sqlalchemy.testing.fixtures import fixture_session
 from sqlalchemy.testing.fixtures import RemoveORMEventsGlobally
@@ -841,7 +842,7 @@ class DeprecatedMapperTest(
 
         assert_col = []
 
-        class User(fixtures.ComparableEntity):
+        class User(ComparableEntity):
             def _get_name(self):
                 assert_col.append(("get", self._name))
                 return self._name
index fa44dbf10dcd6da0871dca8bba7663c9df57ac4c..261269dec1a07a7d77e15ca51a5f6887b4b08f1a 100644 (file)
@@ -41,6 +41,7 @@ from sqlalchemy.testing import is_
 from sqlalchemy.testing import is_not
 from sqlalchemy.testing import mock
 from sqlalchemy.testing.assertsql import CompiledSQL
+from sqlalchemy.testing.entities import ComparableEntity
 from sqlalchemy.testing.fixtures import fixture_session
 from sqlalchemy.testing.schema import Column
 from sqlalchemy.testing.schema import Table
@@ -4318,10 +4319,10 @@ class OrderBySecondaryTest(fixtures.MappedTest):
     def test_ordering(self):
         a, m2m, b = (self.tables.a, self.tables.m2m, self.tables.b)
 
-        class A(fixtures.ComparableEntity):
+        class A(ComparableEntity):
             pass
 
-        class B(fixtures.ComparableEntity):
+        class B(ComparableEntity):
             pass
 
         self.mapper_registry.map_imperatively(
@@ -4361,7 +4362,7 @@ class SelfReferentialEagerTest(fixtures.MappedTest):
     def test_basic(self):
         nodes = self.tables.nodes
 
-        class Node(fixtures.ComparableEntity):
+        class Node(ComparableEntity):
             def append(self, node):
                 self.children.append(node)
 
@@ -4437,7 +4438,7 @@ class SelfReferentialEagerTest(fixtures.MappedTest):
     def test_lazy_fallback_doesnt_affect_eager(self):
         nodes = self.tables.nodes
 
-        class Node(fixtures.ComparableEntity):
+        class Node(ComparableEntity):
             def append(self, node):
                 self.children.append(node)
 
@@ -4484,7 +4485,7 @@ class SelfReferentialEagerTest(fixtures.MappedTest):
     def test_with_deferred(self):
         nodes = self.tables.nodes
 
-        class Node(fixtures.ComparableEntity):
+        class Node(ComparableEntity):
             def append(self, node):
                 self.children.append(node)
 
@@ -4545,7 +4546,7 @@ class SelfReferentialEagerTest(fixtures.MappedTest):
     def test_options(self):
         nodes = self.tables.nodes
 
-        class Node(fixtures.ComparableEntity):
+        class Node(ComparableEntity):
             def append(self, node):
                 self.children.append(node)
 
@@ -4620,7 +4621,7 @@ class SelfReferentialEagerTest(fixtures.MappedTest):
     def test_no_depth(self):
         nodes = self.tables.nodes
 
-        class Node(fixtures.ComparableEntity):
+        class Node(ComparableEntity):
             def append(self, node):
                 self.children.append(node)
 
@@ -4813,7 +4814,7 @@ class SelfReferentialM2MEagerTest(fixtures.MappedTest):
     def test_basic(self):
         widget, widget_rel = self.tables.widget, self.tables.widget_rel
 
-        class Widget(fixtures.ComparableEntity):
+        class Widget(ComparableEntity):
             pass
 
         self.mapper_registry.map_imperatively(
@@ -5236,12 +5237,12 @@ class SubqueryTest(fixtures.MappedTest):
             self.tables.users_table,
         )
 
-        class User(fixtures.ComparableEntity):
+        class User(ComparableEntity):
             @property
             def prop_score(self):
                 return sum([tag.prop_score for tag in self.tags])
 
-        class Tag(fixtures.ComparableEntity):
+        class Tag(ComparableEntity):
             @property
             def prop_score(self):
                 return self.score1 * self.score2
@@ -5395,10 +5396,10 @@ class CorrelatedSubqueryTest(fixtures.MappedTest):
     def _do_test(self, labeled, ondate, aliasstuff):
         stuff, users = self.tables.stuff, self.tables.users
 
-        class User(fixtures.ComparableEntity):
+        class User(ComparableEntity):
             pass
 
-        class Stuff(fixtures.ComparableEntity):
+        class Stuff(ComparableEntity):
             pass
 
         self.mapper_registry.map_imperatively(Stuff, stuff)
index 6b28a637ad33e0ec87599db08e998ffe8bd777ad..51c86a5f1dad90dbc35e5603b5d147e46d46ffa0 100644 (file)
@@ -3912,13 +3912,13 @@ class TestOverlyEagerEquivalentCols(fixtures.MappedTest):
             self.tables.sub1,
         )
 
-        class Base(fixtures.ComparableEntity):
+        class Base(ComparableEntity):
             pass
 
-        class Sub1(fixtures.ComparableEntity):
+        class Sub1(ComparableEntity):
             pass
 
-        class Sub2(fixtures.ComparableEntity):
+        class Sub2(ComparableEntity):
             pass
 
         self.mapper_registry.map_imperatively(
index e3936159adaa2d854cc7cdb8bebd261afe6c97c5..4ab9617123c708898239011ed54b08be5444de4a 100644 (file)
@@ -32,6 +32,7 @@ from sqlalchemy.testing import is_
 from sqlalchemy.testing import is_false
 from sqlalchemy.testing import is_true
 from sqlalchemy.testing.assertsql import CompiledSQL
+from sqlalchemy.testing.entities import ComparableEntity
 from sqlalchemy.testing.fixtures import fixture_session
 from sqlalchemy.testing.schema import Column
 from sqlalchemy.testing.schema import Table
@@ -1025,10 +1026,10 @@ class GetterStateTest(_fixtures.FixtureTest):
             Column("data", MyHashType()),
         )
 
-        class Category(fixtures.ComparableEntity):
+        class Category(ComparableEntity):
             pass
 
-        class Article(fixtures.ComparableEntity):
+        class Article(ComparableEntity):
             pass
 
         self.mapper_registry.map_imperatively(Category, category)
@@ -1314,10 +1315,10 @@ class CorrelatedTest(fixtures.MappedTest):
     def test_correlated_lazyload(self):
         stuff, user_t = self.tables.stuff, self.tables.user_t
 
-        class User(fixtures.ComparableEntity):
+        class User(ComparableEntity):
             pass
 
-        class Stuff(fixtures.ComparableEntity):
+        class Stuff(ComparableEntity):
             pass
 
         self.mapper_registry.map_imperatively(Stuff, stuff)
index 19caf04487aa21e15b91a39f7fefe2186d730769..a3aad69f0871a2e4b3b2fb5f7ca71c78154e006d 100644 (file)
@@ -51,8 +51,8 @@ from sqlalchemy.testing import is_
 from sqlalchemy.testing import is_false
 from sqlalchemy.testing import is_true
 from sqlalchemy.testing import ne_
-from sqlalchemy.testing.fixtures import ComparableEntity
-from sqlalchemy.testing.fixtures import ComparableMixin
+from sqlalchemy.testing.entities import ComparableEntity
+from sqlalchemy.testing.entities import ComparableMixin
 from sqlalchemy.testing.fixtures import fixture_session
 from sqlalchemy.testing.schema import Column
 from sqlalchemy.testing.schema import Table
@@ -973,7 +973,7 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL):
 
         assert_col = []
 
-        class User(fixtures.ComparableEntity):
+        class User(ComparableEntity):
             def _get_name(self):
                 assert_col.append(("get", self._name))
                 return self._name
index 6b3a7c1d6d2897ba1a648dc7f6a983f4f0538343..0c8e2651cdb33d6aadea7b3fbf358fa388bf4460 100644 (file)
@@ -31,6 +31,7 @@ from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import in_
 from sqlalchemy.testing import not_in
 from sqlalchemy.testing.assertsql import CountStatements
+from sqlalchemy.testing.entities import ComparableEntity
 from sqlalchemy.testing.fixtures import fixture_session
 from sqlalchemy.testing.schema import Column
 from sqlalchemy.testing.schema import Table
@@ -2181,10 +2182,10 @@ class LoadOnPendingTest(fixtures.MappedTest):
 
     @classmethod
     def setup_classes(cls):
-        class Rock(cls.Basic, fixtures.ComparableEntity):
+        class Rock(cls.Basic, ComparableEntity):
             pass
 
-        class Bug(cls.Basic, fixtures.ComparableEntity):
+        class Bug(cls.Basic, ComparableEntity):
             pass
 
     def _setup_delete_orphan_o2o(self):
@@ -2251,7 +2252,7 @@ class PolymorphicOnTest(fixtures.MappedTest):
 
     @classmethod
     def setup_classes(cls):
-        class Employee(cls.Basic, fixtures.ComparableEntity):
+        class Employee(cls.Basic, ComparableEntity):
             pass
 
         class Manager(Employee):
index 367f854427a49df6eb8d75ed31ce0c25257f1e2e..ce5c64a43af6b52a98b6a89cb67364e006cfdd97 100644 (file)
@@ -79,6 +79,7 @@ from sqlalchemy.testing.assertions import expect_raises
 from sqlalchemy.testing.assertions import expect_warnings
 from sqlalchemy.testing.assertions import is_not_none
 from sqlalchemy.testing.assertsql import CompiledSQL
+from sqlalchemy.testing.entities import ComparableEntity
 from sqlalchemy.testing.fixtures import fixture_session
 from sqlalchemy.testing.schema import Column
 from sqlalchemy.testing.schema import Table
@@ -1321,7 +1322,7 @@ class GetTest(QueryTest):
 
         s = users.outerjoin(addresses)
 
-        class UserThing(fixtures.ComparableEntity):
+        class UserThing(ComparableEntity):
             pass
 
         registry.map_imperatively(
index 12651fe36432c169ce3a81322bc6d44d44d79c23..2de35a9a1e8a56085a7e565aea4e90e2deaad94f 100644 (file)
@@ -42,6 +42,8 @@ from sqlalchemy.testing import in_
 from sqlalchemy.testing import is_
 from sqlalchemy.testing.assertsql import assert_engine
 from sqlalchemy.testing.assertsql import CompiledSQL
+from sqlalchemy.testing.entities import BasicEntity
+from sqlalchemy.testing.entities import ComparableEntity
 from sqlalchemy.testing.fixtures import fixture_session
 from sqlalchemy.testing.schema import Column
 from sqlalchemy.testing.schema import Table
@@ -334,10 +336,10 @@ class M2ODontOverwriteFKTest(fixtures.MappedTest):
     def _fixture(self, uselist=False):
         a, b = self.tables.a, self.tables.b
 
-        class A(fixtures.BasicEntity):
+        class A(BasicEntity):
             pass
 
-        class B(fixtures.BasicEntity):
+        class B(BasicEntity):
             pass
 
         self.mapper_registry.map_imperatively(
@@ -1720,7 +1722,7 @@ class FKsAsPksTest(fixtures.MappedTest):
         )
         tableC.create(connection)
 
-        class C(fixtures.BasicEntity):
+        class C(BasicEntity):
             pass
 
         self.mapper_registry.map_imperatively(
@@ -1958,10 +1960,10 @@ class RelationshipToSelectableTest(fixtures.MappedTest):
     def test_basic(self):
         items = self.tables.items
 
-        class Container(fixtures.BasicEntity):
+        class Container(BasicEntity):
             pass
 
-        class LineItem(fixtures.BasicEntity):
+        class LineItem(BasicEntity):
             pass
 
         container_select = (
@@ -2050,10 +2052,10 @@ class FKEquatedToConstantTest(fixtures.MappedTest):
     def test_basic(self):
         tag_foo, tags = self.tables.tag_foo, self.tables.tags
 
-        class Tag(fixtures.ComparableEntity):
+        class Tag(ComparableEntity):
             pass
 
-        class TagInstance(fixtures.ComparableEntity):
+        class TagInstance(ComparableEntity):
             pass
 
         self.mapper_registry.map_imperatively(
@@ -2682,13 +2684,13 @@ class TypeMatchTest(fixtures.MappedTest):
     def test_o2m_oncascade(self):
         a, c, b = (self.tables.a, self.tables.c, self.tables.b)
 
-        class A(fixtures.BasicEntity):
+        class A(BasicEntity):
             pass
 
-        class B(fixtures.BasicEntity):
+        class B(BasicEntity):
             pass
 
-        class C(fixtures.BasicEntity):
+        class C(BasicEntity):
             pass
 
         self.mapper_registry.map_imperatively(
@@ -2716,13 +2718,13 @@ class TypeMatchTest(fixtures.MappedTest):
     def test_o2m_onflush(self):
         a, c, b = (self.tables.a, self.tables.c, self.tables.b)
 
-        class A(fixtures.BasicEntity):
+        class A(BasicEntity):
             pass
 
-        class B(fixtures.BasicEntity):
+        class B(BasicEntity):
             pass
 
-        class C(fixtures.BasicEntity):
+        class C(BasicEntity):
             pass
 
         self.mapper_registry.map_imperatively(
@@ -2747,10 +2749,10 @@ class TypeMatchTest(fixtures.MappedTest):
     def test_o2m_nopoly_onflush(self):
         a, c, b = (self.tables.a, self.tables.c, self.tables.b)
 
-        class A(fixtures.BasicEntity):
+        class A(BasicEntity):
             pass
 
-        class B(fixtures.BasicEntity):
+        class B(BasicEntity):
             pass
 
         class C(B):
@@ -2778,13 +2780,13 @@ class TypeMatchTest(fixtures.MappedTest):
     def test_m2o_nopoly_onflush(self):
         a, b, d = (self.tables.a, self.tables.b, self.tables.d)
 
-        class A(fixtures.BasicEntity):
+        class A(BasicEntity):
             pass
 
         class B(A):
             pass
 
-        class D(fixtures.BasicEntity):
+        class D(BasicEntity):
             pass
 
         self.mapper_registry.map_imperatively(A, a)
@@ -2805,13 +2807,13 @@ class TypeMatchTest(fixtures.MappedTest):
     def test_m2o_oncascade(self):
         a, b, d = (self.tables.a, self.tables.b, self.tables.d)
 
-        class A(fixtures.BasicEntity):
+        class A(BasicEntity):
             pass
 
-        class B(fixtures.BasicEntity):
+        class B(BasicEntity):
             pass
 
-        class D(fixtures.BasicEntity):
+        class D(BasicEntity):
             pass
 
         self.mapper_registry.map_imperatively(A, a)
@@ -2865,10 +2867,10 @@ class TypedAssociationTable(fixtures.MappedTest):
 
         t2, t3, t1 = (self.tables.t2, self.tables.t3, self.tables.t1)
 
-        class T1(fixtures.BasicEntity):
+        class T1(BasicEntity):
             pass
 
-        class T2(fixtures.BasicEntity):
+        class T2(BasicEntity):
             pass
 
         self.mapper_registry.map_imperatively(T2, t2)
@@ -2928,10 +2930,10 @@ class CustomOperatorTest(fixtures.MappedTest, AssertsCompiledSQL):
         )
 
     def test_join_on_custom_op_legacy_is_comparison(self):
-        class A(fixtures.BasicEntity):
+        class A(BasicEntity):
             pass
 
-        class B(fixtures.BasicEntity):
+        class B(BasicEntity):
             pass
 
         self.mapper_registry.map_imperatively(
@@ -2955,10 +2957,10 @@ class CustomOperatorTest(fixtures.MappedTest, AssertsCompiledSQL):
         )
 
     def test_join_on_custom_bool_op(self):
-        class A(fixtures.BasicEntity):
+        class A(BasicEntity):
             pass
 
-        class B(fixtures.BasicEntity):
+        class B(BasicEntity):
             pass
 
         self.mapper_registry.map_imperatively(
@@ -3016,10 +3018,10 @@ class ViewOnlyHistoryTest(fixtures.MappedTest):
         return s
 
     def test_o2m_viewonly_oneside(self):
-        class A(fixtures.ComparableEntity):
+        class A(ComparableEntity):
             pass
 
-        class B(fixtures.ComparableEntity):
+        class B(ComparableEntity):
             pass
 
         self.mapper_registry.map_imperatively(
@@ -3049,10 +3051,10 @@ class ViewOnlyHistoryTest(fixtures.MappedTest):
         assert b1 not in sess.dirty
 
     def test_m2o_viewonly_oneside(self):
-        class A(fixtures.ComparableEntity):
+        class A(ComparableEntity):
             pass
 
-        class B(fixtures.ComparableEntity):
+        class B(ComparableEntity):
             pass
 
         self.mapper_registry.map_imperatively(
@@ -3082,10 +3084,10 @@ class ViewOnlyHistoryTest(fixtures.MappedTest):
         assert b1 not in sess.dirty
 
     def test_o2m_viewonly_only(self):
-        class A(fixtures.ComparableEntity):
+        class A(ComparableEntity):
             pass
 
-        class B(fixtures.ComparableEntity):
+        class B(ComparableEntity):
             pass
 
         self.mapper_registry.map_imperatively(
@@ -3103,10 +3105,10 @@ class ViewOnlyHistoryTest(fixtures.MappedTest):
         self._assert_fk(a1, b1, False)
 
     def test_m2o_viewonly_only(self):
-        class A(fixtures.ComparableEntity):
+        class A(ComparableEntity):
             pass
 
-        class B(fixtures.ComparableEntity):
+        class B(ComparableEntity):
             pass
 
         self.mapper_registry.map_imperatively(A, self.tables.t1)
@@ -3151,10 +3153,10 @@ class ViewOnlyM2MBackrefTest(fixtures.MappedTest):
     def test_viewonly(self):
         t1t2, t2, t1 = (self.tables.t1t2, self.tables.t2, self.tables.t1)
 
-        class A(fixtures.ComparableEntity):
+        class A(ComparableEntity):
             pass
 
-        class B(fixtures.ComparableEntity):
+        class B(ComparableEntity):
             pass
 
         self.mapper_registry.map_imperatively(
@@ -3226,13 +3228,13 @@ class ViewOnlyOverlappingNames(fixtures.MappedTest):
 
         t2, t3, t1 = (self.tables.t2, self.tables.t3, self.tables.t1)
 
-        class C1(fixtures.BasicEntity):
+        class C1(BasicEntity):
             pass
 
-        class C2(fixtures.BasicEntity):
+        class C2(BasicEntity):
             pass
 
-        class C3(fixtures.BasicEntity):
+        class C3(BasicEntity):
             pass
 
         self.mapper_registry.map_imperatively(
@@ -3360,10 +3362,10 @@ class ViewOnlySyncBackref(fixtures.MappedTest):
     @testing.combinations(True, False, None, argnames="B_a_sync")
     @testing.combinations(True, False, argnames="B_a_view")
     def test_case(self, B_a_view, B_a_sync, A_bs_view, A_bs_sync):
-        class A(fixtures.ComparableEntity):
+        class A(ComparableEntity):
             pass
 
-        class B(fixtures.ComparableEntity):
+        class B(ComparableEntity):
             pass
 
         case = self.cases[(B_a_view, B_a_sync, A_bs_view, A_bs_sync)]
@@ -3490,13 +3492,13 @@ class ViewOnlyUniqueNames(fixtures.MappedTest):
 
         t2, t3, t1 = (self.tables.t2, self.tables.t3, self.tables.t1)
 
-        class C1(fixtures.BasicEntity):
+        class C1(BasicEntity):
             pass
 
-        class C2(fixtures.BasicEntity):
+        class C2(BasicEntity):
             pass
 
-        class C3(fixtures.BasicEntity):
+        class C3(BasicEntity):
             pass
 
         self.mapper_registry.map_imperatively(
@@ -3596,10 +3598,10 @@ class ViewOnlyNonEquijoin(fixtures.MappedTest):
     def test_viewonly_join(self):
         bars, foos = self.tables.bars, self.tables.foos
 
-        class Foo(fixtures.ComparableEntity):
+        class Foo(ComparableEntity):
             pass
 
-        class Bar(fixtures.ComparableEntity):
+        class Bar(ComparableEntity):
             pass
 
         self.mapper_registry.map_imperatively(
@@ -3669,10 +3671,10 @@ class ViewOnlyRepeatedRemoteColumn(fixtures.MappedTest):
     def test_relationship_on_or(self):
         bars, foos = self.tables.bars, self.tables.foos
 
-        class Foo(fixtures.ComparableEntity):
+        class Foo(ComparableEntity):
             pass
 
-        class Bar(fixtures.ComparableEntity):
+        class Bar(ComparableEntity):
             pass
 
         self.mapper_registry.map_imperatively(
@@ -3744,10 +3746,10 @@ class ViewOnlyRepeatedLocalColumn(fixtures.MappedTest):
     def test_relationship_on_or(self):
         bars, foos = self.tables.bars, self.tables.foos
 
-        class Foo(fixtures.ComparableEntity):
+        class Foo(ComparableEntity):
             pass
 
-        class Bar(fixtures.ComparableEntity):
+        class Bar(ComparableEntity):
             pass
 
         self.mapper_registry.map_imperatively(
@@ -4006,7 +4008,7 @@ class RemoteForeignBetweenColsTest(fixtures.DeclarativeMappedTest):
     def setup_classes(cls):
         Base = cls.DeclarativeBasic
 
-        class Network(fixtures.ComparableEntity, Base):
+        class Network(ComparableEntity, Base):
             __tablename__ = "network"
 
             id = Column(
@@ -4023,7 +4025,7 @@ class RemoteForeignBetweenColsTest(fixtures.DeclarativeMappedTest):
                 viewonly=True,
             )
 
-        class Address(fixtures.ComparableEntity, Base):
+        class Address(ComparableEntity, Base):
             __tablename__ = "address"
 
             ip_addr = Column(Integer, primary_key=True)
index 8c6ddfa0e58ce5bb8158c8f0d758ad19ef5ef878..509137afe379d1983cc90b6945fd660045a36d70 100644 (file)
@@ -17,6 +17,7 @@ from sqlalchemy.testing import eq_
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import is_
 from sqlalchemy.testing import mock
+from sqlalchemy.testing.entities import ComparableEntity
 from sqlalchemy.testing.schema import Column
 from sqlalchemy.testing.schema import Table
 
@@ -49,10 +50,10 @@ class ScopedSessionTest(fixtures.MappedTest):
         class CustomQuery(query.Query):
             pass
 
-        class SomeObject(fixtures.ComparableEntity):
+        class SomeObject(ComparableEntity):
             query = Session.query_property()
 
-        class SomeOtherObject(fixtures.ComparableEntity):
+        class SomeOtherObject(ComparableEntity):
             query = Session.query_property()
             custom_query = Session.query_property(query_cls=CustomQuery)
 
index 2fdc12574f0fd9ee541b0e87c1d1c027f65372c3..c9907c765157e3de3a3ebb1836180c5775ac3679 100644 (file)
@@ -29,7 +29,7 @@ from sqlalchemy.testing import mock
 from sqlalchemy.testing.assertsql import AllOf
 from sqlalchemy.testing.assertsql import assert_engine
 from sqlalchemy.testing.assertsql import CompiledSQL
-from sqlalchemy.testing.fixtures import ComparableEntity
+from sqlalchemy.testing.entities import ComparableEntity
 from sqlalchemy.testing.fixtures import fixture_session
 from sqlalchemy.testing.schema import Column
 from sqlalchemy.testing.schema import Table
@@ -1712,10 +1712,10 @@ class OrderBySecondaryTest(fixtures.MappedTest):
     def test_ordering(self):
         a, m2m, b = (self.tables.a, self.tables.m2m, self.tables.b)
 
-        class A(fixtures.ComparableEntity):
+        class A(ComparableEntity):
             pass
 
-        class B(fixtures.ComparableEntity):
+        class B(ComparableEntity):
             pass
 
         self.mapper_registry.map_imperatively(
@@ -2239,14 +2239,14 @@ class TupleTest(fixtures.DeclarativeMappedTest):
     def setup_classes(cls):
         Base = cls.DeclarativeBasic
 
-        class A(fixtures.ComparableEntity, Base):
+        class A(ComparableEntity, Base):
             __tablename__ = "a"
             id1 = Column(Integer, primary_key=True)
             id2 = Column(Integer, primary_key=True)
 
             bs = relationship("B", order_by="B.id", back_populates="a")
 
-        class B(fixtures.ComparableEntity, Base):
+        class B(ComparableEntity, Base):
             __tablename__ = "b"
             id = Column(Integer, primary_key=True)
             a_id1 = Column()
@@ -2355,12 +2355,12 @@ class ChunkingTest(fixtures.DeclarativeMappedTest):
     def setup_classes(cls):
         Base = cls.DeclarativeBasic
 
-        class A(fixtures.ComparableEntity, Base):
+        class A(ComparableEntity, Base):
             __tablename__ = "a"
             id = Column(Integer, primary_key=True)
             bs = relationship("B", order_by="B.id", back_populates="a")
 
-        class B(fixtures.ComparableEntity, Base):
+        class B(ComparableEntity, Base):
             __tablename__ = "b"
             id = Column(Integer, primary_key=True)
             a_id = Column(ForeignKey("a.id"))
@@ -2955,7 +2955,7 @@ class SelfRefInheritanceAliasedTest(
     def setup_classes(cls):
         Base = cls.DeclarativeBasic
 
-        class Foo(fixtures.ComparableEntity, Base):
+        class Foo(ComparableEntity, Base):
             __tablename__ = "foo"
             id = Column(Integer, primary_key=True)
             type = Column(String(50))
@@ -3203,14 +3203,14 @@ class MissingForeignTest(
     def setup_classes(cls):
         Base = cls.DeclarativeBasic
 
-        class A(fixtures.ComparableEntity, Base):
+        class A(ComparableEntity, Base):
             __tablename__ = "a"
             id = Column(Integer, primary_key=True)
             b_id = Column(Integer)
             b = relationship("B", primaryjoin="foreign(A.b_id) == B.id")
             q = Column(Integer)
 
-        class B(fixtures.ComparableEntity, Base):
+        class B(ComparableEntity, Base):
             __tablename__ = "b"
             id = Column(Integer, primary_key=True)
             x = Column(Integer)
@@ -3256,7 +3256,7 @@ class M2OWDegradeTest(
     def setup_classes(cls):
         Base = cls.DeclarativeBasic
 
-        class A(fixtures.ComparableEntity, Base):
+        class A(ComparableEntity, Base):
             __tablename__ = "a"
             id = Column(Integer, primary_key=True)
             b_id = Column(ForeignKey("b.id"))
@@ -3264,7 +3264,7 @@ class M2OWDegradeTest(
             b_no_omit_join = relationship("B", omit_join=False, overlaps="b")
             q = Column(Integer)
 
-        class B(fixtures.ComparableEntity, Base):
+        class B(ComparableEntity, Base):
             __tablename__ = "b"
             id = Column(Integer, primary_key=True)
             x = Column(Integer)
index 1a83a58be8a1d52ee0fe0c2c3a7cd7d387d3aa3d..00564cfb656fd5808d6ec7e5f1631c53a481955c 100644 (file)
@@ -1766,10 +1766,10 @@ class OrderBySecondaryTest(fixtures.MappedTest):
     def test_ordering(self):
         a, m2m, b = (self.tables.a, self.tables.m2m, self.tables.b)
 
-        class A(fixtures.ComparableEntity):
+        class A(ComparableEntity):
             pass
 
-        class B(fixtures.ComparableEntity):
+        class B(ComparableEntity):
             pass
 
         self.mapper_registry.map_imperatively(
@@ -2440,7 +2440,7 @@ class SelfReferentialTest(fixtures.MappedTest):
     def test_basic(self):
         nodes = self.tables.nodes
 
-        class Node(fixtures.ComparableEntity):
+        class Node(ComparableEntity):
             def append(self, node):
                 self.children.append(node)
 
@@ -2516,7 +2516,7 @@ class SelfReferentialTest(fixtures.MappedTest):
     def test_lazy_fallback_doesnt_affect_eager(self):
         nodes = self.tables.nodes
 
-        class Node(fixtures.ComparableEntity):
+        class Node(ComparableEntity):
             def append(self, node):
                 self.children.append(node)
 
@@ -2562,7 +2562,7 @@ class SelfReferentialTest(fixtures.MappedTest):
     def test_with_deferred(self):
         nodes = self.tables.nodes
 
-        class Node(fixtures.ComparableEntity):
+        class Node(ComparableEntity):
             def append(self, node):
                 self.children.append(node)
 
@@ -2623,7 +2623,7 @@ class SelfReferentialTest(fixtures.MappedTest):
     def test_options(self):
         nodes = self.tables.nodes
 
-        class Node(fixtures.ComparableEntity):
+        class Node(ComparableEntity):
             def append(self, node):
                 self.children.append(node)
 
@@ -2680,7 +2680,7 @@ class SelfReferentialTest(fixtures.MappedTest):
 
         nodes = self.tables.nodes
 
-        class Node(fixtures.ComparableEntity):
+        class Node(ComparableEntity):
             def append(self, node):
                 self.children.append(node)
 
index 78b56f1d463165e0a2deac7ffe1820881646aa62..0937c354f988a5a0e1d0b89722860d6f98ca6cd9 100644 (file)
@@ -31,6 +31,8 @@ from sqlalchemy.testing.assertions import expect_raises_message
 from sqlalchemy.testing.assertsql import AllOf
 from sqlalchemy.testing.assertsql import CompiledSQL
 from sqlalchemy.testing.assertsql import Conditional
+from sqlalchemy.testing.entities import BasicEntity
+from sqlalchemy.testing.entities import ComparableEntity
 from sqlalchemy.testing.fixtures import fixture_session
 from sqlalchemy.testing.provision import normalize_sequence
 from sqlalchemy.testing.schema import Column
@@ -209,10 +211,10 @@ class UnicodeSchemaTest(fixtures.MappedTest):
     def test_mapping(self):
         t2, t1 = self.tables.t2, self.tables.t1
 
-        class A(fixtures.ComparableEntity):
+        class A(ComparableEntity):
             pass
 
-        class B(fixtures.ComparableEntity):
+        class B(ComparableEntity):
             pass
 
         self.mapper_registry.map_imperatively(
@@ -251,7 +253,7 @@ class UnicodeSchemaTest(fixtures.MappedTest):
     def test_inheritance_mapping(self):
         t2, t1 = self.tables.t2, self.tables.t1
 
-        class A(fixtures.ComparableEntity):
+        class A(ComparableEntity):
             pass
 
         class B(A):
@@ -1030,7 +1032,7 @@ class ColumnCollisionTest(fixtures.MappedTest):
     def test_naming(self):
         book = self.tables.book
 
-        class Book(fixtures.ComparableEntity):
+        class Book(ComparableEntity):
             pass
 
         self.mapper_registry.map_imperatively(Book, book)
@@ -1909,7 +1911,7 @@ class SaveTest(_fixtures.FixtureTest):
     def test_synonym(self):
         users = self.tables.users
 
-        class SUser(fixtures.BasicEntity):
+        class SUser(BasicEntity):
             def _get_name(self):
                 return "User:" + self.name
 
@@ -2773,7 +2775,7 @@ class ManyToManyTest(_fixtures.FixtureTest):
             self.classes.Item,
         )
 
-        class IKAssociation(fixtures.ComparableEntity):
+        class IKAssociation(ComparableEntity):
             pass
 
         self.mapper_registry.map_imperatively(Keyword, keywords)
@@ -3026,7 +3028,7 @@ class BooleanColTest(fixtures.MappedTest):
         t1_t = self.tables.t1_t
 
         # use the regular mapper
-        class T(fixtures.ComparableEntity):
+        class T(ComparableEntity):
             pass
 
         self.mapper_registry.map_imperatively(T, t1_t)
index 5cf8bd573f33c18a67ed613381b5b239b650ea48..5ca34d91747d4d1d9f06ac6966a82087706c6a91 100644 (file)
@@ -44,6 +44,8 @@ from sqlalchemy.testing.assertsql import AllOf
 from sqlalchemy.testing.assertsql import CompiledSQL
 from sqlalchemy.testing.assertsql import Conditional
 from sqlalchemy.testing.assertsql import RegexSQL
+from sqlalchemy.testing.entities import BasicEntity
+from sqlalchemy.testing.entities import ComparableEntity
 from sqlalchemy.testing.fixtures import fixture_session
 from sqlalchemy.testing.provision import normalize_sequence
 from sqlalchemy.testing.schema import Column
@@ -1372,10 +1374,10 @@ class SingleCyclePlusAttributeTest(
     def test_flush_size(self):
         foobars, nodes = self.tables.foobars, self.tables.nodes
 
-        class Node(fixtures.ComparableEntity):
+        class Node(ComparableEntity):
             pass
 
-        class FooBar(fixtures.ComparableEntity):
+        class FooBar(ComparableEntity):
             pass
 
         self.mapper_registry.map_imperatively(
@@ -1440,7 +1442,7 @@ class SingleCycleM2MTest(
     def test_many_to_many_one(self):
         nodes, node_to_nodes = self.tables.nodes, self.tables.node_to_nodes
 
-        class Node(fixtures.ComparableEntity):
+        class Node(ComparableEntity):
             pass
 
         self.mapper_registry.map_imperatively(
@@ -1584,10 +1586,10 @@ class RowswitchAccountingTest(fixtures.MappedTest):
     def _fixture(self):
         parent, child = self.tables.parent, self.tables.child
 
-        class Parent(fixtures.BasicEntity):
+        class Parent(BasicEntity):
             pass
 
-        class Child(fixtures.BasicEntity):
+        class Child(BasicEntity):
             pass
 
         self.mapper_registry.map_imperatively(
@@ -1678,13 +1680,13 @@ class RowswitchM2OTest(fixtures.MappedTest):
     def _fixture(self):
         a, b, c = self.tables.a, self.tables.b, self.tables.c
 
-        class A(fixtures.BasicEntity):
+        class A(BasicEntity):
             pass
 
-        class B(fixtures.BasicEntity):
+        class B(BasicEntity):
             pass
 
-        class C(fixtures.BasicEntity):
+        class C(BasicEntity):
             pass
 
         self.mapper_registry.map_imperatively(
@@ -1787,10 +1789,10 @@ class BasicStaleChecksTest(fixtures.MappedTest):
     def _fixture(self, confirm_deleted_rows=True):
         parent, child = self.tables.parent, self.tables.child
 
-        class Parent(fixtures.BasicEntity):
+        class Parent(BasicEntity):
             pass
 
-        class Child(fixtures.BasicEntity):
+        class Child(BasicEntity):
             pass
 
         self.mapper_registry.map_imperatively(
@@ -2081,7 +2083,7 @@ class BatchInsertsTest(fixtures.MappedTest, testing.AssertsExecutionResults):
 
         t = self.tables.t
 
-        class T(fixtures.ComparableEntity):
+        class T(ComparableEntity):
             pass
 
         mp = self.mapper_registry.map_imperatively(T, t)
index 990d6a4c4b1f411a2ee57c1fbf7bbacd2511abea..df7334d5cb584d1a81b23badd03e14b6a2b6de6f 100644 (file)
@@ -9,8 +9,8 @@ from sqlalchemy.orm import validates
 from sqlalchemy.testing import assert_raises
 from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.testing import eq_
-from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import ne_
+from sqlalchemy.testing.entities import ComparableEntity
 from sqlalchemy.testing.fixtures import fixture_session
 from test.orm import _fixtures
 
@@ -20,7 +20,7 @@ class ValidatorTest(_fixtures.FixtureTest):
         users = self.tables.users
         canary = Mock()
 
-        class User(fixtures.ComparableEntity):
+        class User(ComparableEntity):
             @validates("name")
             def validate_name(self, key, name):
                 canary(key, name)
@@ -52,7 +52,7 @@ class ValidatorTest(_fixtures.FixtureTest):
 
         canary = Mock()
 
-        class User(fixtures.ComparableEntity):
+        class User(ComparableEntity):
             @validates("addresses")
             def validate_address(self, key, ad):
                 canary(key, ad)
@@ -87,7 +87,7 @@ class ValidatorTest(_fixtures.FixtureTest):
             self.classes.Address,
         )
 
-        class User(fixtures.ComparableEntity):
+        class User(ComparableEntity):
             @validates("name")
             def validate_name(self, key, name):
                 ne_(name, "fred")
@@ -119,7 +119,7 @@ class ValidatorTest(_fixtures.FixtureTest):
         )
         canary = Mock()
 
-        class User(fixtures.ComparableEntity):
+        class User(ComparableEntity):
             @validates("name", include_removes=True)
             def validate_name(self, key, item, remove):
                 canary(key, item, remove)
@@ -175,7 +175,7 @@ class ValidatorTest(_fixtures.FixtureTest):
             self.classes.Address,
         )
 
-        class User(fixtures.ComparableEntity):
+        class User(ComparableEntity):
             @validates("addresses", include_removes=True)
             def validate_address(self, key, item, remove):
                 if not remove:
@@ -210,7 +210,7 @@ class ValidatorTest(_fixtures.FixtureTest):
             self.classes.Address,
         )
 
-        class User(fixtures.ComparableEntity):
+        class User(ComparableEntity):
             @validates("addresses", include_removes=True)
             def validate_address(self, key, item, remove):
                 if not remove:
@@ -264,7 +264,7 @@ class ValidatorTest(_fixtures.FixtureTest):
                 ne_(name, "fred")
                 return name + " modified"
 
-        class User(fixtures.ComparableEntity):
+        class User(ComparableEntity):
             sv = validates("name")(SomeValidator())
 
         self.mapper_registry.map_imperatively(User, users)
@@ -332,7 +332,7 @@ class ValidatorTest(_fixtures.FixtureTest):
             bool(include_removes) and not include_removes.default
         )
 
-        class User(fixtures.ComparableEntity):
+        class User(ComparableEntity):
             if need_remove_param:
 
                 @validates("addresses", **validate_kw)
@@ -347,7 +347,7 @@ class ValidatorTest(_fixtures.FixtureTest):
                     canary(key, item)
                     return item
 
-        class Address(fixtures.ComparableEntity):
+        class Address(ComparableEntity):
             if need_remove_param:
 
                 @validates("user", **validate_kw)
diff --git a/test/typing/plain_files/engine/engines.py b/test/typing/plain_files/engine/engines.py
new file mode 100644 (file)
index 0000000..5777b91
--- /dev/null
@@ -0,0 +1,34 @@
+from sqlalchemy import create_engine
+from sqlalchemy import Pool
+from sqlalchemy import text
+
+
+def regular() -> None:
+    e = create_engine("sqlite://")
+
+    # EXPECTED_TYPE: Engine
+    reveal_type(e)
+
+    with e.connect() as conn:
+        # EXPECTED_TYPE: Connection
+        reveal_type(conn)
+
+        result = conn.execute(text("select * from table"))
+
+        # EXPECTED_TYPE: CursorResult[Any]
+        reveal_type(result)
+
+    with e.begin() as conn:
+        # EXPECTED_TYPE: Connection
+        reveal_type(conn)
+
+        result = conn.execute(text("select * from table"))
+
+        # EXPECTED_TYPE: CursorResult[Any]
+        reveal_type(result)
+
+    engine = create_engine("postgresql://scott:tiger@localhost/test")
+    status: str = engine.pool.status()
+    other_pool: Pool = engine.pool.recreate()
+
+    print(status, other_pool)
similarity index 94%
rename from test/ext/mypy/plain_files/async_sessionmaker.py
rename to test/typing/plain_files/ext/asyncio/async_sessionmaker.py
index c253774e2e9c6760245594b0dedda411d1269b46..664ff0411dff74fe68c6af01e73181f3d56cdef3 100644 (file)
@@ -88,5 +88,9 @@ async def async_main() -> None:
 
         await session.commit()
 
+        trans_ctx = engine.begin()
+        async with trans_ctx as connection:
+            await connection.execute(select(A))
+
 
 asyncio.run(async_main())
diff --git a/test/typing/plain_files/ext/asyncio/async_stuff.py b/test/typing/plain_files/ext/asyncio/async_stuff.py
new file mode 100644 (file)
index 0000000..9afd0b8
--- /dev/null
@@ -0,0 +1,39 @@
+from asyncio import current_task
+
+from sqlalchemy import text
+from sqlalchemy.ext.asyncio import async_scoped_session
+from sqlalchemy.ext.asyncio import async_sessionmaker
+from sqlalchemy.ext.asyncio import AsyncSession
+from sqlalchemy.ext.asyncio import create_async_engine
+
+
+engine = create_async_engine("")
+SM = async_sessionmaker(engine, class_=AsyncSession)
+
+async_session = AsyncSession(engine)
+
+as_session = async_scoped_session(SM, current_task)
+
+
+async def go() -> None:
+    r = await async_session.scalars(text("select 1"), params=[])
+    r.first()
+    sr = await async_session.stream_scalars(text("select 1"), params=[])
+    await sr.all()
+    r = await as_session.scalars(text("select 1"), params=[])
+    r.first()
+    sr = await as_session.stream_scalars(text("select 1"), params=[])
+    await sr.all()
+
+    async with engine.connect() as conn:
+        cr = await conn.scalars(text("select 1"))
+        cr.first()
+        scr = await conn.stream_scalars(text("select 1"))
+        await scr.all()
+
+    ast = async_session.get_transaction()
+    if ast:
+        ast.is_active
+    nt = async_session.get_nested_transaction()
+    if nt:
+        nt.is_active
diff --git a/test/typing/plain_files/ext/asyncio/create_proxy_methods.py b/test/typing/plain_files/ext/asyncio/create_proxy_methods.py
new file mode 100644 (file)
index 0000000..235cf32
--- /dev/null
@@ -0,0 +1,97 @@
+from sqlalchemy import text
+from sqlalchemy.ext.asyncio import async_scoped_session
+from sqlalchemy.ext.asyncio import AsyncConnection
+from sqlalchemy.ext.asyncio import AsyncEngine
+from sqlalchemy.ext.asyncio import AsyncSession
+from sqlalchemy.ext.asyncio import create_async_engine
+from sqlalchemy.ext.asyncio.session import async_sessionmaker
+
+# async engine
+async_engine: AsyncEngine = create_async_engine("")
+async_engine.clear_compiled_cache()
+async_engine.update_execution_options()
+async_engine.get_execution_options()
+async_engine.url
+async_engine.pool
+async_engine.dialect
+async_engine.engine
+async_engine.name
+async_engine.driver
+async_engine.echo
+
+
+# async connection
+async def go_async_conn() -> None:
+    async_conn: AsyncConnection = await async_engine.connect()
+    async_conn.closed
+    async_conn.invalidated
+    async_conn.dialect
+    async_conn.default_isolation_level
+
+
+# async session
+AsyncSession.object_session(object())
+AsyncSession.identity_key()
+async_session: AsyncSession = AsyncSession(async_engine)
+in_: bool = "foo" in async_session
+list(async_session)
+async_session.add(object())
+async_session.add_all([])
+async_session.expire(object())
+async_session.expire_all()
+async_session.expunge(object())
+async_session.expunge_all()
+async_session.get_bind()
+async_session.is_modified(object())
+async_session.in_transaction()
+async_session.in_nested_transaction()
+async_session.dirty
+async_session.deleted
+async_session.new
+async_session.identity_map
+async_session.is_active
+async_session.autoflush
+async_session.no_autoflush
+async_session.info
+
+
+# async scoped session
+async def test_async_scoped_session() -> None:
+    async_scoped_session.object_session(object())
+    async_scoped_session.identity_key()
+    await async_scoped_session.close_all()
+    asm = async_sessionmaker()
+    async_ss = async_scoped_session(asm, lambda: 42)
+    value: bool = "foo" in async_ss
+    print(value)
+    list(async_ss)
+    async_ss.add(object())
+    async_ss.add_all([])
+    async_ss.begin()
+    async_ss.begin_nested()
+    await async_ss.close()
+    await async_ss.commit()
+    await async_ss.connection()
+    await async_ss.delete(object())
+    await async_ss.execute(text("select 1"))
+    async_ss.expire(object())
+    async_ss.expire_all()
+    async_ss.expunge(object())
+    async_ss.expunge_all()
+    await async_ss.flush()
+    await async_ss.get(object, 1)
+    async_ss.get_bind()
+    async_ss.is_modified(object())
+    await async_ss.merge(object())
+    await async_ss.refresh(object())
+    await async_ss.rollback()
+    await async_ss.scalar(text("select 1"))
+    async_ss.bind
+    async_ss.dirty
+    async_ss.deleted
+    async_ss.new
+    async_ss.identity_map
+    async_ss.is_active
+    async_ss.autoflush
+    async_ss.no_autoflush
+    async_ss.info
similarity index 73%
rename from test/ext/mypy/plain_files/engines.py
rename to test/typing/plain_files/ext/asyncio/engines.py
index b7621aca42d59bf898bf28981c386cd4c5bc19c8..598d319a7765ef17db5d0360369df1b8717739cf 100644 (file)
@@ -1,33 +1,7 @@
-from sqlalchemy import create_engine
 from sqlalchemy import text
 from sqlalchemy.ext.asyncio import create_async_engine
 
 
-def regular() -> None:
-    e = create_engine("sqlite://")
-
-    # EXPECTED_TYPE: Engine
-    reveal_type(e)
-
-    with e.connect() as conn:
-        # EXPECTED_TYPE: Connection
-        reveal_type(conn)
-
-        result = conn.execute(text("select * from table"))
-
-        # EXPECTED_TYPE: CursorResult[Any]
-        reveal_type(result)
-
-    with e.begin() as conn:
-        # EXPECTED_TYPE: Connection
-        reveal_type(conn)
-
-        result = conn.execute(text("select * from table"))
-
-        # EXPECTED_TYPE: CursorResult[Any]
-        reveal_type(result)
-
-
 async def asyncio() -> None:
     e = create_async_engine("sqlite://")
 
similarity index 78%
rename from test/ext/mypy/inspection_inspect.py
rename to test/typing/plain_files/inspection_inspect.py
index c67b515f40701e933dcb9119dc92133f550a727e..155ceffc0358d574a4ae43a884b5aa27935a15d5 100644 (file)
@@ -4,16 +4,21 @@ test inspect()
 however this is not really working
 
 """
+from typing import Any
+from typing import Optional
+
 from sqlalchemy import Column
 from sqlalchemy import create_engine
 from sqlalchemy import inspect
 from sqlalchemy import Integer
 from sqlalchemy import String
 from sqlalchemy.engine.reflection import Inspector
-from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy.orm import DeclarativeBase
 from sqlalchemy.orm import Mapper
 
-Base = declarative_base()
+
+class Base(DeclarativeBase):
+    pass
 
 
 class A(Base):
@@ -30,9 +35,9 @@ e = create_engine("sqlite://")
 # TODO: I can't get these to work, pylance and mypy both don't want
 # to accommodate for different types for the first argument
 
-t: bool = inspect(a1).transient
+t: Optional[Any] = inspect(a1)
 
-m: Mapper = inspect(A)
+m: Mapper[Any] = inspect(A)
 
 inspect(e).get_table_names()
 
similarity index 99%
rename from test/ext/mypy/plugin_files/complete_orm_no_plugin.py
rename to test/typing/plain_files/orm/complete_orm_no_plugin.py
index 53291501ad3dfb0c8f292e7a66d3e3af1a01a277..b22057a2f36bdd4e175c501205a7ee35f82668bf 100644 (file)
@@ -1,4 +1,3 @@
-# NOPLUGINS
 # this should pass typing with no plugins
 
 from typing import Any
similarity index 98%
rename from test/ext/mypy/plain_files/declared_attr_one.py
rename to test/typing/plain_files/orm/declared_attr_one.py
index 86f8cf77041569728545eacc11ff1499a6a9587a..fc304db87e97ca0cf852a21cf23075eb865f9e7c 100644 (file)
@@ -30,7 +30,7 @@ class Employee(Base):
     }
 
     __table_args__ = (
-        Index("my_index", name, type),
+        Index("my_index", name, type.desc()),
         UniqueConstraint(name),
         PrimaryKeyConstraint(id),
         {"prefix": []},
diff --git a/test/typing/plain_files/orm/mapped_assign_expression.py b/test/typing/plain_files/orm/mapped_assign_expression.py
new file mode 100644 (file)
index 0000000..e68b4b4
--- /dev/null
@@ -0,0 +1,27 @@
+from datetime import datetime
+
+from sqlalchemy import create_engine
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import registry
+from sqlalchemy.orm import Session
+from sqlalchemy.sql.functions import now
+from sqlalchemy.testing.schema import mapped_column
+
+mapper_registry: registry = registry()
+e = create_engine("sqlite:///database.db", echo=True)
+
+
+@mapper_registry.mapped
+class A:
+    __tablename__ = "a"
+    id: Mapped[int] = mapped_column(primary_key=True)
+    date_time: Mapped[datetime]
+
+
+mapper_registry.metadata.create_all(e)
+
+with Session(e) as s:
+    a = A()
+    a.date_time = now()
+    s.add(a)
+    s.commit()
similarity index 95%
rename from test/ext/mypy/plain_files/experimental_relationship.py
rename to test/typing/plain_files/orm/relationship.py
index 7acec89e1875e7959745c427f297fd9da0e38e8c..ddd51e21e4376265c340944a8656845b63a9bb1a 100644 (file)
@@ -85,8 +85,8 @@ if typing.TYPE_CHECKING:
     # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.str\*?\]
     reveal_type(Address.email_name)
 
-    # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*?\[experimental_relationship.Address\]\]
+    # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*?\[relationship.Address\]\]
     reveal_type(User.addresses_style_one)
 
-    # EXPECTED_RE_TYPE: sqlalchemy.orm.attributes.InstrumentedAttribute\[builtins.set\*?\[experimental_relationship.Address\]\]
+    # EXPECTED_RE_TYPE: sqlalchemy.orm.attributes.InstrumentedAttribute\[builtins.set\*?\[relationship.Address\]\]
     reveal_type(User.addresses_style_two)
diff --git a/test/typing/plain_files/orm/scoped_session.py b/test/typing/plain_files/orm/scoped_session.py
new file mode 100644 (file)
index 0000000..9809901
--- /dev/null
@@ -0,0 +1,58 @@
+from sqlalchemy import inspect
+from sqlalchemy import text
+from sqlalchemy.orm import DeclarativeBase
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import mapped_column
+from sqlalchemy.orm import scoped_session
+from sqlalchemy.orm import sessionmaker
+
+
+class Base(DeclarativeBase):
+    pass
+
+
+class X(Base):
+    __tablename__ = "x"
+    id: Mapped[int] = mapped_column(primary_key=True)
+
+
+scoped_session.object_session(object())
+scoped_session.identity_key()
+scoped_session.close_all()
+ss = scoped_session(sessionmaker())
+value: bool = "foo" in ss
+list(ss)
+ss.add(object())
+ss.add_all([])
+ss.begin()
+ss.begin_nested()
+ss.close()
+ss.commit()
+ss.connection()
+ss.delete(object())
+ss.execute(text("select 1"))
+ss.expire(object())
+ss.expire_all()
+ss.expunge(object())
+ss.expunge_all()
+ss.flush()
+ss.get(object, 1)
+b = ss.get_bind()
+ss.is_modified(object())
+ss.bulk_save_objects([])
+ss.bulk_insert_mappings(inspect(X), [])
+ss.bulk_update_mappings(inspect(X), [])
+ss.merge(object())
+q = (ss.query(object),)
+ss.refresh(object())
+ss.rollback()
+ss.scalar(text("select 1"))
+ss.bind
+ss.dirty
+ss.deleted
+ss.new
+ss.identity_map
+ss.is_active
+ss.autoflush
+ss.no_autoflush
+ss.info
diff --git a/test/typing/plain_files/sql/core_ddl.py b/test/typing/plain_files/sql/core_ddl.py
new file mode 100644 (file)
index 0000000..b7e0ec5
--- /dev/null
@@ -0,0 +1,151 @@
+from sqlalchemy import Boolean
+from sqlalchemy import CheckConstraint
+from sqlalchemy import Column
+from sqlalchemy import DateTime
+from sqlalchemy import FetchedValue
+from sqlalchemy import ForeignKey
+from sqlalchemy import func
+from sqlalchemy import Index
+from sqlalchemy import Integer
+from sqlalchemy import literal_column
+from sqlalchemy import MetaData
+from sqlalchemy import PrimaryKeyConstraint
+from sqlalchemy import String
+from sqlalchemy import Table
+from sqlalchemy import text
+from sqlalchemy import true
+from sqlalchemy import UUID
+
+
+m = MetaData()
+
+
+t1 = Table(
+    "t1",
+    m,
+    Column("id", Integer, primary_key=True),
+    Column("data", String),
+    Column("data2", String(50)),
+    Column("timestamp", DateTime()),
+    Index(None, "data2"),
+)
+
+t2 = Table(
+    "t2",
+    m,
+    Column("t1id", ForeignKey("t1.id")),
+    Column("q", Integer, CheckConstraint("q > 5")),
+)
+
+t3 = Table(
+    "t3",
+    m,
+    Column("x", Integer),
+    Column("y", Integer),
+    Column("t1id", ForeignKey(t1.c.id)),
+    PrimaryKeyConstraint("x", "y"),
+)
+
+t4 = Table(
+    "test_table",
+    m,
+    Column("i", UUID(as_uuid=True), nullable=False, primary_key=True),
+    Column("x", UUID(as_uuid=True), index=True),
+    Column("y", UUID(as_uuid=False), index=True),
+    Index("ix_xy_unique", "x", "y", unique=True),
+)
+
+
+# cols w/ no name or type, used by declarative
+c1: Column[int] = Column(ForeignKey(t3.c.x))
+# more colum args
+Column("name", Integer, index=True)
+Column(None, name="name")
+Column(Integer, name="name", index=True)
+Column("name", ForeignKey("a.id"))
+Column(ForeignKey("a.id"), type_=None, index=True)
+Column(ForeignKey("a.id"), name="name", type_=Integer())
+Column("name", None)
+Column("name", index=True)
+Column(ForeignKey("a.id"), name="name", index=True)
+Column(type_=None, index=True)
+Column(None, ForeignKey("a.id"))
+Column("name")
+Column(name="name", type_=None, index=True)
+Column(ForeignKey("a.id"), name="name", type_=None)
+Column(Integer)
+Column(ForeignKey("a.id"), type_=Integer())
+Column("name", Integer, ForeignKey("a.id"), index=True)
+Column("name", None, ForeignKey("a.id"), index=True)
+Column(ForeignKey("a.id"), index=True)
+Column("name", Integer)
+Column(Integer, name="name")
+Column(Integer, ForeignKey("a.id"), name="name", index=True)
+Column(ForeignKey("a.id"), type_=None)
+Column(ForeignKey("a.id"), name="name")
+Column(name="name", index=True)
+Column(type_=None)
+Column(None, index=True)
+Column(name="name", type_=None)
+Column(type_=Integer(), index=True)
+Column("name", Integer, ForeignKey("a.id"))
+Column(name="name", type_=Integer(), index=True)
+Column(Integer, ForeignKey("a.id"), index=True)
+Column("name", None, ForeignKey("a.id"))
+Column(index=True)
+Column("name", type_=None, index=True)
+Column("name", ForeignKey("a.id"), type_=Integer(), index=True)
+Column(ForeignKey("a.id"))
+Column(Integer, ForeignKey("a.id"))
+Column(Integer, ForeignKey("a.id"), name="name")
+Column("name", ForeignKey("a.id"), index=True)
+Column("name", type_=Integer(), index=True)
+Column(ForeignKey("a.id"), name="name", type_=Integer(), index=True)
+Column(name="name")
+Column("name", None, index=True)
+Column("name", ForeignKey("a.id"), type_=None, index=True)
+Column("name", type_=Integer())
+Column(None)
+Column(None, ForeignKey("a.id"), index=True)
+Column("name", ForeignKey("a.id"), type_=None)
+Column(type_=Integer())
+Column(None, ForeignKey("a.id"), name="name", index=True)
+Column(Integer, index=True)
+Column(ForeignKey("a.id"), name="name", type_=None, index=True)
+Column(ForeignKey("a.id"), type_=Integer(), index=True)
+Column(name="name", type_=Integer())
+Column(None, name="name", index=True)
+Column()
+Column(None, ForeignKey("a.id"), name="name")
+Column("name", type_=None)
+Column("name", ForeignKey("a.id"), type_=Integer())
+
+# server_default
+Column(Boolean, nullable=False, server_default=true())
+Column(DateTime, server_default=func.now(), nullable=False)
+Column(Boolean, server_default=func.xyzq(), nullable=False)
+# what would be *nice* to emit an error would be this, but this
+# is really not important, people don't usually put types in functions
+# as they are usually part of a bigger context where the type is known
+Column(Boolean, server_default=func.xyzq(type_=DateTime), nullable=False)
+Column(DateTime, server_default="now()")
+Column(DateTime, server_default=text("now()"))
+Column(DateTime, server_default=FetchedValue())
+Column(Boolean, server_default=literal_column("false", Boolean))
+Column("name", server_default=FetchedValue(), nullable=False)
+Column(server_default="now()", nullable=False)
+Column("name", Integer, server_default=text("now()"), nullable=False)
+Column(Integer, server_default=literal_column("42", Integer), nullable=False)
+
+# server_onupdate
+Column("name", server_onupdate=FetchedValue(), nullable=False)
+Column(server_onupdate=FetchedValue(), nullable=False)
+Column("name", Integer, server_onupdate=FetchedValue(), nullable=False)
+Column(Integer, server_onupdate=FetchedValue(), nullable=False)
+
+# TypeEngine.with_variant should accept both a TypeEngine instance and the Concrete Type
+Integer().with_variant(Integer, "mysql")
+Integer().with_variant(Integer(), "mysql")
+# Also test Variant.with_variant
+Integer().with_variant(Integer, "mysql").with_variant(Integer, "mysql")
+Integer().with_variant(Integer, "mysql").with_variant(Integer(), "mysql")
diff --git a/test/typing/plain_files/sql/functions_again.py b/test/typing/plain_files/sql/functions_again.py
new file mode 100644 (file)
index 0000000..edfbd6b
--- /dev/null
@@ -0,0 +1,23 @@
+from sqlalchemy import func
+from sqlalchemy.orm import DeclarativeBase
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import mapped_column
+
+
+class Base(DeclarativeBase):
+    pass
+
+
+class Foo(Base):
+    __tablename__ = "foo"
+
+    id: Mapped[int] = mapped_column(primary_key=True)
+    a: Mapped[int]
+    b: Mapped[int]
+
+
+func.row_number().over(order_by=Foo.a, partition_by=Foo.b.desc())
+func.row_number().over(order_by=[Foo.a.desc(), Foo.b.desc()])
+func.row_number().over(partition_by=[Foo.a.desc(), Foo.b.desc()])
+func.row_number().over(order_by="a", partition_by=("a", "b"))
+func.row_number().over(partition_by="a", order_by=("a", "b"))
diff --git a/test/typing/plain_files/sql/lowercase_objects.py b/test/typing/plain_files/sql/lowercase_objects.py
new file mode 100644 (file)
index 0000000..ab26d7e
--- /dev/null
@@ -0,0 +1,16 @@
+import sqlalchemy as sa
+
+Book = sa.table(
+    "book",
+    sa.column("id", sa.Integer),
+    sa.column("name", sa.String),
+)
+Book.append_column(sa.column("other"))
+Book.corresponding_column(Book.c.id)
+
+value_expr = sa.values(
+    sa.column("id", sa.Integer), sa.column("name", sa.String), name="my_values"
+).data([(1, "name1"), (2, "name2"), (3, "name3")])
+
+sa.select(Book)
+sa.select(sa.literal_column("42"), sa.column("foo")).select_from(sa.table("t"))
diff --git a/test/typing/plain_files/sql/operators.py b/test/typing/plain_files/sql/operators.py
new file mode 100644 (file)
index 0000000..41981d1
--- /dev/null
@@ -0,0 +1,137 @@
+from decimal import Decimal
+from typing import Any
+
+from sqlalchemy import ARRAY
+from sqlalchemy import BigInteger
+from sqlalchemy import column
+from sqlalchemy import ColumnElement
+from sqlalchemy import Integer
+from sqlalchemy import String
+from sqlalchemy.orm import DeclarativeBase
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import mapped_column
+
+
+class Base(DeclarativeBase):
+    pass
+
+
+class A(Base):
+    __tablename__ = "a"
+    id: Mapped[int]
+    string: Mapped[str]
+    arr: Mapped[list[int]] = mapped_column(ARRAY(Integer))
+
+
+lt1: "ColumnElement[bool]" = A.id > A.id
+lt2: "ColumnElement[bool]" = A.id > 1
+lt3: "ColumnElement[bool]" = 1 < A.id
+
+le1: "ColumnElement[bool]" = A.id >= A.id
+le2: "ColumnElement[bool]" = A.id >= 1
+le3: "ColumnElement[bool]" = 1 <= A.id
+
+eq1: "ColumnElement[bool]" = A.id == A.id
+eq2: "ColumnElement[bool]" = A.id == 1
+# eq3: "ColumnElement[bool]" = 1 == A.id
+
+ne1: "ColumnElement[bool]" = A.id != A.id
+ne2: "ColumnElement[bool]" = A.id != 1
+# ne3: "ColumnElement[bool]" = 1 != A.id
+
+gt1: "ColumnElement[bool]" = A.id < A.id
+gt2: "ColumnElement[bool]" = A.id < 1
+gt3: "ColumnElement[bool]" = 1 > A.id
+
+ge1: "ColumnElement[bool]" = A.id <= A.id
+ge2: "ColumnElement[bool]" = A.id <= 1
+ge3: "ColumnElement[bool]" = 1 >= A.id
+
+
+# TODO "in" doesn't seem to pick up the typing of __contains__?
+# EXPECTED_MYPY: Incompatible types in assignment (expression has type "bool", variable has type "ColumnElement[bool]") # noqa: E501
+contains1: "ColumnElement[bool]" = A.id in A.arr
+# EXPECTED_MYPY: Incompatible types in assignment (expression has type "bool", variable has type "ColumnElement[bool]") # noqa: E501
+contains2: "ColumnElement[bool]" = A.id in A.string
+
+lshift1: "ColumnElement[int]" = A.id << A.id
+lshift2: "ColumnElement[int]" = A.id << 1
+lshift3: "ColumnElement[Any]" = A.string << 1
+
+rshift1: "ColumnElement[int]" = A.id >> A.id
+rshift2: "ColumnElement[int]" = A.id >> 1
+rshift3: "ColumnElement[Any]" = A.string >> 1
+
+concat1: "ColumnElement[str]" = A.string.concat(A.string)
+concat2: "ColumnElement[str]" = A.string.concat(1)
+concat3: "ColumnElement[str]" = A.string.concat("a")
+
+like1: "ColumnElement[bool]" = A.string.like("test")
+like2: "ColumnElement[bool]" = A.string.like("test", escape="/")
+ilike1: "ColumnElement[bool]" = A.string.ilike("test")
+ilike2: "ColumnElement[bool]" = A.string.ilike("test", escape="/")
+
+in_: "ColumnElement[bool]" = A.id.in_([1, 2])
+not_in: "ColumnElement[bool]" = A.id.not_in([1, 2])
+
+not_like1: "ColumnElement[bool]" = A.string.not_like("test")
+not_like2: "ColumnElement[bool]" = A.string.not_like("test", escape="/")
+not_ilike1: "ColumnElement[bool]" = A.string.not_ilike("test")
+not_ilike2: "ColumnElement[bool]" = A.string.not_ilike("test", escape="/")
+
+is_: "ColumnElement[bool]" = A.string.is_("test")
+is_not: "ColumnElement[bool]" = A.string.is_not("test")
+
+startswith: "ColumnElement[bool]" = A.string.startswith("test")
+endswith: "ColumnElement[bool]" = A.string.endswith("test")
+contains: "ColumnElement[bool]" = A.string.contains("test")
+match: "ColumnElement[bool]" = A.string.match("test")
+regexp_match: "ColumnElement[bool]" = A.string.regexp_match("test")
+
+regexp_replace: "ColumnElement[str]" = A.string.regexp_replace(
+    "pattern", "replacement"
+)
+between: "ColumnElement[bool]" = A.string.between("a", "b")
+
+adds: "ColumnElement[str]" = A.string + A.string
+add1: "ColumnElement[int]" = A.id + A.id
+add2: "ColumnElement[int]" = A.id + 1
+add3: "ColumnElement[int]" = 1 + A.id
+
+sub1: "ColumnElement[int]" = A.id - A.id
+sub2: "ColumnElement[int]" = A.id - 1
+sub3: "ColumnElement[int]" = 1 - A.id
+
+mul1: "ColumnElement[int]" = A.id * A.id
+mul2: "ColumnElement[int]" = A.id * 1
+mul3: "ColumnElement[int]" = 1 * A.id
+
+div1: "ColumnElement[float|Decimal]" = A.id / A.id
+div2: "ColumnElement[float|Decimal]" = A.id / 1
+div3: "ColumnElement[float|Decimal]" = 1 / A.id
+
+mod1: "ColumnElement[int]" = A.id % A.id
+mod2: "ColumnElement[int]" = A.id % 1
+mod3: "ColumnElement[int]" = 1 % A.id
+
+# unary
+
+neg: "ColumnElement[int]" = -A.id
+
+desc: "ColumnElement[int]" = A.id.desc()
+asc: "ColumnElement[int]" = A.id.asc()
+any_: "ColumnElement[bool]" = A.id.any_()
+all_: "ColumnElement[bool]" = A.id.all_()
+nulls_first: "ColumnElement[int]" = A.id.nulls_first()
+nulls_last: "ColumnElement[int]" = A.id.nulls_last()
+collate: "ColumnElement[str]" = A.string.collate("somelang")
+distinct: "ColumnElement[int]" = A.id.distinct()
+
+
+# custom ops
+col = column("flags", Integer)
+op_a: "ColumnElement[Any]" = col.op("&")(1)
+op_b: "ColumnElement[int]" = col.op("&", return_type=Integer)(1)
+op_c: "ColumnElement[str]" = col.op("&", return_type=String)("1")
+op_d: "ColumnElement[int]" = col.op("&", return_type=BigInteger)("1")
+op_e: "ColumnElement[bool]" = col.bool_op("&")("1")
diff --git a/test/typing/test_mypy.py b/test/typing/test_mypy.py
new file mode 100644 (file)
index 0000000..14d13bd
--- /dev/null
@@ -0,0 +1,17 @@
+import os
+
+from sqlalchemy import testing
+from sqlalchemy.testing import fixtures
+
+
+class MypyPlainTest(fixtures.MypyTest):
+    @testing.combinations(
+        *(
+            (os.path.basename(path), path)
+            for path in fixtures.MypyTest.file_combinations("plain_files")
+        ),
+        argnames="path",
+        id_="ia",
+    )
+    def test_mypy_no_plugin(self, mypy_typecheck_file, path):
+        mypy_typecheck_file(path)
index 5845e89ad93517ec31918be724120cb173d54fd4..848a927225056d0b728410a70a33a0d2e874f091 100644 (file)
@@ -136,7 +136,7 @@ def main(cmd: code_writer_cmd) -> None:
 
 
 functions_py = "lib/sqlalchemy/sql/functions.py"
-test_functions_py = "test/ext/mypy/plain_files/functions.py"
+test_functions_py = "test/typing/plain_files/sql/functions.py"
 
 
 if __name__ == "__main__":