From: Federico Caselli Date: Fri, 23 Jun 2023 17:58:54 +0000 (+0200) Subject: Improve typing tests X-Git-Tag: rel_2_0_18~21^2 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=c3acf8a5d23881ed4795fb5ca1c28fae0adc6414;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Improve typing tests Extract a fixture to run mypy on files Move the plain files to test/typing Move test files from stubs repository Transform the fixture module in a package Change-Id: I23acaecb84e7c4b9010259d44395dc1df83a9385 --- diff --git a/lib/sqlalchemy/orm/writeonly.py b/lib/sqlalchemy/orm/writeonly.py index 9f0dbeead2..0f245835b0 100644 --- a/lib/sqlalchemy/orm/writeonly.py +++ b/lib/sqlalchemy/orm/writeonly.py @@ -255,7 +255,7 @@ class WriteOnlyAttributeImpl( state._modified_event(dict_, self, attributes.NEVER_SET) - # this is a hack to allow the fixtures.ComparableEntity fixture + # this is a hack to allow the entities.ComparableEntity fixture # to work dict_[self.key] = True return state.committed_state[self.key] diff --git a/lib/sqlalchemy/testing/config.py b/lib/sqlalchemy/testing/config.py index d2bda4d83c..b8f03362f9 100644 --- a/lib/sqlalchemy/testing/config.py +++ b/lib/sqlalchemy/testing/config.py @@ -9,6 +9,7 @@ from __future__ import annotations +from argparse import Namespace import collections import inspect import typing @@ -34,6 +35,7 @@ test_schema_2 = None any_async = False _current = None ident = "main" +options: Namespace = None # type: ignore if typing.TYPE_CHECKING: from .plugin.plugin_base import FixtureFunctions diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py deleted file mode 100644 index cb08380f2e..0000000000 --- a/lib/sqlalchemy/testing/fixtures.py +++ /dev/null @@ -1,1055 +0,0 @@ -# testing/fixtures.py -# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors -# -# -# This module is part of SQLAlchemy and is released under -# the MIT License: https://www.opensource.org/licenses/mit-license.php -# mypy: ignore-errors - - -from __future__ import annotations - -import itertools -import random -import re -import sys -from typing import Any - -import sqlalchemy as sa -from . import assertions -from . import config -from . import mock -from . import schema -from .assertions import eq_ -from .assertions import ne_ -from .entities import BasicEntity -from .entities import ComparableEntity -from .entities import ComparableMixin # noqa -from .util import adict -from .util import drop_all_tables_from_metadata -from .. import Column -from .. import event -from .. import func -from .. import Integer -from .. import select -from .. import Table -from .. import util -from ..orm import DeclarativeBase -from ..orm import events as orm_events -from ..orm import MappedAsDataclass -from ..orm import registry -from ..schema import sort_tables_and_constraints -from ..sql import visitors -from ..sql.elements import ClauseElement - - -@config.mark_base_test_class() -class TestBase: - # A sequence of requirement names matching testing.requires decorators - __requires__ = () - - # A sequence of dialect names to exclude from the test class. - __unsupported_on__ = () - - # If present, test class is only runnable for the *single* specified - # dialect. If you need multiple, use __unsupported_on__ and invert. - __only_on__ = None - - # A sequence of no-arg callables. If any are True, the entire testcase is - # skipped. - __skip_if__ = None - - # if True, the testing reaper will not attempt to touch connection - # state after a test is completed and before the outer teardown - # starts - __leave_connections_for_teardown__ = False - - def assert_(self, val, msg=None): - assert val, msg - - @config.fixture() - def nocache(self): - _cache = config.db._compiled_cache - config.db._compiled_cache = None - yield - config.db._compiled_cache = _cache - - @config.fixture() - def connection_no_trans(self): - eng = getattr(self, "bind", None) or config.db - - with eng.connect() as conn: - yield conn - - @config.fixture() - def connection(self): - global _connection_fixture_connection - - eng = getattr(self, "bind", None) or config.db - - conn = eng.connect() - trans = conn.begin() - - _connection_fixture_connection = conn - yield conn - - _connection_fixture_connection = None - - if trans.is_active: - trans.rollback() - # trans would not be active here if the test is using - # the legacy @provide_metadata decorator still, as it will - # run a close all connections. - conn.close() - - @config.fixture() - def close_result_when_finished(self): - to_close = [] - to_consume = [] - - def go(result, consume=False): - to_close.append(result) - if consume: - to_consume.append(result) - - yield go - for r in to_consume: - try: - r.all() - except: - pass - for r in to_close: - try: - r.close() - except: - pass - - @config.fixture() - def registry(self, metadata): - reg = registry( - metadata=metadata, - type_annotation_map={ - str: sa.String().with_variant( - sa.String(50), "mysql", "mariadb", "oracle" - ) - }, - ) - yield reg - reg.dispose() - - @config.fixture - def decl_base(self, metadata): - _md = metadata - - class Base(DeclarativeBase): - metadata = _md - type_annotation_map = { - str: sa.String().with_variant( - sa.String(50), "mysql", "mariadb", "oracle" - ) - } - - yield Base - Base.registry.dispose() - - @config.fixture - def dc_decl_base(self, metadata): - _md = metadata - - class Base(MappedAsDataclass, DeclarativeBase): - metadata = _md - type_annotation_map = { - str: sa.String().with_variant( - sa.String(50), "mysql", "mariadb" - ) - } - - yield Base - Base.registry.dispose() - - @config.fixture() - def future_connection(self, future_engine, connection): - # integrate the future_engine and connection fixtures so - # that users of the "connection" fixture will get at the - # "future" connection - yield connection - - @config.fixture() - def future_engine(self): - yield - - @config.fixture() - def testing_engine(self): - from . import engines - - def gen_testing_engine( - url=None, - options=None, - future=None, - asyncio=False, - transfer_staticpool=False, - share_pool=False, - ): - if options is None: - options = {} - options["scope"] = "fixture" - return engines.testing_engine( - url=url, - options=options, - asyncio=asyncio, - transfer_staticpool=transfer_staticpool, - share_pool=share_pool, - ) - - yield gen_testing_engine - - engines.testing_reaper._drop_testing_engines("fixture") - - @config.fixture() - def async_testing_engine(self, testing_engine): - def go(**kw): - kw["asyncio"] = True - return testing_engine(**kw) - - return go - - @config.fixture - def fixture_session(self): - return fixture_session() - - @config.fixture() - def metadata(self, request): - """Provide bound MetaData for a single test, dropping afterwards.""" - - from ..sql import schema - - metadata = schema.MetaData() - request.instance.metadata = metadata - yield metadata - del request.instance.metadata - - if ( - _connection_fixture_connection - and _connection_fixture_connection.in_transaction() - ): - trans = _connection_fixture_connection.get_transaction() - trans.rollback() - with _connection_fixture_connection.begin(): - drop_all_tables_from_metadata( - metadata, _connection_fixture_connection - ) - else: - drop_all_tables_from_metadata(metadata, config.db) - - @config.fixture( - params=[ - (rollback, second_operation, begin_nested) - for rollback in (True, False) - for second_operation in ("none", "execute", "begin") - for begin_nested in ( - True, - False, - ) - ] - ) - def trans_ctx_manager_fixture(self, request, metadata): - rollback, second_operation, begin_nested = request.param - - t = Table("test", metadata, Column("data", Integer)) - eng = getattr(self, "bind", None) or config.db - - t.create(eng) - - def run_test(subject, trans_on_subject, execute_on_subject): - with subject.begin() as trans: - if begin_nested: - if not config.requirements.savepoints.enabled: - config.skip_test("savepoints not enabled") - if execute_on_subject: - nested_trans = subject.begin_nested() - else: - nested_trans = trans.begin_nested() - - with nested_trans: - if execute_on_subject: - subject.execute(t.insert(), {"data": 10}) - else: - trans.execute(t.insert(), {"data": 10}) - - # for nested trans, we always commit/rollback on the - # "nested trans" object itself. - # only Session(future=False) will affect savepoint - # transaction for session.commit/rollback - - if rollback: - nested_trans.rollback() - else: - nested_trans.commit() - - if second_operation != "none": - with assertions.expect_raises_message( - sa.exc.InvalidRequestError, - "Can't operate on closed transaction " - "inside context " - "manager. Please complete the context " - "manager " - "before emitting further commands.", - ): - if second_operation == "execute": - if execute_on_subject: - subject.execute( - t.insert(), {"data": 12} - ) - else: - trans.execute(t.insert(), {"data": 12}) - elif second_operation == "begin": - if execute_on_subject: - subject.begin_nested() - else: - trans.begin_nested() - - # outside the nested trans block, but still inside the - # transaction block, we can run SQL, and it will be - # committed - if execute_on_subject: - subject.execute(t.insert(), {"data": 14}) - else: - trans.execute(t.insert(), {"data": 14}) - - else: - if execute_on_subject: - subject.execute(t.insert(), {"data": 10}) - else: - trans.execute(t.insert(), {"data": 10}) - - if trans_on_subject: - if rollback: - subject.rollback() - else: - subject.commit() - else: - if rollback: - trans.rollback() - else: - trans.commit() - - if second_operation != "none": - with assertions.expect_raises_message( - sa.exc.InvalidRequestError, - "Can't operate on closed transaction inside " - "context " - "manager. Please complete the context manager " - "before emitting further commands.", - ): - if second_operation == "execute": - if execute_on_subject: - subject.execute(t.insert(), {"data": 12}) - else: - trans.execute(t.insert(), {"data": 12}) - elif second_operation == "begin": - if hasattr(trans, "begin"): - trans.begin() - else: - subject.begin() - elif second_operation == "begin_nested": - if execute_on_subject: - subject.begin_nested() - else: - trans.begin_nested() - - expected_committed = 0 - if begin_nested: - # begin_nested variant, we inserted a row after the nested - # block - expected_committed += 1 - if not rollback: - # not rollback variant, our row inserted in the target - # block itself would be committed - expected_committed += 1 - - if execute_on_subject: - eq_( - subject.scalar(select(func.count()).select_from(t)), - expected_committed, - ) - else: - with subject.connect() as conn: - eq_( - conn.scalar(select(func.count()).select_from(t)), - expected_committed, - ) - - return run_test - - -_connection_fixture_connection = None - - -class FutureEngineMixin: - """alembic's suite still using this""" - - -class TablesTest(TestBase): - # 'once', None - run_setup_bind = "once" - - # 'once', 'each', None - run_define_tables = "once" - - # 'once', 'each', None - run_create_tables = "once" - - # 'once', 'each', None - run_inserts = "each" - - # 'each', None - run_deletes = "each" - - # 'once', None - run_dispose_bind = None - - bind = None - _tables_metadata = None - tables = None - other = None - sequences = None - - @config.fixture(autouse=True, scope="class") - def _setup_tables_test_class(self): - cls = self.__class__ - cls._init_class() - - cls._setup_once_tables() - - cls._setup_once_inserts() - - yield - - cls._teardown_once_metadata_bind() - - @config.fixture(autouse=True, scope="function") - def _setup_tables_test_instance(self): - self._setup_each_tables() - self._setup_each_inserts() - - yield - - self._teardown_each_tables() - - @property - def tables_test_metadata(self): - return self._tables_metadata - - @classmethod - def _init_class(cls): - if cls.run_define_tables == "each": - if cls.run_create_tables == "once": - cls.run_create_tables = "each" - assert cls.run_inserts in ("each", None) - - cls.other = adict() - cls.tables = adict() - cls.sequences = adict() - - cls.bind = cls.setup_bind() - cls._tables_metadata = sa.MetaData() - - @classmethod - def _setup_once_inserts(cls): - if cls.run_inserts == "once": - cls._load_fixtures() - with cls.bind.begin() as conn: - cls.insert_data(conn) - - @classmethod - def _setup_once_tables(cls): - if cls.run_define_tables == "once": - cls.define_tables(cls._tables_metadata) - if cls.run_create_tables == "once": - cls._tables_metadata.create_all(cls.bind) - cls.tables.update(cls._tables_metadata.tables) - cls.sequences.update(cls._tables_metadata._sequences) - - def _setup_each_tables(self): - if self.run_define_tables == "each": - self.define_tables(self._tables_metadata) - if self.run_create_tables == "each": - self._tables_metadata.create_all(self.bind) - self.tables.update(self._tables_metadata.tables) - self.sequences.update(self._tables_metadata._sequences) - elif self.run_create_tables == "each": - self._tables_metadata.create_all(self.bind) - - def _setup_each_inserts(self): - if self.run_inserts == "each": - self._load_fixtures() - with self.bind.begin() as conn: - self.insert_data(conn) - - def _teardown_each_tables(self): - if self.run_define_tables == "each": - self.tables.clear() - if self.run_create_tables == "each": - drop_all_tables_from_metadata(self._tables_metadata, self.bind) - self._tables_metadata.clear() - elif self.run_create_tables == "each": - drop_all_tables_from_metadata(self._tables_metadata, self.bind) - - savepoints = getattr(config.requirements, "savepoints", False) - if savepoints: - savepoints = savepoints.enabled - - # no need to run deletes if tables are recreated on setup - if ( - self.run_define_tables != "each" - and self.run_create_tables != "each" - and self.run_deletes == "each" - ): - with self.bind.begin() as conn: - for table in reversed( - [ - t - for (t, fks) in sort_tables_and_constraints( - self._tables_metadata.tables.values() - ) - if t is not None - ] - ): - try: - if savepoints: - with conn.begin_nested(): - conn.execute(table.delete()) - else: - conn.execute(table.delete()) - except sa.exc.DBAPIError as ex: - print( - ("Error emptying table %s: %r" % (table, ex)), - file=sys.stderr, - ) - - @classmethod - def _teardown_once_metadata_bind(cls): - if cls.run_create_tables: - drop_all_tables_from_metadata(cls._tables_metadata, cls.bind) - - if cls.run_dispose_bind == "once": - cls.dispose_bind(cls.bind) - - cls._tables_metadata.bind = None - - if cls.run_setup_bind is not None: - cls.bind = None - - @classmethod - def setup_bind(cls): - return config.db - - @classmethod - def dispose_bind(cls, bind): - if hasattr(bind, "dispose"): - bind.dispose() - elif hasattr(bind, "close"): - bind.close() - - @classmethod - def define_tables(cls, metadata): - pass - - @classmethod - def fixtures(cls): - return {} - - @classmethod - def insert_data(cls, connection): - pass - - def sql_count_(self, count, fn): - self.assert_sql_count(self.bind, fn, count) - - def sql_eq_(self, callable_, statements): - self.assert_sql(self.bind, callable_, statements) - - @classmethod - def _load_fixtures(cls): - """Insert rows as represented by the fixtures() method.""" - headers, rows = {}, {} - for table, data in cls.fixtures().items(): - if len(data) < 2: - continue - if isinstance(table, str): - table = cls.tables[table] - headers[table] = data[0] - rows[table] = data[1:] - for table, fks in sort_tables_and_constraints( - cls._tables_metadata.tables.values() - ): - if table is None: - continue - if table not in headers: - continue - with cls.bind.begin() as conn: - conn.execute( - table.insert(), - [ - dict(zip(headers[table], column_values)) - for column_values in rows[table] - ], - ) - - -class NoCache: - @config.fixture(autouse=True, scope="function") - def _disable_cache(self): - _cache = config.db._compiled_cache - config.db._compiled_cache = None - yield - config.db._compiled_cache = _cache - - -class RemovesEvents: - @util.memoized_property - def _event_fns(self): - return set() - - def event_listen(self, target, name, fn, **kw): - self._event_fns.add((target, name, fn)) - event.listen(target, name, fn, **kw) - - @config.fixture(autouse=True, scope="function") - def _remove_events(self): - yield - for key in self._event_fns: - event.remove(*key) - - -class RemoveORMEventsGlobally: - @config.fixture(autouse=True) - def _remove_listeners(self): - yield - orm_events.MapperEvents._clear() - orm_events.InstanceEvents._clear() - orm_events.SessionEvents._clear() - orm_events.InstrumentationEvents._clear() - orm_events.QueryEvents._clear() - - -_fixture_sessions = set() - - -def fixture_session(**kw): - kw.setdefault("autoflush", True) - kw.setdefault("expire_on_commit", True) - - bind = kw.pop("bind", config.db) - - sess = sa.orm.Session(bind, **kw) - _fixture_sessions.add(sess) - return sess - - -def _close_all_sessions(): - # will close all still-referenced sessions - sa.orm.session.close_all_sessions() - _fixture_sessions.clear() - - -def stop_test_class_inside_fixtures(cls): - _close_all_sessions() - sa.orm.clear_mappers() - - -def after_test(): - if _fixture_sessions: - _close_all_sessions() - - -class ORMTest(TestBase): - pass - - -class MappedTest(TablesTest, assertions.AssertsExecutionResults): - # 'once', 'each', None - run_setup_classes = "once" - - # 'once', 'each', None - run_setup_mappers = "each" - - classes: Any = None - - @config.fixture(autouse=True, scope="class") - def _setup_tables_test_class(self): - cls = self.__class__ - cls._init_class() - - if cls.classes is None: - cls.classes = adict() - - cls._setup_once_tables() - cls._setup_once_classes() - cls._setup_once_mappers() - cls._setup_once_inserts() - - yield - - cls._teardown_once_class() - cls._teardown_once_metadata_bind() - - @config.fixture(autouse=True, scope="function") - def _setup_tables_test_instance(self): - self._setup_each_tables() - self._setup_each_classes() - self._setup_each_mappers() - self._setup_each_inserts() - - yield - - sa.orm.session.close_all_sessions() - self._teardown_each_mappers() - self._teardown_each_classes() - self._teardown_each_tables() - - @classmethod - def _teardown_once_class(cls): - cls.classes.clear() - - @classmethod - def _setup_once_classes(cls): - if cls.run_setup_classes == "once": - cls._with_register_classes(cls.setup_classes) - - @classmethod - def _setup_once_mappers(cls): - if cls.run_setup_mappers == "once": - cls.mapper_registry, cls.mapper = cls._generate_registry() - cls._with_register_classes(cls.setup_mappers) - - def _setup_each_mappers(self): - if self.run_setup_mappers != "once": - ( - self.__class__.mapper_registry, - self.__class__.mapper, - ) = self._generate_registry() - - if self.run_setup_mappers == "each": - self._with_register_classes(self.setup_mappers) - - def _setup_each_classes(self): - if self.run_setup_classes == "each": - self._with_register_classes(self.setup_classes) - - @classmethod - def _generate_registry(cls): - decl = registry(metadata=cls._tables_metadata) - return decl, decl.map_imperatively - - @classmethod - def _with_register_classes(cls, fn): - """Run a setup method, framing the operation with a Base class - that will catch new subclasses to be established within - the "classes" registry. - - """ - cls_registry = cls.classes - - class _Base: - def __init_subclass__(cls) -> None: - assert cls_registry is not None - cls_registry[cls.__name__] = cls - super().__init_subclass__() - - class Basic(BasicEntity, _Base): - pass - - class Comparable(ComparableEntity, _Base): - pass - - cls.Basic = Basic - cls.Comparable = Comparable - fn() - - def _teardown_each_mappers(self): - # some tests create mappers in the test bodies - # and will define setup_mappers as None - - # clear mappers in any case - if self.run_setup_mappers != "once": - sa.orm.clear_mappers() - - def _teardown_each_classes(self): - if self.run_setup_classes != "once": - self.classes.clear() - - @classmethod - def setup_classes(cls): - pass - - @classmethod - def setup_mappers(cls): - pass - - -class DeclarativeMappedTest(MappedTest): - run_setup_classes = "once" - run_setup_mappers = "once" - - @classmethod - def _setup_once_tables(cls): - pass - - @classmethod - def _with_register_classes(cls, fn): - cls_registry = cls.classes - - class _DeclBase(DeclarativeBase): - __table_cls__ = schema.Table - metadata = cls._tables_metadata - type_annotation_map = { - str: sa.String().with_variant( - sa.String(50), "mysql", "mariadb", "oracle" - ) - } - - def __init_subclass__(cls, **kw) -> None: - assert cls_registry is not None - cls_registry[cls.__name__] = cls - super().__init_subclass__(**kw) - - cls.DeclarativeBasic = _DeclBase - - # sets up cls.Basic which is helpful for things like composite - # classes - super()._with_register_classes(fn) - - if cls._tables_metadata.tables and cls.run_create_tables: - cls._tables_metadata.create_all(config.db) - - -class ComputedReflectionFixtureTest(TablesTest): - run_inserts = run_deletes = None - - __backend__ = True - __requires__ = ("computed_columns", "table_reflection") - - regexp = re.compile(r"[\[\]\(\)\s`'\"]*") - - def normalize(self, text): - return self.regexp.sub("", text).lower() - - @classmethod - def define_tables(cls, metadata): - from .. import Integer - from .. import testing - from ..schema import Column - from ..schema import Computed - from ..schema import Table - - Table( - "computed_default_table", - metadata, - Column("id", Integer, primary_key=True), - Column("normal", Integer), - Column("computed_col", Integer, Computed("normal + 42")), - Column("with_default", Integer, server_default="42"), - ) - - t = Table( - "computed_column_table", - metadata, - Column("id", Integer, primary_key=True), - Column("normal", Integer), - Column("computed_no_flag", Integer, Computed("normal + 42")), - ) - - if testing.requires.schemas.enabled: - t2 = Table( - "computed_column_table", - metadata, - Column("id", Integer, primary_key=True), - Column("normal", Integer), - Column("computed_no_flag", Integer, Computed("normal / 42")), - schema=config.test_schema, - ) - - if testing.requires.computed_columns_virtual.enabled: - t.append_column( - Column( - "computed_virtual", - Integer, - Computed("normal + 2", persisted=False), - ) - ) - if testing.requires.schemas.enabled: - t2.append_column( - Column( - "computed_virtual", - Integer, - Computed("normal / 2", persisted=False), - ) - ) - if testing.requires.computed_columns_stored.enabled: - t.append_column( - Column( - "computed_stored", - Integer, - Computed("normal - 42", persisted=True), - ) - ) - if testing.requires.schemas.enabled: - t2.append_column( - Column( - "computed_stored", - Integer, - Computed("normal * 42", persisted=True), - ) - ) - - -class CacheKeyFixture: - def _compare_equal(self, a, b, compare_values): - a_key = a._generate_cache_key() - b_key = b._generate_cache_key() - - if a_key is None: - assert a._annotations.get("nocache") - - assert b_key is None - else: - eq_(a_key.key, b_key.key) - eq_(hash(a_key.key), hash(b_key.key)) - - for a_param, b_param in zip(a_key.bindparams, b_key.bindparams): - assert a_param.compare(b_param, compare_values=compare_values) - return a_key, b_key - - def _run_cache_key_fixture(self, fixture, compare_values): - case_a = fixture() - case_b = fixture() - - for a, b in itertools.combinations_with_replacement( - range(len(case_a)), 2 - ): - if a == b: - a_key, b_key = self._compare_equal( - case_a[a], case_b[b], compare_values - ) - if a_key is None: - continue - else: - a_key = case_a[a]._generate_cache_key() - b_key = case_b[b]._generate_cache_key() - - if a_key is None or b_key is None: - if a_key is None: - assert case_a[a]._annotations.get("nocache") - if b_key is None: - assert case_b[b]._annotations.get("nocache") - continue - - if a_key.key == b_key.key: - for a_param, b_param in zip( - a_key.bindparams, b_key.bindparams - ): - if not a_param.compare( - b_param, compare_values=compare_values - ): - break - else: - # this fails unconditionally since we could not - # find bound parameter values that differed. - # Usually we intended to get two distinct keys here - # so the failure will be more descriptive using the - # ne_() assertion. - ne_(a_key.key, b_key.key) - else: - ne_(a_key.key, b_key.key) - - # ClauseElement-specific test to ensure the cache key - # collected all the bound parameters that aren't marked - # as "literal execute" - if isinstance(case_a[a], ClauseElement) and isinstance( - case_b[b], ClauseElement - ): - assert_a_params = [] - assert_b_params = [] - - for elem in visitors.iterate(case_a[a]): - if elem.__visit_name__ == "bindparam": - assert_a_params.append(elem) - - for elem in visitors.iterate(case_b[b]): - if elem.__visit_name__ == "bindparam": - assert_b_params.append(elem) - - # note we're asserting the order of the params as well as - # if there are dupes or not. ordering has to be - # deterministic and matches what a traversal would provide. - eq_( - sorted(a_key.bindparams, key=lambda b: b.key), - sorted( - util.unique_list(assert_a_params), key=lambda b: b.key - ), - ) - eq_( - sorted(b_key.bindparams, key=lambda b: b.key), - sorted( - util.unique_list(assert_b_params), key=lambda b: b.key - ), - ) - - def _run_cache_key_equal_fixture(self, fixture, compare_values): - case_a = fixture() - case_b = fixture() - - for a, b in itertools.combinations_with_replacement( - range(len(case_a)), 2 - ): - self._compare_equal(case_a[a], case_b[b], compare_values) - - -def insertmanyvalues_fixture( - connection, randomize_rows=False, warn_on_downgraded=False -): - dialect = connection.dialect - orig_dialect = dialect._deliver_insertmanyvalues_batches - orig_conn = connection._exec_insertmany_context - - class RandomCursor: - __slots__ = ("cursor",) - - def __init__(self, cursor): - self.cursor = cursor - - # only this method is called by the deliver method. - # by not having the other methods we assert that those aren't being - # used - - def fetchall(self): - rows = self.cursor.fetchall() - rows = list(rows) - random.shuffle(rows) - return rows - - def _deliver_insertmanyvalues_batches( - cursor, statement, parameters, generic_setinputsizes, context - ): - if randomize_rows: - cursor = RandomCursor(cursor) - for batch in orig_dialect( - cursor, statement, parameters, generic_setinputsizes, context - ): - if warn_on_downgraded and batch.is_downgraded: - util.warn("Batches were downgraded for sorted INSERT") - - yield batch - - def _exec_insertmany_context( - dialect, - context, - ): - with mock.patch.object( - dialect, - "_deliver_insertmanyvalues_batches", - new=_deliver_insertmanyvalues_batches, - ): - return orig_conn(dialect, context) - - connection._exec_insertmany_context = _exec_insertmany_context diff --git a/lib/sqlalchemy/testing/fixtures/__init__.py b/lib/sqlalchemy/testing/fixtures/__init__.py new file mode 100644 index 0000000000..932051ce8e --- /dev/null +++ b/lib/sqlalchemy/testing/fixtures/__init__.py @@ -0,0 +1,28 @@ +# testing/fixtures/__init__.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors +from .base import FutureEngineMixin as FutureEngineMixin +from .base import TestBase as TestBase +from .mypy import MypyTest as MypyTest +from .orm import after_test as after_test +from .orm import close_all_sessions as close_all_sessions +from .orm import DeclarativeMappedTest as DeclarativeMappedTest +from .orm import fixture_session as fixture_session +from .orm import MappedTest as MappedTest +from .orm import ORMTest as ORMTest +from .orm import RemoveORMEventsGlobally as RemoveORMEventsGlobally +from .orm import ( + stop_test_class_inside_fixtures as stop_test_class_inside_fixtures, +) +from .sql import CacheKeyFixture as CacheKeyFixture +from .sql import ( + ComputedReflectionFixtureTest as ComputedReflectionFixtureTest, +) +from .sql import insertmanyvalues_fixture as insertmanyvalues_fixture +from .sql import NoCache as NoCache +from .sql import RemovesEvents as RemovesEvents +from .sql import TablesTest as TablesTest diff --git a/lib/sqlalchemy/testing/fixtures/base.py b/lib/sqlalchemy/testing/fixtures/base.py new file mode 100644 index 0000000000..199ae7134e --- /dev/null +++ b/lib/sqlalchemy/testing/fixtures/base.py @@ -0,0 +1,366 @@ +# testing/fixtures/base.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +from __future__ import annotations + +import sqlalchemy as sa +from .. import assertions +from .. import config +from ..assertions import eq_ +from ..util import drop_all_tables_from_metadata +from ... import Column +from ... import func +from ... import Integer +from ... import select +from ... import Table +from ...orm import DeclarativeBase +from ...orm import MappedAsDataclass +from ...orm import registry + + +@config.mark_base_test_class() +class TestBase: + # A sequence of requirement names matching testing.requires decorators + __requires__ = () + + # A sequence of dialect names to exclude from the test class. + __unsupported_on__ = () + + # If present, test class is only runnable for the *single* specified + # dialect. If you need multiple, use __unsupported_on__ and invert. + __only_on__ = None + + # A sequence of no-arg callables. If any are True, the entire testcase is + # skipped. + __skip_if__ = None + + # if True, the testing reaper will not attempt to touch connection + # state after a test is completed and before the outer teardown + # starts + __leave_connections_for_teardown__ = False + + def assert_(self, val, msg=None): + assert val, msg + + @config.fixture() + def nocache(self): + _cache = config.db._compiled_cache + config.db._compiled_cache = None + yield + config.db._compiled_cache = _cache + + @config.fixture() + def connection_no_trans(self): + eng = getattr(self, "bind", None) or config.db + + with eng.connect() as conn: + yield conn + + @config.fixture() + def connection(self): + global _connection_fixture_connection + + eng = getattr(self, "bind", None) or config.db + + conn = eng.connect() + trans = conn.begin() + + _connection_fixture_connection = conn + yield conn + + _connection_fixture_connection = None + + if trans.is_active: + trans.rollback() + # trans would not be active here if the test is using + # the legacy @provide_metadata decorator still, as it will + # run a close all connections. + conn.close() + + @config.fixture() + def close_result_when_finished(self): + to_close = [] + to_consume = [] + + def go(result, consume=False): + to_close.append(result) + if consume: + to_consume.append(result) + + yield go + for r in to_consume: + try: + r.all() + except: + pass + for r in to_close: + try: + r.close() + except: + pass + + @config.fixture() + def registry(self, metadata): + reg = registry( + metadata=metadata, + type_annotation_map={ + str: sa.String().with_variant( + sa.String(50), "mysql", "mariadb", "oracle" + ) + }, + ) + yield reg + reg.dispose() + + @config.fixture + def decl_base(self, metadata): + _md = metadata + + class Base(DeclarativeBase): + metadata = _md + type_annotation_map = { + str: sa.String().with_variant( + sa.String(50), "mysql", "mariadb", "oracle" + ) + } + + yield Base + Base.registry.dispose() + + @config.fixture + def dc_decl_base(self, metadata): + _md = metadata + + class Base(MappedAsDataclass, DeclarativeBase): + metadata = _md + type_annotation_map = { + str: sa.String().with_variant( + sa.String(50), "mysql", "mariadb" + ) + } + + yield Base + Base.registry.dispose() + + @config.fixture() + def future_connection(self, future_engine, connection): + # integrate the future_engine and connection fixtures so + # that users of the "connection" fixture will get at the + # "future" connection + yield connection + + @config.fixture() + def future_engine(self): + yield + + @config.fixture() + def testing_engine(self): + from .. import engines + + def gen_testing_engine( + url=None, + options=None, + future=None, + asyncio=False, + transfer_staticpool=False, + share_pool=False, + ): + if options is None: + options = {} + options["scope"] = "fixture" + return engines.testing_engine( + url=url, + options=options, + asyncio=asyncio, + transfer_staticpool=transfer_staticpool, + share_pool=share_pool, + ) + + yield gen_testing_engine + + engines.testing_reaper._drop_testing_engines("fixture") + + @config.fixture() + def async_testing_engine(self, testing_engine): + def go(**kw): + kw["asyncio"] = True + return testing_engine(**kw) + + return go + + @config.fixture() + def metadata(self, request): + """Provide bound MetaData for a single test, dropping afterwards.""" + + from ...sql import schema + + metadata = schema.MetaData() + request.instance.metadata = metadata + yield metadata + del request.instance.metadata + + if ( + _connection_fixture_connection + and _connection_fixture_connection.in_transaction() + ): + trans = _connection_fixture_connection.get_transaction() + trans.rollback() + with _connection_fixture_connection.begin(): + drop_all_tables_from_metadata( + metadata, _connection_fixture_connection + ) + else: + drop_all_tables_from_metadata(metadata, config.db) + + @config.fixture( + params=[ + (rollback, second_operation, begin_nested) + for rollback in (True, False) + for second_operation in ("none", "execute", "begin") + for begin_nested in ( + True, + False, + ) + ] + ) + def trans_ctx_manager_fixture(self, request, metadata): + rollback, second_operation, begin_nested = request.param + + t = Table("test", metadata, Column("data", Integer)) + eng = getattr(self, "bind", None) or config.db + + t.create(eng) + + def run_test(subject, trans_on_subject, execute_on_subject): + with subject.begin() as trans: + if begin_nested: + if not config.requirements.savepoints.enabled: + config.skip_test("savepoints not enabled") + if execute_on_subject: + nested_trans = subject.begin_nested() + else: + nested_trans = trans.begin_nested() + + with nested_trans: + if execute_on_subject: + subject.execute(t.insert(), {"data": 10}) + else: + trans.execute(t.insert(), {"data": 10}) + + # for nested trans, we always commit/rollback on the + # "nested trans" object itself. + # only Session(future=False) will affect savepoint + # transaction for session.commit/rollback + + if rollback: + nested_trans.rollback() + else: + nested_trans.commit() + + if second_operation != "none": + with assertions.expect_raises_message( + sa.exc.InvalidRequestError, + "Can't operate on closed transaction " + "inside context " + "manager. Please complete the context " + "manager " + "before emitting further commands.", + ): + if second_operation == "execute": + if execute_on_subject: + subject.execute( + t.insert(), {"data": 12} + ) + else: + trans.execute(t.insert(), {"data": 12}) + elif second_operation == "begin": + if execute_on_subject: + subject.begin_nested() + else: + trans.begin_nested() + + # outside the nested trans block, but still inside the + # transaction block, we can run SQL, and it will be + # committed + if execute_on_subject: + subject.execute(t.insert(), {"data": 14}) + else: + trans.execute(t.insert(), {"data": 14}) + + else: + if execute_on_subject: + subject.execute(t.insert(), {"data": 10}) + else: + trans.execute(t.insert(), {"data": 10}) + + if trans_on_subject: + if rollback: + subject.rollback() + else: + subject.commit() + else: + if rollback: + trans.rollback() + else: + trans.commit() + + if second_operation != "none": + with assertions.expect_raises_message( + sa.exc.InvalidRequestError, + "Can't operate on closed transaction inside " + "context " + "manager. Please complete the context manager " + "before emitting further commands.", + ): + if second_operation == "execute": + if execute_on_subject: + subject.execute(t.insert(), {"data": 12}) + else: + trans.execute(t.insert(), {"data": 12}) + elif second_operation == "begin": + if hasattr(trans, "begin"): + trans.begin() + else: + subject.begin() + elif second_operation == "begin_nested": + if execute_on_subject: + subject.begin_nested() + else: + trans.begin_nested() + + expected_committed = 0 + if begin_nested: + # begin_nested variant, we inserted a row after the nested + # block + expected_committed += 1 + if not rollback: + # not rollback variant, our row inserted in the target + # block itself would be committed + expected_committed += 1 + + if execute_on_subject: + eq_( + subject.scalar(select(func.count()).select_from(t)), + expected_committed, + ) + else: + with subject.connect() as conn: + eq_( + conn.scalar(select(func.count()).select_from(t)), + expected_committed, + ) + + return run_test + + +_connection_fixture_connection = None + + +class FutureEngineMixin: + """alembic's suite still using this""" diff --git a/lib/sqlalchemy/testing/fixtures/mypy.py b/lib/sqlalchemy/testing/fixtures/mypy.py new file mode 100644 index 0000000000..80e5ee0733 --- /dev/null +++ b/lib/sqlalchemy/testing/fixtures/mypy.py @@ -0,0 +1,308 @@ +# testing/fixtures/mypy.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +from __future__ import annotations + +import inspect +import os +from pathlib import Path +import re +import shutil +import sys +import tempfile + +from .base import TestBase +from .. import config +from ..assertions import eq_ +from ... import util + + +@config.add_to_marker.mypy +class MypyTest(TestBase): + __requires__ = ("no_sqlalchemy2_stubs",) + + @config.fixture(scope="function") + def per_func_cachedir(self): + yield from self._cachedir() + + @config.fixture(scope="class") + def cachedir(self): + yield from self._cachedir() + + def _cachedir(self): + # as of mypy 0.971 i think we need to keep mypy_path empty + mypy_path = "" + + with tempfile.TemporaryDirectory() as cachedir: + with open( + Path(cachedir) / "sqla_mypy_config.cfg", "w" + ) as config_file: + config_file.write( + f""" + [mypy]\n + plugins = sqlalchemy.ext.mypy.plugin\n + show_error_codes = True\n + {mypy_path} + disable_error_code = no-untyped-call + + [mypy-sqlalchemy.*] + ignore_errors = True + + """ + ) + with open( + Path(cachedir) / "plain_mypy_config.cfg", "w" + ) as config_file: + config_file.write( + f""" + [mypy]\n + show_error_codes = True\n + {mypy_path} + disable_error_code = var-annotated,no-untyped-call + [mypy-sqlalchemy.*] + ignore_errors = True + + """ + ) + yield cachedir + + @config.fixture() + def mypy_runner(self, cachedir): + from mypy import api + + def run(path, use_plugin=False, use_cachedir=None): + if use_cachedir is None: + use_cachedir = cachedir + args = [ + "--strict", + "--raise-exceptions", + "--cache-dir", + use_cachedir, + "--config-file", + os.path.join( + use_cachedir, + "sqla_mypy_config.cfg" + if use_plugin + else "plain_mypy_config.cfg", + ), + ] + + # mypy as of 0.990 is more aggressively blocking messaging + # for paths that are in sys.path, and as pytest puts currdir, + # test/ etc in sys.path, just copy the source file to the + # tempdir we are working in so that we don't have to try to + # manipulate sys.path and/or guess what mypy is doing + filename = os.path.basename(path) + test_program = os.path.join(use_cachedir, filename) + if path != test_program: + shutil.copyfile(path, test_program) + args.append(test_program) + + # I set this locally but for the suite here needs to be + # disabled + os.environ.pop("MYPY_FORCE_COLOR", None) + + stdout, stderr, exitcode = api.run(args) + return stdout, stderr, exitcode + + return run + + @config.fixture + def mypy_typecheck_file(self, mypy_runner): + def run(path, use_plugin=False): + expected_messages = self._collect_messages(path) + stdout, stderr, exitcode = mypy_runner(path, use_plugin=use_plugin) + self._check_output( + path, expected_messages, stdout, stderr, exitcode + ) + + return run + + @staticmethod + def file_combinations(dirname): + if os.path.isabs(dirname): + path = dirname + else: + caller_path = inspect.stack()[1].filename + path = os.path.join(os.path.dirname(caller_path), dirname) + files = list(Path(path).glob("**/*.py")) + + for extra_dir in config.options.mypy_extra_test_paths: + if extra_dir and os.path.isdir(extra_dir): + files.extend((Path(extra_dir) / dirname).glob("**/*.py")) + return files + + def _collect_messages(self, path): + from sqlalchemy.ext.mypy.util import mypy_14 + + expected_messages = [] + expected_re = re.compile(r"\s*# EXPECTED(_MYPY)?(_RE)?(_TYPE)?: (.+)") + py_ver_re = re.compile(r"^#\s*PYTHON_VERSION\s?>=\s?(\d+\.\d+)") + with open(path) as file_: + current_assert_messages = [] + for num, line in enumerate(file_, 1): + m = py_ver_re.match(line) + if m: + major, _, minor = m.group(1).partition(".") + if sys.version_info < (int(major), int(minor)): + config.skip_test( + "Requires python >= %s" % (m.group(1)) + ) + continue + + m = expected_re.match(line) + if m: + is_mypy = bool(m.group(1)) + is_re = bool(m.group(2)) + is_type = bool(m.group(3)) + + expected_msg = re.sub(r"# noqa[:]? ?.*", "", m.group(4)) + if is_type: + if not is_re: + # the goal here is that we can cut-and-paste + # from vscode -> pylance into the + # EXPECTED_TYPE: line, then the test suite will + # validate that line against what mypy produces + expected_msg = re.sub( + r"([\[\]])", + lambda m: rf"\{m.group(0)}", + expected_msg, + ) + + # note making sure preceding text matches + # with a dot, so that an expect for "Select" + # does not match "TypedSelect" + expected_msg = re.sub( + r"([\w_]+)", + lambda m: rf"(?:.*\.)?{m.group(1)}\*?", + expected_msg, + ) + + expected_msg = re.sub( + "List", "builtins.list", expected_msg + ) + + expected_msg = re.sub( + r"\b(int|str|float|bool)\b", + lambda m: rf"builtins.{m.group(0)}\*?", + expected_msg, + ) + # expected_msg = re.sub( + # r"(Sequence|Tuple|List|Union)", + # lambda m: fr"typing.{m.group(0)}\*?", + # expected_msg, + # ) + + is_mypy = is_re = True + expected_msg = f'Revealed type is "{expected_msg}"' + + if mypy_14 and util.py39: + # use_lowercase_names, py39 and above + # https://github.com/python/mypy/blob/304997bfb85200fb521ac727ee0ce3e6085e5278/mypy/options.py#L363 # noqa: E501 + + # skip first character which could be capitalized + # "List item x not found" type of message + expected_msg = expected_msg[0] + re.sub( + r"\b(List|Tuple|Dict|Set)\b" + if is_type + else r"\b(List|Tuple|Dict|Set|Type)\b", + lambda m: m.group(1).lower(), + expected_msg[1:], + ) + + if mypy_14 and util.py310: + # use_or_syntax, py310 and above + # https://github.com/python/mypy/blob/304997bfb85200fb521ac727ee0ce3e6085e5278/mypy/options.py#L368 # noqa: E501 + expected_msg = re.sub( + r"Optional\[(.*?)\]", + lambda m: f"{m.group(1)} | None", + expected_msg, + ) + current_assert_messages.append( + (is_mypy, is_re, expected_msg.strip()) + ) + elif current_assert_messages: + expected_messages.extend( + (num, is_mypy, is_re, expected_msg) + for ( + is_mypy, + is_re, + expected_msg, + ) in current_assert_messages + ) + current_assert_messages[:] = [] + + return expected_messages + + def _check_output(self, path, expected_messages, stdout, stderr, exitcode): + not_located = [] + filename = os.path.basename(path) + if expected_messages: + # mypy 0.990 changed how return codes work, so don't assume a + # 1 or a 0 return code here, could be either depending on if + # errors were generated or not + + output = [] + + raw_lines = stdout.split("\n") + while raw_lines: + e = raw_lines.pop(0) + if re.match(r".+\.py:\d+: error: .*", e): + output.append(("error", e)) + elif re.match( + r".+\.py:\d+: note: +(?:Possible overload|def ).*", e + ): + while raw_lines: + ol = raw_lines.pop(0) + if not re.match(r".+\.py:\d+: note: +def \[.*", ol): + break + elif re.match( + r".+\.py:\d+: note: .*(?:perhaps|suggestion)", e, re.I + ): + pass + elif re.match(r".+\.py:\d+: note: .*", e): + output.append(("note", e)) + + for num, is_mypy, is_re, msg in expected_messages: + msg = msg.replace("'", '"') + prefix = "[SQLAlchemy Mypy plugin] " if not is_mypy else "" + for idx, (typ, errmsg) in enumerate(output): + if is_re: + if re.match( + rf".*{filename}\:{num}\: {typ}\: {prefix}{msg}", + errmsg, + ): + break + elif ( + f"{filename}:{num}: {typ}: {prefix}{msg}" + in errmsg.replace("'", '"') + ): + break + else: + not_located.append(msg) + continue + del output[idx] + + if not_located: + missing = "\n".join(not_located) + print("Couldn't locate expected messages:", missing, sep="\n") + if output: + extra = "\n".join(msg for _, msg in output) + print("Remaining messages:", extra, sep="\n") + assert False, "expected messages not found, see stdout" + + if output: + print(f"{len(output)} messages from mypy were not consumed:") + print("\n".join(msg for _, msg in output)) + assert False, "errors and/or notes remain, see stdout" + + else: + if exitcode != 0: + print(stdout, stderr, sep="\n") + + eq_(exitcode, 0, msg=stdout) diff --git a/lib/sqlalchemy/testing/fixtures/orm.py b/lib/sqlalchemy/testing/fixtures/orm.py new file mode 100644 index 0000000000..da622c068c --- /dev/null +++ b/lib/sqlalchemy/testing/fixtures/orm.py @@ -0,0 +1,227 @@ +# testing/fixtures/orm.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors +from __future__ import annotations + +from typing import Any + +import sqlalchemy as sa +from .base import TestBase +from .sql import TablesTest +from .. import assertions +from .. import config +from .. import schema +from ..entities import BasicEntity +from ..entities import ComparableEntity +from ..util import adict +from ... import orm +from ...orm import DeclarativeBase +from ...orm import events as orm_events +from ...orm import registry + + +class ORMTest(TestBase): + @config.fixture + def fixture_session(self): + return fixture_session() + + +class MappedTest(ORMTest, TablesTest, assertions.AssertsExecutionResults): + # 'once', 'each', None + run_setup_classes = "once" + + # 'once', 'each', None + run_setup_mappers = "each" + + classes: Any = None + + @config.fixture(autouse=True, scope="class") + def _setup_tables_test_class(self): + cls = self.__class__ + cls._init_class() + + if cls.classes is None: + cls.classes = adict() + + cls._setup_once_tables() + cls._setup_once_classes() + cls._setup_once_mappers() + cls._setup_once_inserts() + + yield + + cls._teardown_once_class() + cls._teardown_once_metadata_bind() + + @config.fixture(autouse=True, scope="function") + def _setup_tables_test_instance(self): + self._setup_each_tables() + self._setup_each_classes() + self._setup_each_mappers() + self._setup_each_inserts() + + yield + + orm.session.close_all_sessions() + self._teardown_each_mappers() + self._teardown_each_classes() + self._teardown_each_tables() + + @classmethod + def _teardown_once_class(cls): + cls.classes.clear() + + @classmethod + def _setup_once_classes(cls): + if cls.run_setup_classes == "once": + cls._with_register_classes(cls.setup_classes) + + @classmethod + def _setup_once_mappers(cls): + if cls.run_setup_mappers == "once": + cls.mapper_registry, cls.mapper = cls._generate_registry() + cls._with_register_classes(cls.setup_mappers) + + def _setup_each_mappers(self): + if self.run_setup_mappers != "once": + ( + self.__class__.mapper_registry, + self.__class__.mapper, + ) = self._generate_registry() + + if self.run_setup_mappers == "each": + self._with_register_classes(self.setup_mappers) + + def _setup_each_classes(self): + if self.run_setup_classes == "each": + self._with_register_classes(self.setup_classes) + + @classmethod + def _generate_registry(cls): + decl = registry(metadata=cls._tables_metadata) + return decl, decl.map_imperatively + + @classmethod + def _with_register_classes(cls, fn): + """Run a setup method, framing the operation with a Base class + that will catch new subclasses to be established within + the "classes" registry. + + """ + cls_registry = cls.classes + + class _Base: + def __init_subclass__(cls) -> None: + assert cls_registry is not None + cls_registry[cls.__name__] = cls + super().__init_subclass__() + + class Basic(BasicEntity, _Base): + pass + + class Comparable(ComparableEntity, _Base): + pass + + cls.Basic = Basic + cls.Comparable = Comparable + fn() + + def _teardown_each_mappers(self): + # some tests create mappers in the test bodies + # and will define setup_mappers as None - + # clear mappers in any case + if self.run_setup_mappers != "once": + orm.clear_mappers() + + def _teardown_each_classes(self): + if self.run_setup_classes != "once": + self.classes.clear() + + @classmethod + def setup_classes(cls): + pass + + @classmethod + def setup_mappers(cls): + pass + + +class DeclarativeMappedTest(MappedTest): + run_setup_classes = "once" + run_setup_mappers = "once" + + @classmethod + def _setup_once_tables(cls): + pass + + @classmethod + def _with_register_classes(cls, fn): + cls_registry = cls.classes + + class _DeclBase(DeclarativeBase): + __table_cls__ = schema.Table + metadata = cls._tables_metadata + type_annotation_map = { + str: sa.String().with_variant( + sa.String(50), "mysql", "mariadb", "oracle" + ) + } + + def __init_subclass__(cls, **kw) -> None: + assert cls_registry is not None + cls_registry[cls.__name__] = cls + super().__init_subclass__(**kw) + + cls.DeclarativeBasic = _DeclBase + + # sets up cls.Basic which is helpful for things like composite + # classes + super()._with_register_classes(fn) + + if cls._tables_metadata.tables and cls.run_create_tables: + cls._tables_metadata.create_all(config.db) + + +class RemoveORMEventsGlobally: + @config.fixture(autouse=True) + def _remove_listeners(self): + yield + orm_events.MapperEvents._clear() + orm_events.InstanceEvents._clear() + orm_events.SessionEvents._clear() + orm_events.InstrumentationEvents._clear() + orm_events.QueryEvents._clear() + + +_fixture_sessions = set() + + +def fixture_session(**kw): + kw.setdefault("autoflush", True) + kw.setdefault("expire_on_commit", True) + + bind = kw.pop("bind", config.db) + + sess = orm.Session(bind, **kw) + _fixture_sessions.add(sess) + return sess + + +def close_all_sessions(): + # will close all still-referenced sessions + orm.close_all_sessions() + _fixture_sessions.clear() + + +def stop_test_class_inside_fixtures(cls): + close_all_sessions() + orm.clear_mappers() + + +def after_test(): + if _fixture_sessions: + close_all_sessions() diff --git a/lib/sqlalchemy/testing/fixtures/sql.py b/lib/sqlalchemy/testing/fixtures/sql.py new file mode 100644 index 0000000000..911dddda31 --- /dev/null +++ b/lib/sqlalchemy/testing/fixtures/sql.py @@ -0,0 +1,492 @@ +# testing/fixtures/sql.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors +from __future__ import annotations + +import itertools +import random +import re +import sys + +import sqlalchemy as sa +from .base import TestBase +from .. import config +from .. import mock +from ..assertions import eq_ +from ..assertions import ne_ +from ..util import adict +from ..util import drop_all_tables_from_metadata +from ... import event +from ... import util +from ...schema import sort_tables_and_constraints +from ...sql import visitors +from ...sql.elements import ClauseElement + + +class TablesTest(TestBase): + # 'once', None + run_setup_bind = "once" + + # 'once', 'each', None + run_define_tables = "once" + + # 'once', 'each', None + run_create_tables = "once" + + # 'once', 'each', None + run_inserts = "each" + + # 'each', None + run_deletes = "each" + + # 'once', None + run_dispose_bind = None + + bind = None + _tables_metadata = None + tables = None + other = None + sequences = None + + @config.fixture(autouse=True, scope="class") + def _setup_tables_test_class(self): + cls = self.__class__ + cls._init_class() + + cls._setup_once_tables() + + cls._setup_once_inserts() + + yield + + cls._teardown_once_metadata_bind() + + @config.fixture(autouse=True, scope="function") + def _setup_tables_test_instance(self): + self._setup_each_tables() + self._setup_each_inserts() + + yield + + self._teardown_each_tables() + + @property + def tables_test_metadata(self): + return self._tables_metadata + + @classmethod + def _init_class(cls): + if cls.run_define_tables == "each": + if cls.run_create_tables == "once": + cls.run_create_tables = "each" + assert cls.run_inserts in ("each", None) + + cls.other = adict() + cls.tables = adict() + cls.sequences = adict() + + cls.bind = cls.setup_bind() + cls._tables_metadata = sa.MetaData() + + @classmethod + def _setup_once_inserts(cls): + if cls.run_inserts == "once": + cls._load_fixtures() + with cls.bind.begin() as conn: + cls.insert_data(conn) + + @classmethod + def _setup_once_tables(cls): + if cls.run_define_tables == "once": + cls.define_tables(cls._tables_metadata) + if cls.run_create_tables == "once": + cls._tables_metadata.create_all(cls.bind) + cls.tables.update(cls._tables_metadata.tables) + cls.sequences.update(cls._tables_metadata._sequences) + + def _setup_each_tables(self): + if self.run_define_tables == "each": + self.define_tables(self._tables_metadata) + if self.run_create_tables == "each": + self._tables_metadata.create_all(self.bind) + self.tables.update(self._tables_metadata.tables) + self.sequences.update(self._tables_metadata._sequences) + elif self.run_create_tables == "each": + self._tables_metadata.create_all(self.bind) + + def _setup_each_inserts(self): + if self.run_inserts == "each": + self._load_fixtures() + with self.bind.begin() as conn: + self.insert_data(conn) + + def _teardown_each_tables(self): + if self.run_define_tables == "each": + self.tables.clear() + if self.run_create_tables == "each": + drop_all_tables_from_metadata(self._tables_metadata, self.bind) + self._tables_metadata.clear() + elif self.run_create_tables == "each": + drop_all_tables_from_metadata(self._tables_metadata, self.bind) + + savepoints = getattr(config.requirements, "savepoints", False) + if savepoints: + savepoints = savepoints.enabled + + # no need to run deletes if tables are recreated on setup + if ( + self.run_define_tables != "each" + and self.run_create_tables != "each" + and self.run_deletes == "each" + ): + with self.bind.begin() as conn: + for table in reversed( + [ + t + for (t, fks) in sort_tables_and_constraints( + self._tables_metadata.tables.values() + ) + if t is not None + ] + ): + try: + if savepoints: + with conn.begin_nested(): + conn.execute(table.delete()) + else: + conn.execute(table.delete()) + except sa.exc.DBAPIError as ex: + print( + ("Error emptying table %s: %r" % (table, ex)), + file=sys.stderr, + ) + + @classmethod + def _teardown_once_metadata_bind(cls): + if cls.run_create_tables: + drop_all_tables_from_metadata(cls._tables_metadata, cls.bind) + + if cls.run_dispose_bind == "once": + cls.dispose_bind(cls.bind) + + cls._tables_metadata.bind = None + + if cls.run_setup_bind is not None: + cls.bind = None + + @classmethod + def setup_bind(cls): + return config.db + + @classmethod + def dispose_bind(cls, bind): + if hasattr(bind, "dispose"): + bind.dispose() + elif hasattr(bind, "close"): + bind.close() + + @classmethod + def define_tables(cls, metadata): + pass + + @classmethod + def fixtures(cls): + return {} + + @classmethod + def insert_data(cls, connection): + pass + + def sql_count_(self, count, fn): + self.assert_sql_count(self.bind, fn, count) + + def sql_eq_(self, callable_, statements): + self.assert_sql(self.bind, callable_, statements) + + @classmethod + def _load_fixtures(cls): + """Insert rows as represented by the fixtures() method.""" + headers, rows = {}, {} + for table, data in cls.fixtures().items(): + if len(data) < 2: + continue + if isinstance(table, str): + table = cls.tables[table] + headers[table] = data[0] + rows[table] = data[1:] + for table, fks in sort_tables_and_constraints( + cls._tables_metadata.tables.values() + ): + if table is None: + continue + if table not in headers: + continue + with cls.bind.begin() as conn: + conn.execute( + table.insert(), + [ + dict(zip(headers[table], column_values)) + for column_values in rows[table] + ], + ) + + +class NoCache: + @config.fixture(autouse=True, scope="function") + def _disable_cache(self): + _cache = config.db._compiled_cache + config.db._compiled_cache = None + yield + config.db._compiled_cache = _cache + + +class RemovesEvents: + @util.memoized_property + def _event_fns(self): + return set() + + def event_listen(self, target, name, fn, **kw): + self._event_fns.add((target, name, fn)) + event.listen(target, name, fn, **kw) + + @config.fixture(autouse=True, scope="function") + def _remove_events(self): + yield + for key in self._event_fns: + event.remove(*key) + + +class ComputedReflectionFixtureTest(TablesTest): + run_inserts = run_deletes = None + + __backend__ = True + __requires__ = ("computed_columns", "table_reflection") + + regexp = re.compile(r"[\[\]\(\)\s`'\"]*") + + def normalize(self, text): + return self.regexp.sub("", text).lower() + + @classmethod + def define_tables(cls, metadata): + from ... import Integer + from ... import testing + from ...schema import Column + from ...schema import Computed + from ...schema import Table + + Table( + "computed_default_table", + metadata, + Column("id", Integer, primary_key=True), + Column("normal", Integer), + Column("computed_col", Integer, Computed("normal + 42")), + Column("with_default", Integer, server_default="42"), + ) + + t = Table( + "computed_column_table", + metadata, + Column("id", Integer, primary_key=True), + Column("normal", Integer), + Column("computed_no_flag", Integer, Computed("normal + 42")), + ) + + if testing.requires.schemas.enabled: + t2 = Table( + "computed_column_table", + metadata, + Column("id", Integer, primary_key=True), + Column("normal", Integer), + Column("computed_no_flag", Integer, Computed("normal / 42")), + schema=config.test_schema, + ) + + if testing.requires.computed_columns_virtual.enabled: + t.append_column( + Column( + "computed_virtual", + Integer, + Computed("normal + 2", persisted=False), + ) + ) + if testing.requires.schemas.enabled: + t2.append_column( + Column( + "computed_virtual", + Integer, + Computed("normal / 2", persisted=False), + ) + ) + if testing.requires.computed_columns_stored.enabled: + t.append_column( + Column( + "computed_stored", + Integer, + Computed("normal - 42", persisted=True), + ) + ) + if testing.requires.schemas.enabled: + t2.append_column( + Column( + "computed_stored", + Integer, + Computed("normal * 42", persisted=True), + ) + ) + + +class CacheKeyFixture: + def _compare_equal(self, a, b, compare_values): + a_key = a._generate_cache_key() + b_key = b._generate_cache_key() + + if a_key is None: + assert a._annotations.get("nocache") + + assert b_key is None + else: + eq_(a_key.key, b_key.key) + eq_(hash(a_key.key), hash(b_key.key)) + + for a_param, b_param in zip(a_key.bindparams, b_key.bindparams): + assert a_param.compare(b_param, compare_values=compare_values) + return a_key, b_key + + def _run_cache_key_fixture(self, fixture, compare_values): + case_a = fixture() + case_b = fixture() + + for a, b in itertools.combinations_with_replacement( + range(len(case_a)), 2 + ): + if a == b: + a_key, b_key = self._compare_equal( + case_a[a], case_b[b], compare_values + ) + if a_key is None: + continue + else: + a_key = case_a[a]._generate_cache_key() + b_key = case_b[b]._generate_cache_key() + + if a_key is None or b_key is None: + if a_key is None: + assert case_a[a]._annotations.get("nocache") + if b_key is None: + assert case_b[b]._annotations.get("nocache") + continue + + if a_key.key == b_key.key: + for a_param, b_param in zip( + a_key.bindparams, b_key.bindparams + ): + if not a_param.compare( + b_param, compare_values=compare_values + ): + break + else: + # this fails unconditionally since we could not + # find bound parameter values that differed. + # Usually we intended to get two distinct keys here + # so the failure will be more descriptive using the + # ne_() assertion. + ne_(a_key.key, b_key.key) + else: + ne_(a_key.key, b_key.key) + + # ClauseElement-specific test to ensure the cache key + # collected all the bound parameters that aren't marked + # as "literal execute" + if isinstance(case_a[a], ClauseElement) and isinstance( + case_b[b], ClauseElement + ): + assert_a_params = [] + assert_b_params = [] + + for elem in visitors.iterate(case_a[a]): + if elem.__visit_name__ == "bindparam": + assert_a_params.append(elem) + + for elem in visitors.iterate(case_b[b]): + if elem.__visit_name__ == "bindparam": + assert_b_params.append(elem) + + # note we're asserting the order of the params as well as + # if there are dupes or not. ordering has to be + # deterministic and matches what a traversal would provide. + eq_( + sorted(a_key.bindparams, key=lambda b: b.key), + sorted( + util.unique_list(assert_a_params), key=lambda b: b.key + ), + ) + eq_( + sorted(b_key.bindparams, key=lambda b: b.key), + sorted( + util.unique_list(assert_b_params), key=lambda b: b.key + ), + ) + + def _run_cache_key_equal_fixture(self, fixture, compare_values): + case_a = fixture() + case_b = fixture() + + for a, b in itertools.combinations_with_replacement( + range(len(case_a)), 2 + ): + self._compare_equal(case_a[a], case_b[b], compare_values) + + +def insertmanyvalues_fixture( + connection, randomize_rows=False, warn_on_downgraded=False +): + dialect = connection.dialect + orig_dialect = dialect._deliver_insertmanyvalues_batches + orig_conn = connection._exec_insertmany_context + + class RandomCursor: + __slots__ = ("cursor",) + + def __init__(self, cursor): + self.cursor = cursor + + # only this method is called by the deliver method. + # by not having the other methods we assert that those aren't being + # used + + def fetchall(self): + rows = self.cursor.fetchall() + rows = list(rows) + random.shuffle(rows) + return rows + + def _deliver_insertmanyvalues_batches( + cursor, statement, parameters, generic_setinputsizes, context + ): + if randomize_rows: + cursor = RandomCursor(cursor) + for batch in orig_dialect( + cursor, statement, parameters, generic_setinputsizes, context + ): + if warn_on_downgraded and batch.is_downgraded: + util.warn("Batches were downgraded for sorted INSERT") + + yield batch + + def _exec_insertmany_context( + dialect, + context, + ): + with mock.patch.object( + dialect, + "_deliver_insertmanyvalues_batches", + new=_deliver_insertmanyvalues_batches, + ): + return orig_conn(dialect, context) + + connection._exec_insertmany_context = _exec_insertmany_context diff --git a/lib/sqlalchemy/testing/pickleable.py b/lib/sqlalchemy/testing/pickleable.py index b0823983fe..89155a8419 100644 --- a/lib/sqlalchemy/testing/pickleable.py +++ b/lib/sqlalchemy/testing/pickleable.py @@ -13,20 +13,20 @@ unpickling. from __future__ import annotations -from . import fixtures +from .entities import ComparableEntity from ..schema import Column from ..types import String -class User(fixtures.ComparableEntity): +class User(ComparableEntity): pass -class Order(fixtures.ComparableEntity): +class Order(ComparableEntity): pass -class Dingaling(fixtures.ComparableEntity): +class Dingaling(ComparableEntity): pass @@ -34,20 +34,20 @@ class EmailUser(User): pass -class Address(fixtures.ComparableEntity): +class Address(ComparableEntity): pass # TODO: these are kind of arbitrary.... -class Child1(fixtures.ComparableEntity): +class Child1(ComparableEntity): pass -class Child2(fixtures.ComparableEntity): +class Child2(ComparableEntity): pass -class Parent(fixtures.ComparableEntity): +class Parent(ComparableEntity): pass @@ -61,7 +61,7 @@ class Mixin: email_address = Column(String) -class AddressWMixin(Mixin, fixtures.ComparableEntity): +class AddressWMixin(Mixin, ComparableEntity): pass diff --git a/lib/sqlalchemy/testing/plugin/plugin_base.py b/lib/sqlalchemy/testing/plugin/plugin_base.py index cff53ea727..393070d08c 100644 --- a/lib/sqlalchemy/testing/plugin/plugin_base.py +++ b/lib/sqlalchemy/testing/plugin/plugin_base.py @@ -10,6 +10,7 @@ from __future__ import annotations import abc +from argparse import Namespace import configparser import logging import os @@ -51,7 +52,7 @@ file_config = None logging = None include_tags = set() exclude_tags = set() -options = None +options: Namespace = None # type: ignore def setup_options(make_option): diff --git a/lib/sqlalchemy/testing/plugin/pytestplugin.py b/lib/sqlalchemy/testing/plugin/pytestplugin.py index 17bd038d38..a676e7e28d 100644 --- a/lib/sqlalchemy/testing/plugin/pytestplugin.py +++ b/lib/sqlalchemy/testing/plugin/pytestplugin.py @@ -11,13 +11,15 @@ import operator import os import re import sys +from typing import TYPE_CHECKING import uuid import pytest try: # installed by bootstrap.py - import sqla_plugin_base as plugin_base + if not TYPE_CHECKING: + import sqla_plugin_base as plugin_base except ImportError: # assume we're a package, use traditional import from . import plugin_base diff --git a/lib/sqlalchemy/testing/util.py b/lib/sqlalchemy/testing/util.py index ccd06716e0..cf24b43a96 100644 --- a/lib/sqlalchemy/testing/util.py +++ b/lib/sqlalchemy/testing/util.py @@ -216,22 +216,21 @@ def provide_metadata(fn, *args, **kw): # we have to hardcode some of that cleanup ahead of time. # close ORM sessions - fixtures._close_all_sessions() + fixtures.close_all_sessions() # integrate with the "connection" fixture as there are many # tests where it is used along with provide_metadata - if fixtures._connection_fixture_connection: + cfc = fixtures.base._connection_fixture_connection + if cfc: # TODO: this warning can be used to find all the places # this is used with connection fixture # warn("mixing legacy provide metadata with connection fixture") - drop_all_tables_from_metadata( - metadata, fixtures._connection_fixture_connection - ) + drop_all_tables_from_metadata(metadata, cfc) # as the provide_metadata fixture is often used with "testing.db", # when we do the drop we have to commit the transaction so that # the DB is actually updated as the CREATE would have been # committed - fixtures._connection_fixture_connection.get_transaction().commit() + cfc.get_transaction().commit() else: drop_all_tables_from_metadata(metadata, config.db) self.metadata = prev_meta diff --git a/setup.cfg b/setup.cfg index 4857a19f14..ed5a4a92e5 100644 --- a/setup.cfg +++ b/setup.cfg @@ -113,7 +113,7 @@ application-import-names = sqlalchemy,test per-file-ignores = **/__init__.py:F401 test/*:FA100 - test/ext/mypy/plain_files/*:F821,E501,FA100 + test/typing/plain_files/*:F821,E501,FA100 test/ext/mypy/plugin_files/*:F821,E501,FA100 lib/sqlalchemy/events.py:F401 lib/sqlalchemy/schema.py:F401 diff --git a/test/aaa_profiling/test_memusage.py b/test/aaa_profiling/test_memusage.py index 047853e675..fc6be0f096 100644 --- a/test/aaa_profiling/test_memusage.py +++ b/test/aaa_profiling/test_memusage.py @@ -46,6 +46,7 @@ from sqlalchemy.testing import engines from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures from sqlalchemy.testing import pickleable +from sqlalchemy.testing.entities import ComparableEntity from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table @@ -53,11 +54,11 @@ from sqlalchemy.testing.util import gc_collect from ..orm import _fixtures -class A(fixtures.ComparableEntity): +class A(ComparableEntity): pass -class B(fixtures.ComparableEntity): +class B(ComparableEntity): pass @@ -916,7 +917,7 @@ class MemUsageWBackendTest(fixtures.MappedTest, EnsureZeroed): @profile_memory() def go(): - class A(fixtures.ComparableEntity): + class A(ComparableEntity): pass class B(A): @@ -997,10 +998,10 @@ class MemUsageWBackendTest(fixtures.MappedTest, EnsureZeroed): @profile_memory() def go(): - class A(fixtures.ComparableEntity): + class A(ComparableEntity): pass - class B(fixtures.ComparableEntity): + class B(ComparableEntity): pass self.mapper_registry.map_imperatively( diff --git a/test/ext/asyncio/test_session_py3k.py b/test/ext/asyncio/test_session_py3k.py index 1767f2f4e3..228489349a 100644 --- a/test/ext/asyncio/test_session_py3k.py +++ b/test/ext/asyncio/test_session_py3k.py @@ -41,6 +41,7 @@ from sqlalchemy.testing import is_true from sqlalchemy.testing import mock from sqlalchemy.testing.assertions import expect_deprecated from sqlalchemy.testing.assertions import is_false +from sqlalchemy.testing.entities import ComparableEntity from sqlalchemy.testing.provision import normalize_sequence from .test_engine_py3k import AsyncFixture as _AsyncFixture from ...orm import _fixtures @@ -1023,7 +1024,7 @@ class AsyncAttrsTest( def decl_base(self, metadata): _md = metadata - class Base(fixtures.ComparableEntity, AsyncAttrs, DeclarativeBase): + class Base(ComparableEntity, AsyncAttrs, DeclarativeBase): metadata = _md type_annotation_map = { str: String().with_variant( diff --git a/test/ext/declarative/test_inheritance.py b/test/ext/declarative/test_inheritance.py index 62e15a124b..d6d059cbef 100644 --- a/test/ext/declarative/test_inheritance.py +++ b/test/ext/declarative/test_inheritance.py @@ -21,6 +21,7 @@ from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures from sqlalchemy.testing import mock from sqlalchemy.testing.assertions import expect_raises_message +from sqlalchemy.testing.entities import ComparableEntity from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.fixtures import RemoveORMEventsGlobally from sqlalchemy.testing.schema import Column @@ -144,7 +145,7 @@ class ConcreteInhTest( "punion", ) - class Employee(Base, fixtures.ComparableEntity): + class Employee(Base, ComparableEntity): __table__ = punion __mapper_args__ = {"polymorphic_on": punion.c.type} @@ -174,7 +175,7 @@ class ConcreteInhTest( def test_concrete_inline_non_polymorphic(self): """test the example from the declarative docs.""" - class Employee(Base, fixtures.ComparableEntity): + class Employee(Base, ComparableEntity): __tablename__ = "people" id = Column( Integer, primary_key=True, test_needs_autoincrement=True @@ -211,7 +212,7 @@ class ConcreteInhTest( self._roundtrip(Employee, Manager, Engineer, Boss, polymorphic=False) def test_abstract_concrete_base_didnt_configure(self): - class Employee(AbstractConcreteBase, Base, fixtures.ComparableEntity): + class Employee(AbstractConcreteBase, Base, ComparableEntity): strict_attrs = True assert_raises_message( @@ -269,7 +270,7 @@ class ConcreteInhTest( ) def test_abstract_concrete_extension(self): - class Employee(AbstractConcreteBase, Base, fixtures.ComparableEntity): + class Employee(AbstractConcreteBase, Base, ComparableEntity): name = Column(String(50)) class Manager(Employee): @@ -321,7 +322,7 @@ class ConcreteInhTest( def test_abstract_concrete_extension_descriptor_refresh( self, use_strict_attrs ): - class Employee(AbstractConcreteBase, Base, fixtures.ComparableEntity): + class Employee(AbstractConcreteBase, Base, ComparableEntity): strict_attrs = use_strict_attrs @declared_attr @@ -378,7 +379,7 @@ class ConcreteInhTest( eq_(e1.name, "d") def test_concrete_extension(self): - class Employee(ConcreteBase, Base, fixtures.ComparableEntity): + class Employee(ConcreteBase, Base, ComparableEntity): __tablename__ = "employee" employee_id = Column( Integer, primary_key=True, test_needs_autoincrement=True @@ -428,7 +429,7 @@ class ConcreteInhTest( self._roundtrip(Employee, Manager, Engineer, Boss) def test_concrete_extension_warn_for_overlap(self): - class Employee(ConcreteBase, Base, fixtures.ComparableEntity): + class Employee(ConcreteBase, Base, ComparableEntity): __tablename__ = "employee" employee_id = Column( @@ -463,7 +464,7 @@ class ConcreteInhTest( configure_mappers() def test_concrete_extension_warn_concrete_disc_resolves_overlap(self): - class Employee(ConcreteBase, Base, fixtures.ComparableEntity): + class Employee(ConcreteBase, Base, ComparableEntity): _concrete_discriminator_name = "_type" __tablename__ = "employee" @@ -562,7 +563,7 @@ class ConcreteInhTest( ) def test_abs_concrete_extension_warn_for_overlap(self): - class Employee(AbstractConcreteBase, Base, fixtures.ComparableEntity): + class Employee(AbstractConcreteBase, Base, ComparableEntity): name = Column(String(50)) __mapper_args__ = { "polymorphic_identity": "employee", @@ -595,7 +596,7 @@ class ConcreteInhTest( def test_abs_concrete_extension_warn_concrete_disc_resolves_overlap( self, use_strict_attrs ): - class Employee(AbstractConcreteBase, Base, fixtures.ComparableEntity): + class Employee(AbstractConcreteBase, Base, ComparableEntity): strict_attrs = use_strict_attrs _concrete_discriminator_name = "_type" @@ -671,7 +672,7 @@ class ConcreteInhTest( assert PolyTest.__mapper__.polymorphic_on is Test.__table__.c.type def test_ok_to_override_type_from_abstract(self): - class Employee(AbstractConcreteBase, Base, fixtures.ComparableEntity): + class Employee(AbstractConcreteBase, Base, ComparableEntity): name = Column(String(50)) class Manager(Employee): @@ -734,7 +735,7 @@ class ConcreteExtensionConfigTest( __dialect__ = "default" def test_classreg_setup(self): - class A(Base, fixtures.ComparableEntity): + class A(Base, ComparableEntity): __tablename__ = "a" id = Column( Integer, primary_key=True, test_needs_autoincrement=True @@ -744,7 +745,7 @@ class ConcreteExtensionConfigTest( "BC", primaryjoin="BC.a_id == A.id", collection_class=set ) - class BC(AbstractConcreteBase, Base, fixtures.ComparableEntity): + class BC(AbstractConcreteBase, Base, ComparableEntity): a_id = Column(Integer, ForeignKey("a.id")) class B(BC): diff --git a/test/ext/declarative/test_reflection.py b/test/ext/declarative/test_reflection.py index 103d3d07ff..4f81d7c470 100644 --- a/test/ext/declarative/test_reflection.py +++ b/test/ext/declarative/test_reflection.py @@ -22,6 +22,7 @@ from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import eq_ from sqlalchemy.testing import expect_raises_message from sqlalchemy.testing import fixtures +from sqlalchemy.testing.entities import ComparableEntity from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table @@ -74,11 +75,11 @@ class DeferredReflectPKFKTest(DeferredReflectBase): ) def test_pk_fk(self): - class B(DeferredReflection, fixtures.ComparableEntity, Base): + class B(DeferredReflection, ComparableEntity, Base): __tablename__ = "b" a = relationship("A") - class A(DeferredReflection, fixtures.ComparableEntity, Base): + class A(DeferredReflection, ComparableEntity, Base): __tablename__ = "a" DeferredReflection.prepare(testing.db) @@ -133,11 +134,11 @@ class DeferredReflectionTest(testing.AssertsCompiledSQL, DeferredReflectBase): eq_(a1.user, User(name="u1")) def test_exception_prepare_not_called(self): - class User(DeferredReflection, fixtures.ComparableEntity, Base): + class User(DeferredReflection, ComparableEntity, Base): __tablename__ = "users" addresses = relationship("Address", backref="user") - class Address(DeferredReflection, fixtures.ComparableEntity, Base): + class Address(DeferredReflection, ComparableEntity, Base): __tablename__ = "addresses" assert_raises_message( @@ -152,11 +153,11 @@ class DeferredReflectionTest(testing.AssertsCompiledSQL, DeferredReflectBase): @testing.variation("bind", ["engine", "connection", "raise_"]) def test_basic_deferred(self, bind): - class User(DeferredReflection, fixtures.ComparableEntity, Base): + class User(DeferredReflection, ComparableEntity, Base): __tablename__ = "users" addresses = relationship("Address", backref="user") - class Address(DeferredReflection, fixtures.ComparableEntity, Base): + class Address(DeferredReflection, ComparableEntity, Base): __tablename__ = "addresses" if bind.engine: @@ -218,11 +219,11 @@ class DeferredReflectionTest(testing.AssertsCompiledSQL, DeferredReflectBase): class OtherDefBase(DeferredReflection, Base): __abstract__ = True - class User(fixtures.ComparableEntity, DefBase): + class User(ComparableEntity, DefBase): __tablename__ = "users" addresses = relationship("Address", backref="user") - class Address(fixtures.ComparableEntity, DefBase): + class Address(ComparableEntity, DefBase): __tablename__ = "addresses" class Fake(OtherDefBase): @@ -232,11 +233,11 @@ class DeferredReflectionTest(testing.AssertsCompiledSQL, DeferredReflectBase): self._roundtrip() def test_redefine_fk_double(self): - class User(DeferredReflection, fixtures.ComparableEntity, Base): + class User(DeferredReflection, ComparableEntity, Base): __tablename__ = "users" addresses = relationship("Address", backref="user") - class Address(DeferredReflection, fixtures.ComparableEntity, Base): + class Address(DeferredReflection, ComparableEntity, Base): __tablename__ = "addresses" user_id = Column(Integer, ForeignKey("users.id")) @@ -247,7 +248,7 @@ class DeferredReflectionTest(testing.AssertsCompiledSQL, DeferredReflectBase): """test that __mapper_args__ is not called until *after* table reflection""" - class User(DeferredReflection, fixtures.ComparableEntity, Base): + class User(DeferredReflection, ComparableEntity, Base): __tablename__ = "users" @declared_attr @@ -277,10 +278,10 @@ class DeferredReflectionTest(testing.AssertsCompiledSQL, DeferredReflectBase): @testing.requires.predictable_gc def test_cls_not_strong_ref(self): - class User(DeferredReflection, fixtures.ComparableEntity, Base): + class User(DeferredReflection, ComparableEntity, Base): __tablename__ = "users" - class Address(DeferredReflection, fixtures.ComparableEntity, Base): + class Address(DeferredReflection, ComparableEntity, Base): __tablename__ = "addresses" eq_(len(_DeferredMapperConfig._configs), 2) @@ -340,26 +341,26 @@ class DeferredSecondaryReflectionTest(DeferredReflectBase): ) def test_string_resolution(self): - class User(DeferredReflection, fixtures.ComparableEntity, Base): + class User(DeferredReflection, ComparableEntity, Base): __tablename__ = "users" items = relationship("Item", secondary="user_items") - class Item(DeferredReflection, fixtures.ComparableEntity, Base): + class Item(DeferredReflection, ComparableEntity, Base): __tablename__ = "items" DeferredReflection.prepare(testing.db) self._roundtrip() def test_table_resolution(self): - class User(DeferredReflection, fixtures.ComparableEntity, Base): + class User(DeferredReflection, ComparableEntity, Base): __tablename__ = "users" items = relationship( "Item", secondary=Table("user_items", Base.metadata) ) - class Item(DeferredReflection, fixtures.ComparableEntity, Base): + class Item(DeferredReflection, ComparableEntity, Base): __tablename__ = "items" DeferredReflection.prepare(testing.db) @@ -408,7 +409,7 @@ class DeferredSingleInhReflectionTest(DeferredInhReflectBase): ) def test_basic(self, decl_base): - class Foo(DeferredReflection, fixtures.ComparableEntity, decl_base): + class Foo(DeferredReflection, ComparableEntity, decl_base): __tablename__ = "foo" __mapper_args__ = { "polymorphic_on": "type", @@ -422,7 +423,7 @@ class DeferredSingleInhReflectionTest(DeferredInhReflectBase): self._roundtrip() def test_add_subclass_column(self, decl_base): - class Foo(DeferredReflection, fixtures.ComparableEntity, decl_base): + class Foo(DeferredReflection, ComparableEntity, decl_base): __tablename__ = "foo" __mapper_args__ = { "polymorphic_on": "type", @@ -437,7 +438,7 @@ class DeferredSingleInhReflectionTest(DeferredInhReflectBase): self._roundtrip() def test_add_subclass_mapped_column(self, decl_base): - class Foo(DeferredReflection, fixtures.ComparableEntity, decl_base): + class Foo(DeferredReflection, ComparableEntity, decl_base): __tablename__ = "foo" __mapper_args__ = { "polymorphic_on": "type", @@ -452,7 +453,7 @@ class DeferredSingleInhReflectionTest(DeferredInhReflectBase): self._roundtrip() def test_subclass_mapped_column_no_existing(self, decl_base): - class Foo(DeferredReflection, fixtures.ComparableEntity, decl_base): + class Foo(DeferredReflection, ComparableEntity, decl_base): __tablename__ = "foo" __mapper_args__ = { "polymorphic_on": "type", @@ -469,7 +470,7 @@ class DeferredSingleInhReflectionTest(DeferredInhReflectBase): bar_data: Mapped[str] = mapped_column(use_existing_column=True) def test_add_pk_column(self, decl_base): - class Foo(DeferredReflection, fixtures.ComparableEntity, decl_base): + class Foo(DeferredReflection, ComparableEntity, decl_base): __tablename__ = "foo" __mapper_args__ = { "polymorphic_on": "type", @@ -484,7 +485,7 @@ class DeferredSingleInhReflectionTest(DeferredInhReflectBase): self._roundtrip() def test_add_pk_mapped_column(self, decl_base): - class Foo(DeferredReflection, fixtures.ComparableEntity, decl_base): + class Foo(DeferredReflection, ComparableEntity, decl_base): __tablename__ = "foo" __mapper_args__ = { "polymorphic_on": "type", @@ -521,7 +522,7 @@ class DeferredJoinedInhReflectionTest(DeferredInhReflectBase): ) def test_basic(self): - class Foo(DeferredReflection, fixtures.ComparableEntity, Base): + class Foo(DeferredReflection, ComparableEntity, Base): __tablename__ = "foo" __mapper_args__ = { "polymorphic_on": "type", @@ -536,7 +537,7 @@ class DeferredJoinedInhReflectionTest(DeferredInhReflectBase): self._roundtrip() def test_add_subclass_column(self): - class Foo(DeferredReflection, fixtures.ComparableEntity, Base): + class Foo(DeferredReflection, ComparableEntity, Base): __tablename__ = "foo" __mapper_args__ = { "polymorphic_on": "type", @@ -552,7 +553,7 @@ class DeferredJoinedInhReflectionTest(DeferredInhReflectBase): self._roundtrip() def test_add_pk_column(self): - class Foo(DeferredReflection, fixtures.ComparableEntity, Base): + class Foo(DeferredReflection, ComparableEntity, Base): __tablename__ = "foo" __mapper_args__ = { "polymorphic_on": "type", @@ -568,7 +569,7 @@ class DeferredJoinedInhReflectionTest(DeferredInhReflectBase): self._roundtrip() def test_add_fk_pk_column(self): - class Foo(DeferredReflection, fixtures.ComparableEntity, Base): + class Foo(DeferredReflection, ComparableEntity, Base): __tablename__ = "foo" __mapper_args__ = { "polymorphic_on": "type", diff --git a/test/ext/mypy/plain_files/core_ddl.py b/test/ext/mypy/plain_files/core_ddl.py deleted file mode 100644 index 673a90e943..0000000000 --- a/test/ext/mypy/plain_files/core_ddl.py +++ /dev/null @@ -1,43 +0,0 @@ -from sqlalchemy import CheckConstraint -from sqlalchemy import Column -from sqlalchemy import DateTime -from sqlalchemy import ForeignKey -from sqlalchemy import Index -from sqlalchemy import Integer -from sqlalchemy import MetaData -from sqlalchemy import PrimaryKeyConstraint -from sqlalchemy import String -from sqlalchemy import Table - - -m = MetaData() - - -t1 = Table( - "t1", - m, - Column("id", Integer, primary_key=True), - Column("data", String), - Column("data2", String(50)), - Column("timestamp", DateTime()), - Index(None, "data2"), -) - -t2 = Table( - "t2", - m, - Column("t1id", ForeignKey("t1.id")), - Column("q", Integer, CheckConstraint("q > 5")), -) - -t3 = Table( - "t3", - m, - Column("x", Integer), - Column("y", Integer), - Column("t1id", ForeignKey(t1.c.id)), - PrimaryKeyConstraint("x", "y"), -) - -# cols w/ no name or type, used by declarative -c1: Column[int] = Column(ForeignKey(t3.c.x)) diff --git a/test/ext/mypy/test_mypy_plugin_py3k.py b/test/ext/mypy/test_mypy_plugin_py3k.py index a9cc1eb336..f1b36ac52b 100644 --- a/test/ext/mypy/test_mypy_plugin_py3k.py +++ b/test/ext/mypy/test_mypy_plugin_py3k.py @@ -1,35 +1,11 @@ import os -import re import shutil -import sys -import tempfile -from typing import Any -from typing import cast -from typing import List -from typing import Tuple from sqlalchemy import testing -from sqlalchemy import util -from sqlalchemy.testing import config from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures -def _file_combinations(dirname): - path = os.path.join(os.path.dirname(__file__), dirname) - files = [] - for f in os.listdir(path): - if f.endswith(".py"): - files.append(os.path.join(os.path.dirname(__file__), dirname, f)) - - for extra_dir in testing.config.options.mypy_extra_test_paths: - if extra_dir and os.path.isdir(extra_dir): - for f in os.listdir(os.path.join(extra_dir, dirname)): - if f.endswith(".py"): - files.append(os.path.join(extra_dir, dirname, f)) - return files - - def _incremental_dirs(): path = os.path.join(os.path.dirname(__file__), "incremental") files = [] @@ -47,99 +23,9 @@ def _incremental_dirs(): return files -@testing.add_to_marker.mypy -class MypyPluginTest(fixtures.TestBase): - __tags__ = ("mypy",) - __requires__ = ("no_sqlalchemy2_stubs",) - - @testing.fixture(scope="function") - def per_func_cachedir(self): - yield from self._cachedir() - - @testing.fixture(scope="class") - def cachedir(self): - yield from self._cachedir() - - def _cachedir(self): - # as of mypy 0.971 i think we need to keep mypy_path empty - mypy_path = "" - - with tempfile.TemporaryDirectory() as cachedir: - with open( - os.path.join(cachedir, "sqla_mypy_config.cfg"), "w" - ) as config_file: - config_file.write( - f""" - [mypy]\n - plugins = sqlalchemy.ext.mypy.plugin\n - show_error_codes = True\n - {mypy_path} - disable_error_code = no-untyped-call - - [mypy-sqlalchemy.*] - ignore_errors = True - - """ - ) - with open( - os.path.join(cachedir, "plain_mypy_config.cfg"), "w" - ) as config_file: - config_file.write( - f""" - [mypy]\n - show_error_codes = True\n - {mypy_path} - disable_error_code = var-annotated,no-untyped-call - [mypy-sqlalchemy.*] - ignore_errors = True - - """ - ) - yield cachedir - - @testing.fixture() - def mypy_runner(self, cachedir): - from mypy import api - - def run(path, use_plugin=True, incremental=False): - args = [ - "--strict", - "--raise-exceptions", - "--cache-dir", - cachedir, - "--config-file", - os.path.join( - cachedir, - "sqla_mypy_config.cfg" - if use_plugin - else "plain_mypy_config.cfg", - ), - ] - - # mypy as of 0.990 is more aggressively blocking messaging - # for paths that are in sys.path, and as pytest puts currdir, - # test/ etc in sys.path, just copy the source file to the - # tempdir we are working in so that we don't have to try to - # manipulate sys.path and/or guess what mypy is doing - filename = os.path.basename(path) - test_program = os.path.join(cachedir, filename) - shutil.copyfile(path, test_program) - args.append(test_program) - - # I set this locally but for the suite here needs to be - # disabled - os.environ.pop("MYPY_FORCE_COLOR", None) - - result = api.run(args) - return result - - return run - +class MypyPluginTest(fixtures.MypyTest): @testing.combinations( - *[ - (pathname, testing.exclusions.closed()) - for pathname in _incremental_dirs() - ], + *[(pathname) for pathname in _incremental_dirs()], argnames="pathname", ) @testing.requires.patch_library @@ -175,7 +61,7 @@ class MypyPluginTest(fixtures.TestBase): result = mypy_runner( dest, use_plugin=True, - incremental=True, + use_cachedir=cachedir, ) eq_( result[2], @@ -186,191 +72,11 @@ class MypyPluginTest(fixtures.TestBase): @testing.combinations( *( - cast( - List[Tuple[Any, ...]], - [ - ("w_plugin", os.path.basename(path), path, True) - for path in _file_combinations("plugin_files") - ], - ) - + cast( - List[Tuple[Any, ...]], - [ - ("plain", os.path.basename(path), path, False) - for path in _file_combinations("plain_files") - ], - ) + (os.path.basename(path), path, True) + for path in fixtures.MypyTest.file_combinations("plugin_files") ), - argnames="filename,path,use_plugin", - id_="isaa", + argnames="path", + id_="ia", ) - def test_files(self, mypy_runner, filename, path, use_plugin): - expected_messages = [] - expected_re = re.compile(r"\s*# EXPECTED(_MYPY)?(_RE)?(_TYPE)?: (.+)") - py_ver_re = re.compile(r"^#\s*PYTHON_VERSION\s?>=\s?(\d+\.\d+)") - - from sqlalchemy.ext.mypy.util import mypy_14 - - with open(path) as file_: - current_assert_messages = [] - for num, line in enumerate(file_, 1): - m = py_ver_re.match(line) - if m: - major, _, minor = m.group(1).partition(".") - if sys.version_info < (int(major), int(minor)): - config.skip_test( - "Requires python >= %s" % (m.group(1)) - ) - continue - if line.startswith("# NOPLUGINS"): - use_plugin = False - continue - - m = expected_re.match(line) - if m: - is_mypy = bool(m.group(1)) - is_re = bool(m.group(2)) - is_type = bool(m.group(3)) - - expected_msg = re.sub(r"# noqa[:]? ?.*", "", m.group(4)) - - if is_type: - if not is_re: - # the goal here is that we can cut-and-paste - # from vscode -> pylance into the - # EXPECTED_TYPE: line, then the test suite will - # validate that line against what mypy produces - expected_msg = re.sub( - r"([\[\]])", - lambda m: rf"\{m.group(0)}", - expected_msg, - ) - - # note making sure preceding text matches - # with a dot, so that an expect for "Select" - # does not match "TypedSelect" - expected_msg = re.sub( - r"([\w_]+)", - lambda m: rf"(?:.*\.)?{m.group(1)}\*?", - expected_msg, - ) - - expected_msg = re.sub( - "List", "builtins.list", expected_msg - ) - - expected_msg = re.sub( - r"\b(int|str|float|bool)\b", - lambda m: rf"builtins.{m.group(0)}\*?", - expected_msg, - ) - # expected_msg = re.sub( - # r"(Sequence|Tuple|List|Union)", - # lambda m: fr"typing.{m.group(0)}\*?", - # expected_msg, - # ) - - is_mypy = is_re = True - expected_msg = f'Revealed type is "{expected_msg}"' - - if mypy_14 and util.py39: - # use_lowercase_names, py39 and above - # https://github.com/python/mypy/blob/304997bfb85200fb521ac727ee0ce3e6085e5278/mypy/options.py#L363 # noqa: E501 - - # skip first character which could be capitalized - # "List item x not found" type of message - expected_msg = expected_msg[0] + re.sub( - r"\b(List|Tuple|Dict|Set)\b" - if is_type - else r"\b(List|Tuple|Dict|Set|Type)\b", - lambda m: m.group(1).lower(), - expected_msg[1:], - ) - - if mypy_14 and util.py310: - # use_or_syntax, py310 and above - # https://github.com/python/mypy/blob/304997bfb85200fb521ac727ee0ce3e6085e5278/mypy/options.py#L368 # noqa: E501 - expected_msg = re.sub( - r"Optional\[(.*?)\]", - lambda m: f"{m.group(1)} | None", - expected_msg, - ) - current_assert_messages.append( - (is_mypy, is_re, expected_msg.strip()) - ) - elif current_assert_messages: - expected_messages.extend( - (num, is_mypy, is_re, expected_msg) - for ( - is_mypy, - is_re, - expected_msg, - ) in current_assert_messages - ) - current_assert_messages[:] = [] - - result = mypy_runner(path, use_plugin=use_plugin) - - not_located = [] - - if expected_messages: - # mypy 0.990 changed how return codes work, so don't assume a - # 1 or a 0 return code here, could be either depending on if - # errors were generated or not - - output = [] - - raw_lines = result[0].split("\n") - while raw_lines: - e = raw_lines.pop(0) - if re.match(r".+\.py:\d+: error: .*", e): - output.append(("error", e)) - elif re.match( - r".+\.py:\d+: note: +(?:Possible overload|def ).*", e - ): - while raw_lines: - ol = raw_lines.pop(0) - if not re.match(r".+\.py:\d+: note: +def \[.*", ol): - break - elif re.match( - r".+\.py:\d+: note: .*(?:perhaps|suggestion)", e, re.I - ): - pass - elif re.match(r".+\.py:\d+: note: .*", e): - output.append(("note", e)) - - for num, is_mypy, is_re, msg in expected_messages: - msg = msg.replace("'", '"') - prefix = "[SQLAlchemy Mypy plugin] " if not is_mypy else "" - for idx, (typ, errmsg) in enumerate(output): - if is_re: - if re.match( - rf".*{filename}\:{num}\: {typ}\: {prefix}{msg}", # noqa: E501 - errmsg, - ): - break - elif ( - f"{filename}:{num}: {typ}: {prefix}{msg}" - in errmsg.replace("'", '"') - ): - break - else: - not_located.append(msg) - continue - del output[idx] - - if not_located: - print(f"Couldn't locate expected messages: {not_located}") - print("\n".join(msg for _, msg in output)) - assert False, "expected messages not found, see stdout" - - if output: - print(f"{len(output)} messages from mypy were not consumed:") - print("\n".join(msg for _, msg in output)) - assert False, "errors and/or notes remain, see stdout" - - else: - if result[2] != 0: - print(result[0]) - - eq_(result[2], 0, msg=result) + def test_plugin_files(self, mypy_typecheck_file, path): + mypy_typecheck_file(path, use_plugin=True) diff --git a/test/ext/test_associationproxy.py b/test/ext/test_associationproxy.py index d7b7b0bb20..abf8efec71 100644 --- a/test/ext/test_associationproxy.py +++ b/test/ext/test_associationproxy.py @@ -45,6 +45,7 @@ from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ from sqlalchemy.testing import is_false from sqlalchemy.testing.assertions import expect_raises_message +from sqlalchemy.testing.entities import ComparableEntity # noqa from sqlalchemy.testing.entities import ComparableMixin # noqa from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column @@ -2451,7 +2452,7 @@ class CompositeAccessTest(fixtures.DeclarativeMappedTest): creator=lambda point: PointData(point=point), ) - class PointData(fixtures.ComparableEntity, cls.DeclarativeBasic): + class PointData(ComparableEntity, cls.DeclarativeBasic): __tablename__ = "point" id = Column( diff --git a/test/ext/test_indexable.py b/test/ext/test_indexable.py index e68f9c0351..4421c3a6ed 100644 --- a/test/ext/test_indexable.py +++ b/test/ext/test_indexable.py @@ -15,6 +15,7 @@ from sqlalchemy.testing import in_ from sqlalchemy.testing import is_ from sqlalchemy.testing import ne_ from sqlalchemy.testing import not_in +from sqlalchemy.testing.entities import ComparableEntity from sqlalchemy.testing.schema import Column @@ -176,7 +177,7 @@ class IndexPropertyArrayTest(fixtures.DeclarativeMappedTest): def setup_classes(cls): Base = cls.DeclarativeBasic - class Array(fixtures.ComparableEntity, Base): + class Array(ComparableEntity, Base): __tablename__ = "array" id = Column( @@ -270,7 +271,7 @@ class IndexPropertyJsonTest(fixtures.DeclarativeMappedTest): expr = super().expr(model) return expr.astext.cast(self.cast_type) - class Json(fixtures.ComparableEntity, Base): + class Json(ComparableEntity, Base): __tablename__ = "json" id = Column( diff --git a/test/ext/test_mutable.py b/test/ext/test_mutable.py index 6c428fa854..dffdac8d84 100644 --- a/test/ext/test_mutable.py +++ b/test/ext/test_mutable.py @@ -35,6 +35,7 @@ from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ from sqlalchemy.testing import is_true from sqlalchemy.testing import mock +from sqlalchemy.testing.entities import BasicEntity from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table @@ -44,7 +45,7 @@ from sqlalchemy.types import TypeDecorator from sqlalchemy.types import VARCHAR -class Foo(fixtures.BasicEntity): +class Foo(BasicEntity): pass @@ -52,7 +53,7 @@ class SubFoo(Foo): pass -class Foo2(fixtures.BasicEntity): +class Foo2(BasicEntity): pass @@ -68,7 +69,7 @@ class FooWithEq: return self.id == other.id -class FooWNoHash(fixtures.BasicEntity): +class FooWNoHash(BasicEntity): __hash__ = None diff --git a/test/ext/test_serializer.py b/test/ext/test_serializer.py index 8318484c02..a52c59e2d3 100644 --- a/test/ext/test_serializer.py +++ b/test/ext/test_serializer.py @@ -20,6 +20,7 @@ from sqlalchemy.orm import sessionmaker from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures +from sqlalchemy.testing.entities import ComparableEntity from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table @@ -29,11 +30,11 @@ def pickle_protocols(): # return iter([-1, 0, 1, 2]) -class User(fixtures.ComparableEntity): +class User(ComparableEntity): pass -class Address(fixtures.ComparableEntity): +class Address(ComparableEntity): pass diff --git a/test/orm/declarative/test_basic.py b/test/orm/declarative/test_basic.py index 985b600f0d..7085b2af9f 100644 --- a/test/orm/declarative/test_basic.py +++ b/test/orm/declarative/test_basic.py @@ -59,6 +59,7 @@ from sqlalchemy.testing import expect_warnings from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ from sqlalchemy.testing import mock +from sqlalchemy.testing.entities import ComparableEntity from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table @@ -517,7 +518,7 @@ class DeclarativeBaseSetupsTest(fixtures.TestBase): xyzzy = "magic" # _as_declarative() inspects obj.__class__.__bases__ - class User(BrokenParent, fixtures.ComparableEntity): + class User(BrokenParent, ComparableEntity): __tablename__ = "users" id = Column( "id", Integer, primary_key=True, test_needs_autoincrement=True @@ -707,7 +708,7 @@ class DeclarativeBaseSetupsTest(fixtures.TestBase): assert Base().foobar() == "foobar" def test_as_declarative(self, metadata): - class User(fixtures.ComparableEntity): + class User(ComparableEntity): __tablename__ = "users" id = Column( "id", Integer, primary_key=True, test_needs_autoincrement=True @@ -715,7 +716,7 @@ class DeclarativeBaseSetupsTest(fixtures.TestBase): name = Column("name", String(50)) addresses = relationship("Address", backref="user") - class Address(fixtures.ComparableEntity): + class Address(ComparableEntity): __tablename__ = "addresses" id = Column( "id", Integer, primary_key=True, test_needs_autoincrement=True @@ -746,7 +747,7 @@ class DeclarativeBaseSetupsTest(fixtures.TestBase): ) def test_map_declaratively(self, metadata): - class User(fixtures.ComparableEntity): + class User(ComparableEntity): __tablename__ = "users" id = Column( "id", Integer, primary_key=True, test_needs_autoincrement=True @@ -754,7 +755,7 @@ class DeclarativeBaseSetupsTest(fixtures.TestBase): name = Column("name", String(50)) addresses = relationship("Address", backref="user") - class Address(fixtures.ComparableEntity): + class Address(ComparableEntity): __tablename__ = "addresses" id = Column( "id", Integer, primary_key=True, test_needs_autoincrement=True @@ -1105,7 +1106,7 @@ class DeclarativeMultiBaseTest( testing.config.skip_test("current base has no metaclass") def test_basic(self): - class User(Base, fixtures.ComparableEntity): + class User(Base, ComparableEntity): __tablename__ = "users" id = Column( @@ -1114,7 +1115,7 @@ class DeclarativeMultiBaseTest( name = Column("name", String(50)) addresses = relationship("Address", backref="user") - class Address(Base, fixtures.ComparableEntity): + class Address(Base, ComparableEntity): __tablename__ = "addresses" id = Column( @@ -1263,7 +1264,7 @@ class DeclarativeMultiBaseTest( ) def test_unicode_string_resolve(self): - class User(Base, fixtures.ComparableEntity): + class User(Base, ComparableEntity): __tablename__ = "users" id = Column( @@ -1272,7 +1273,7 @@ class DeclarativeMultiBaseTest( name = Column("name", String(50)) addresses = relationship("Address", backref="user") - class Address(Base, fixtures.ComparableEntity): + class Address(Base, ComparableEntity): __tablename__ = "addresses" id = Column( @@ -1286,7 +1287,7 @@ class DeclarativeMultiBaseTest( assert User.addresses.property.mapper.class_ is Address def test_unicode_string_resolve_backref(self): - class User(Base, fixtures.ComparableEntity): + class User(Base, ComparableEntity): __tablename__ = "users" id = Column( @@ -1294,7 +1295,7 @@ class DeclarativeMultiBaseTest( ) name = Column("name", String(50)) - class Address(Base, fixtures.ComparableEntity): + class Address(Base, ComparableEntity): __tablename__ = "addresses" id = Column( @@ -1729,7 +1730,7 @@ class DeclarativeMultiBaseTest( assert User.__mapper__.registry._new_mappers is False def test_string_dependency_resolution(self): - class User(Base, fixtures.ComparableEntity): + class User(Base, ComparableEntity): __tablename__ = "users" id = Column( Integer, primary_key=True, test_needs_autoincrement=True @@ -1747,7 +1748,7 @@ class DeclarativeMultiBaseTest( ), ) - class Address(Base, fixtures.ComparableEntity): + class Address(Base, ComparableEntity): __tablename__ = "addresses" id = Column( Integer, primary_key=True, test_needs_autoincrement=True @@ -1780,7 +1781,7 @@ class DeclarativeMultiBaseTest( ), ) - class Foo(Base, fixtures.ComparableEntity): + class Foo(Base, ComparableEntity): __tablename__ = "foo" id = Column(Integer, primary_key=True) rel = relationship("User", primaryjoin="User.addresses==Foo.id") @@ -1792,7 +1793,7 @@ class DeclarativeMultiBaseTest( ) def test_string_dependency_resolution_synonym(self): - class User(Base, fixtures.ComparableEntity): + class User(Base, ComparableEntity): __tablename__ = "users" id = Column( Integer, primary_key=True, test_needs_autoincrement=True @@ -1807,7 +1808,7 @@ class DeclarativeMultiBaseTest( sess.expunge_all() eq_(sess.query(User).filter(User.name == "ed").one(), User(name="ed")) - class Foo(Base, fixtures.ComparableEntity): + class Foo(Base, ComparableEntity): __tablename__ = "foo" id = Column(Integer, primary_key=True) _user_id = Column(Integer) @@ -1902,14 +1903,14 @@ class DeclarativeMultiBaseTest( ) def test_string_dependency_resolution_no_table(self): - class User(Base, fixtures.ComparableEntity): + class User(Base, ComparableEntity): __tablename__ = "users" id = Column( Integer, primary_key=True, test_needs_autoincrement=True ) name = Column(String(50)) - class Bar(Base, fixtures.ComparableEntity): + class Bar(Base, ComparableEntity): __tablename__ = "bar" id = Column(Integer, primary_key=True) rel = relationship("User", primaryjoin="User.id==Bar.__table__.id") @@ -1921,14 +1922,14 @@ class DeclarativeMultiBaseTest( ) def test_string_w_pj_annotations(self): - class User(Base, fixtures.ComparableEntity): + class User(Base, ComparableEntity): __tablename__ = "users" id = Column( Integer, primary_key=True, test_needs_autoincrement=True ) name = Column(String(50)) - class Address(Base, fixtures.ComparableEntity): + class Address(Base, ComparableEntity): __tablename__ = "addresses" id = Column( Integer, primary_key=True, test_needs_autoincrement=True @@ -1947,7 +1948,7 @@ class DeclarativeMultiBaseTest( def test_string_dependency_resolution_no_magic(self): """test that full tinkery expressions work as written""" - class User(Base, fixtures.ComparableEntity): + class User(Base, ComparableEntity): __tablename__ = "users" id = Column(Integer, primary_key=True) addresses = relationship( @@ -1955,7 +1956,7 @@ class DeclarativeMultiBaseTest( primaryjoin="User.id==Address.user_id.prop.columns[0]", ) - class Address(Base, fixtures.ComparableEntity): + class Address(Base, ComparableEntity): __tablename__ = "addresses" id = Column(Integer, primary_key=True) user_id = Column(Integer, ForeignKey("users.id")) @@ -1967,7 +1968,7 @@ class DeclarativeMultiBaseTest( ) def test_string_dependency_resolution_module_qualified(self): - class User(Base, fixtures.ComparableEntity): + class User(Base, ComparableEntity): __tablename__ = "users" id = Column(Integer, primary_key=True) addresses = relationship( @@ -1976,7 +1977,7 @@ class DeclarativeMultiBaseTest( % (__name__, __name__), ) - class Address(Base, fixtures.ComparableEntity): + class Address(Base, ComparableEntity): __tablename__ = "addresses" id = Column(Integer, primary_key=True) user_id = Column(Integer, ForeignKey("users.id")) @@ -1988,7 +1989,7 @@ class DeclarativeMultiBaseTest( ) def test_string_dependency_resolution_in_backref(self): - class User(Base, fixtures.ComparableEntity): + class User(Base, ComparableEntity): __tablename__ = "users" id = Column(Integer, primary_key=True) name = Column(String(50)) @@ -1998,7 +1999,7 @@ class DeclarativeMultiBaseTest( backref="user", ) - class Address(Base, fixtures.ComparableEntity): + class Address(Base, ComparableEntity): __tablename__ = "addresses" id = Column(Integer, primary_key=True) email = Column(String(50)) @@ -2011,7 +2012,7 @@ class DeclarativeMultiBaseTest( ) def test_string_dependency_resolution_tables(self): - class User(Base, fixtures.ComparableEntity): + class User(Base, ComparableEntity): __tablename__ = "users" id = Column(Integer, primary_key=True) name = Column(String(50)) @@ -2023,7 +2024,7 @@ class DeclarativeMultiBaseTest( backref="users", ) - class Prop(Base, fixtures.ComparableEntity): + class Prop(Base, ComparableEntity): __tablename__ = "props" id = Column(Integer, primary_key=True) name = Column(String(50)) @@ -2042,7 +2043,7 @@ class DeclarativeMultiBaseTest( def test_string_dependency_resolution_table_over_class(self): # test for second half of #5774 - class User(Base, fixtures.ComparableEntity): + class User(Base, ComparableEntity): __tablename__ = "users" id = Column(Integer, primary_key=True) name = Column(String(50)) @@ -2052,7 +2053,7 @@ class DeclarativeMultiBaseTest( backref="users", ) - class Prop(Base, fixtures.ComparableEntity): + class Prop(Base, ComparableEntity): __tablename__ = "props" id = Column(Integer, primary_key=True) name = Column(String(50)) @@ -2071,7 +2072,7 @@ class DeclarativeMultiBaseTest( def test_string_dependency_resolution_class_over_table(self): # test for second half of #5774 - class User(Base, fixtures.ComparableEntity): + class User(Base, ComparableEntity): __tablename__ = "users" id = Column(Integer, primary_key=True) name = Column(String(50)) @@ -2091,7 +2092,7 @@ class DeclarativeMultiBaseTest( ) def test_uncompiled_attributes_in_relationship(self): - class Address(Base, fixtures.ComparableEntity): + class Address(Base, ComparableEntity): __tablename__ = "addresses" id = Column( Integer, primary_key=True, test_needs_autoincrement=True @@ -2099,7 +2100,7 @@ class DeclarativeMultiBaseTest( email = Column(String(50)) user_id = Column(Integer, ForeignKey("users.id")) - class User(Base, fixtures.ComparableEntity): + class User(Base, ComparableEntity): __tablename__ = "users" id = Column( Integer, primary_key=True, test_needs_autoincrement=True @@ -2276,14 +2277,14 @@ class DeclarativeMultiBaseTest( def test_add_prop_auto( self, require_metaclass, assert_user_address_mapping, _column ): - class User(Base, fixtures.ComparableEntity): + class User(Base, ComparableEntity): __tablename__ = "users" id = Column("id", Integer, primary_key=True) User.name = _column("name", String(50)) User.addresses = relationship("Address", backref="user") - class Address(Base, fixtures.ComparableEntity): + class Address(Base, ComparableEntity): __tablename__ = "addresses" id = _column(Integer, primary_key=True) @@ -2300,7 +2301,7 @@ class DeclarativeMultiBaseTest( @testing.combinations(Column, mapped_column, argnames="_column") def test_add_prop_manual(self, assert_user_address_mapping, _column): - class User(Base, fixtures.ComparableEntity): + class User(Base, ComparableEntity): __tablename__ = "users" id = _column("id", Integer, primary_key=True) @@ -2309,7 +2310,7 @@ class DeclarativeMultiBaseTest( User, "addresses", relationship("Address", backref="user") ) - class Address(Base, fixtures.ComparableEntity): + class Address(Base, ComparableEntity): __tablename__ = "addresses" id = _column(Integer, primary_key=True) @@ -2404,7 +2405,7 @@ class DeclarativeMultiBaseTest( A(brap=B()) def test_eager_order_by(self): - class Address(Base, fixtures.ComparableEntity): + class Address(Base, ComparableEntity): __tablename__ = "addresses" id = Column( "id", Integer, primary_key=True, test_needs_autoincrement=True @@ -2412,7 +2413,7 @@ class DeclarativeMultiBaseTest( email = Column("email", String(50)) user_id = Column("user_id", Integer, ForeignKey("users.id")) - class User(Base, fixtures.ComparableEntity): + class User(Base, ComparableEntity): __tablename__ = "users" id = Column( "id", Integer, primary_key=True, test_needs_autoincrement=True @@ -2439,7 +2440,7 @@ class DeclarativeMultiBaseTest( ) def test_order_by_multi(self): - class Address(Base, fixtures.ComparableEntity): + class Address(Base, ComparableEntity): __tablename__ = "addresses" id = Column( "id", Integer, primary_key=True, test_needs_autoincrement=True @@ -2447,7 +2448,7 @@ class DeclarativeMultiBaseTest( email = Column("email", String(50)) user_id = Column("user_id", Integer, ForeignKey("users.id")) - class User(Base, fixtures.ComparableEntity): + class User(Base, ComparableEntity): __tablename__ = "users" id = Column( "id", Integer, primary_key=True, test_needs_autoincrement=True @@ -2473,7 +2474,7 @@ class DeclarativeMultiBaseTest( "Ignoring declarative-like tuple value of " "attribute 'name'" ): - class User(Base, fixtures.ComparableEntity): + class User(Base, ComparableEntity): __tablename__ = "users" id = Column("id", Integer, primary_key=True) name = (Column("name", String(50)),) @@ -2573,7 +2574,7 @@ class DeclarativeMultiBaseTest( is_(inspect(Employee).local_table, Person.__table__) def test_expression(self, require_metaclass): - class User(Base, fixtures.ComparableEntity): + class User(Base, ComparableEntity): __tablename__ = "users" id = Column( "id", Integer, primary_key=True, test_needs_autoincrement=True @@ -2581,7 +2582,7 @@ class DeclarativeMultiBaseTest( name = Column("name", String(50)) addresses = relationship("Address", backref="user") - class Address(Base, fixtures.ComparableEntity): + class Address(Base, ComparableEntity): __tablename__ = "addresses" id = Column( "id", Integer, primary_key=True, test_needs_autoincrement=True @@ -2614,7 +2615,7 @@ class DeclarativeMultiBaseTest( ) def test_useless_declared_attr(self): - class Address(Base, fixtures.ComparableEntity): + class Address(Base, ComparableEntity): __tablename__ = "addresses" id = Column( "id", Integer, primary_key=True, test_needs_autoincrement=True @@ -2622,7 +2623,7 @@ class DeclarativeMultiBaseTest( email = Column("email", String(50)) user_id = Column("user_id", Integer, ForeignKey("users.id")) - class User(Base, fixtures.ComparableEntity): + class User(Base, ComparableEntity): __tablename__ = "users" id = Column( "id", Integer, primary_key=True, test_needs_autoincrement=True @@ -2696,7 +2697,7 @@ class DeclarativeMultiBaseTest( return Column(Integer) def test_column(self, require_metaclass): - class User(Base, fixtures.ComparableEntity): + class User(Base, ComparableEntity): __tablename__ = "users" id = Column( "id", Integer, primary_key=True, test_needs_autoincrement=True @@ -2736,7 +2737,7 @@ class DeclarativeMultiBaseTest( eq_(Foo.d.impl.active_history, False) def test_column_properties(self): - class Address(Base, fixtures.ComparableEntity): + class Address(Base, ComparableEntity): __tablename__ = "addresses" id = Column( Integer, primary_key=True, test_needs_autoincrement=True @@ -2744,7 +2745,7 @@ class DeclarativeMultiBaseTest( email = Column(String(50)) user_id = Column(Integer, ForeignKey("users.id")) - class User(Base, fixtures.ComparableEntity): + class User(Base, ComparableEntity): __tablename__ = "users" id = Column( "id", Integer, primary_key=True, test_needs_autoincrement=True @@ -2778,13 +2779,13 @@ class DeclarativeMultiBaseTest( ) def test_column_properties_2(self): - class Address(Base, fixtures.ComparableEntity): + class Address(Base, ComparableEntity): __tablename__ = "addresses" id = Column(Integer, primary_key=True) email = Column(String(50)) user_id = Column(Integer, ForeignKey("users.id")) - class User(Base, fixtures.ComparableEntity): + class User(Base, ComparableEntity): __tablename__ = "users" id = Column("id", Integer, primary_key=True) name = Column("name", String(50)) @@ -2798,7 +2799,7 @@ class DeclarativeMultiBaseTest( eq_(set(Address.__table__.c.keys()), {"id", "email", "user_id"}) def test_deferred(self): - class User(Base, fixtures.ComparableEntity): + class User(Base, ComparableEntity): __tablename__ = "users" id = Column( Integer, primary_key=True, test_needs_autoincrement=True @@ -2819,7 +2820,7 @@ class DeclarativeMultiBaseTest( self.assert_sql_count(testing.db, go, 1) def test_composite_inline(self): - class AddressComposite(fixtures.ComparableEntity): + class AddressComposite(ComparableEntity): def __init__(self, street, state): self.street = street self.state = state @@ -2827,7 +2828,7 @@ class DeclarativeMultiBaseTest( def __composite_values__(self): return [self.street, self.state] - class User(Base, fixtures.ComparableEntity): + class User(Base, ComparableEntity): __tablename__ = "user" id = Column( Integer, primary_key=True, test_needs_autoincrement=True @@ -2848,7 +2849,7 @@ class DeclarativeMultiBaseTest( ) def test_composite_separate(self): - class AddressComposite(fixtures.ComparableEntity): + class AddressComposite(ComparableEntity): def __init__(self, street, state): self.street = street self.state = state @@ -2856,7 +2857,7 @@ class DeclarativeMultiBaseTest( def __composite_values__(self): return [self.street, self.state] - class User(Base, fixtures.ComparableEntity): + class User(Base, ComparableEntity): __tablename__ = "user" id = Column( Integer, primary_key=True, test_needs_autoincrement=True @@ -2903,7 +2904,7 @@ class DeclarativeMultiBaseTest( ) def test_synonym_inline(self): - class User(Base, fixtures.ComparableEntity): + class User(Base, ComparableEntity): __tablename__ = "users" id = Column( "id", Integer, primary_key=True, test_needs_autoincrement=True @@ -2939,7 +2940,7 @@ class DeclarativeMultiBaseTest( def __eq__(self, other): return self.__clause_element__() == other + " FOO" - class User(Base, fixtures.ComparableEntity): + class User(Base, ComparableEntity): __tablename__ = "users" id = Column( "id", Integer, primary_key=True, test_needs_autoincrement=True @@ -2955,7 +2956,7 @@ class DeclarativeMultiBaseTest( eq_(sess.query(User).filter(User.name == "someuser").one(), u1) def test_synonym_added(self, require_metaclass): - class User(Base, fixtures.ComparableEntity): + class User(Base, ComparableEntity): __tablename__ = "users" id = Column( "id", Integer, primary_key=True, test_needs_autoincrement=True @@ -2982,7 +2983,7 @@ class DeclarativeMultiBaseTest( ) def test_reentrant_compile_via_foreignkey(self): - class User(Base, fixtures.ComparableEntity): + class User(Base, ComparableEntity): __tablename__ = "users" id = Column( "id", Integer, primary_key=True, test_needs_autoincrement=True @@ -2990,7 +2991,7 @@ class DeclarativeMultiBaseTest( name = Column("name", String(50)) addresses = relationship("Address", backref="user") - class Address(Base, fixtures.ComparableEntity): + class Address(Base, ComparableEntity): __tablename__ = "addresses" id = Column( "id", Integer, primary_key=True, test_needs_autoincrement=True @@ -3026,7 +3027,7 @@ class DeclarativeMultiBaseTest( ) def test_relationship_reference(self, require_metaclass): - class Address(Base, fixtures.ComparableEntity): + class Address(Base, ComparableEntity): __tablename__ = "addresses" id = Column( "id", Integer, primary_key=True, test_needs_autoincrement=True @@ -3034,7 +3035,7 @@ class DeclarativeMultiBaseTest( email = Column("email", String(50)) user_id = Column("user_id", Integer, ForeignKey("users.id")) - class User(Base, fixtures.ComparableEntity): + class User(Base, ComparableEntity): __tablename__ = "users" id = Column( "id", Integer, primary_key=True, test_needs_autoincrement=True @@ -3105,7 +3106,7 @@ class DeclarativeMultiBaseTest( eq_(sess.execute(t1.select()).fetchall(), [("someid", "somedata")]) def test_synonym_for(self): - class User(Base, fixtures.ComparableEntity): + class User(Base, ComparableEntity): __tablename__ = "users" id = Column( "id", Integer, primary_key=True, test_needs_autoincrement=True diff --git a/test/orm/declarative/test_inheritance.py b/test/orm/declarative/test_inheritance.py index 79639ed9ce..c5b908cd82 100644 --- a/test/orm/declarative/test_inheritance.py +++ b/test/orm/declarative/test_inheritance.py @@ -28,6 +28,7 @@ from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ from sqlalchemy.testing import is_false from sqlalchemy.testing import is_true +from sqlalchemy.testing.entities import ComparableEntity from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table @@ -147,7 +148,7 @@ class DeclarativeInheritanceTest( configure_mappers() def test_joined(self): - class Company(Base, fixtures.ComparableEntity): + class Company(Base, ComparableEntity): __tablename__ = "companies" id = Column( "id", Integer, primary_key=True, test_needs_autoincrement=True @@ -155,7 +156,7 @@ class DeclarativeInheritanceTest( name = Column("name", String(50)) employees = relationship("Person") - class Person(Base, fixtures.ComparableEntity): + class Person(Base, ComparableEntity): __tablename__ = "people" id = Column( "id", Integer, primary_key=True, test_needs_autoincrement=True @@ -245,7 +246,7 @@ class DeclarativeInheritanceTest( self.assert_sql_count(testing.db, go, 1) def test_add_subcol_after_the_fact(self): - class Person(Base, fixtures.ComparableEntity): + class Person(Base, ComparableEntity): __tablename__ = "people" id = Column( "id", Integer, primary_key=True, test_needs_autoincrement=True @@ -274,7 +275,7 @@ class DeclarativeInheritanceTest( ) def test_add_parentcol_after_the_fact(self): - class Person(Base, fixtures.ComparableEntity): + class Person(Base, ComparableEntity): __tablename__ = "people" id = Column( "id", Integer, primary_key=True, test_needs_autoincrement=True @@ -303,7 +304,7 @@ class DeclarativeInheritanceTest( ) def test_add_sub_parentcol_after_the_fact(self): - class Person(Base, fixtures.ComparableEntity): + class Person(Base, ComparableEntity): __tablename__ = "people" id = Column( "id", Integer, primary_key=True, test_needs_autoincrement=True @@ -340,7 +341,7 @@ class DeclarativeInheritanceTest( ) def test_subclass_mixin(self): - class Person(Base, fixtures.ComparableEntity): + class Person(Base, ComparableEntity): __tablename__ = "people" id = Column("id", Integer, primary_key=True) name = Column("name", String(50)) @@ -532,7 +533,7 @@ class DeclarativeInheritanceTest( """test single inheritance where all the columns are on the base class.""" - class Company(Base, fixtures.ComparableEntity): + class Company(Base, ComparableEntity): __tablename__ = "companies" id = Column( "id", Integer, primary_key=True, test_needs_autoincrement=True @@ -540,7 +541,7 @@ class DeclarativeInheritanceTest( name = Column("name", String(50)) employees = relationship("Person") - class Person(Base, fixtures.ComparableEntity): + class Person(Base, ComparableEntity): __tablename__ = "people" id = Column( "id", Integer, primary_key=True, test_needs_autoincrement=True @@ -604,7 +605,7 @@ class DeclarativeInheritanceTest( """ - class Company(Base, fixtures.ComparableEntity): + class Company(Base, ComparableEntity): __tablename__ = "companies" id = Column( "id", Integer, primary_key=True, test_needs_autoincrement=True @@ -612,7 +613,7 @@ class DeclarativeInheritanceTest( name = Column("name", String(50)) employees = relationship("Person") - class Person(Base, fixtures.ComparableEntity): + class Person(Base, ComparableEntity): __tablename__ = "people" id = Column( Integer, primary_key=True, test_needs_autoincrement=True @@ -787,7 +788,7 @@ class DeclarativeInheritanceTest( def test_single_constraint_on_sub(self): """test the somewhat unusual case of [ticket:3341]""" - class Person(Base, fixtures.ComparableEntity): + class Person(Base, ComparableEntity): __tablename__ = "people" id = Column( Integer, primary_key=True, test_needs_autoincrement=True @@ -1147,7 +1148,7 @@ class DeclarativeInheritanceTest( is_(Manager.id.property.columns[0], Person.__table__.c.id) def test_joined_from_single(self): - class Company(Base, fixtures.ComparableEntity): + class Company(Base, ComparableEntity): __tablename__ = "companies" id = Column( "id", Integer, primary_key=True, test_needs_autoincrement=True @@ -1155,7 +1156,7 @@ class DeclarativeInheritanceTest( name = Column("name", String(50)) employees = relationship("Person") - class Person(Base, fixtures.ComparableEntity): + class Person(Base, ComparableEntity): __tablename__ = "people" id = Column( Integer, primary_key=True, test_needs_autoincrement=True @@ -1220,7 +1221,7 @@ class DeclarativeInheritanceTest( ) def test_single_from_joined_colsonsub(self): - class Person(Base, fixtures.ComparableEntity): + class Person(Base, ComparableEntity): __tablename__ = "people" id = Column( Integer, primary_key=True, test_needs_autoincrement=True @@ -1268,7 +1269,7 @@ class DeclarativeInheritanceTest( is_(B.__mapper__.polymorphic_on, A.__table__.c.discriminator) def test_add_deferred(self): - class Person(Base, fixtures.ComparableEntity): + class Person(Base, ComparableEntity): __tablename__ = "people" id = Column( "id", Integer, primary_key=True, test_needs_autoincrement=True @@ -1292,7 +1293,7 @@ class DeclarativeInheritanceTest( """ - class Person(Base, fixtures.ComparableEntity): + class Person(Base, ComparableEntity): __tablename__ = "people" id = Column( Integer, primary_key=True, test_needs_autoincrement=True @@ -1306,7 +1307,7 @@ class DeclarativeInheritanceTest( primary_language_id = Column(Integer, ForeignKey("languages.id")) primary_language = relationship("Language") - class Language(Base, fixtures.ComparableEntity): + class Language(Base, ComparableEntity): __tablename__ = "languages" id = Column( Integer, primary_key=True, test_needs_autoincrement=True @@ -1354,7 +1355,7 @@ class DeclarativeInheritanceTest( ) def test_single_three_levels(self): - class Person(Base, fixtures.ComparableEntity): + class Person(Base, ComparableEntity): __tablename__ = "people" id = Column(Integer, primary_key=True) name = Column(String(50)) @@ -1415,7 +1416,7 @@ class DeclarativeInheritanceTest( assert_raises(sa.exc.ArgumentError, go) def test_single_no_special_cols(self): - class Person(Base, fixtures.ComparableEntity): + class Person(Base, ComparableEntity): __tablename__ = "people" id = Column("id", Integer, primary_key=True) name = Column("name", String(50)) @@ -1431,7 +1432,7 @@ class DeclarativeInheritanceTest( assert_raises_message(sa.exc.ArgumentError, "place primary key", go) def test_single_no_table_args(self): - class Person(Base, fixtures.ComparableEntity): + class Person(Base, ComparableEntity): __tablename__ = "people" id = Column("id", Integer, primary_key=True) name = Column("name", String(50)) diff --git a/test/orm/declarative/test_reflection.py b/test/orm/declarative/test_reflection.py index a2ed8f0ebf..be9d28b193 100644 --- a/test/orm/declarative/test_reflection.py +++ b/test/orm/declarative/test_reflection.py @@ -9,6 +9,8 @@ from sqlalchemy.orm import relationship from sqlalchemy.testing import assert_raises from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures +from sqlalchemy.testing.entities import ComparableEntity +from sqlalchemy.testing.entities import ComparableMixin from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table @@ -62,12 +64,12 @@ class DeclarativeReflectionTest(DeclarativeReflectionBase): ) def test_basic(self): - class User(Base, fixtures.ComparableEntity): + class User(Base, ComparableEntity): __tablename__ = "users" __autoload_with__ = testing.db addresses = relationship("Address", backref="user") - class Address(Base, fixtures.ComparableEntity): + class Address(Base, ComparableEntity): __tablename__ = "addresses" __autoload_with__ = testing.db @@ -92,13 +94,13 @@ class DeclarativeReflectionTest(DeclarativeReflectionBase): eq_(a1.user, User(name="u1")) def test_rekey_wbase(self): - class User(Base, fixtures.ComparableEntity): + class User(Base, ComparableEntity): __tablename__ = "users" __autoload_with__ = testing.db nom = Column("name", String(50), key="nom") addresses = relationship("Address", backref="user") - class Address(Base, fixtures.ComparableEntity): + class Address(Base, ComparableEntity): __tablename__ = "addresses" __autoload_with__ = testing.db @@ -125,14 +127,14 @@ class DeclarativeReflectionTest(DeclarativeReflectionBase): def test_rekey_wdecorator(self): @registry.mapped - class User(fixtures.ComparableMixin): + class User(ComparableMixin): __tablename__ = "users" __autoload_with__ = testing.db nom = Column("name", String(50), key="nom") addresses = relationship("Address", backref="user") @registry.mapped - class Address(fixtures.ComparableMixin): + class Address(ComparableMixin): __tablename__ = "addresses" __autoload_with__ = testing.db @@ -158,12 +160,12 @@ class DeclarativeReflectionTest(DeclarativeReflectionBase): assert_raises(TypeError, User, name="u3") def test_supplied_fk(self): - class IMHandle(Base, fixtures.ComparableEntity): + class IMHandle(Base, ComparableEntity): __tablename__ = "imhandles" __autoload_with__ = testing.db user_id = Column("user_id", Integer, ForeignKey("users.id")) - class User(Base, fixtures.ComparableEntity): + class User(Base, ComparableEntity): __tablename__ = "users" __autoload_with__ = testing.db handles = relationship("IMHandle", backref="user") diff --git a/test/orm/dml/test_bulk_statements.py b/test/orm/dml/test_bulk_statements.py index 9550671119..2888aeaf9e 100644 --- a/test/orm/dml/test_bulk_statements.py +++ b/test/orm/dml/test_bulk_statements.py @@ -61,7 +61,7 @@ class InsertStmtTest(testing.AssertsExecutionResults, fixtures.TestBase): style: testing.Variation, sort_by_parameter_order, ): - class A(fixtures.ComparableEntity, decl_base): + class A(ComparableEntity, decl_base): __tablename__ = "a" id: Mapped[int] = mapped_column(Identity(), primary_key=True) data: Mapped[str] @@ -1700,7 +1700,7 @@ class BulkDMLReturningJoinedInhTest( def setup_classes(cls): decl_base = cls.DeclarativeBasic - class A(fixtures.ComparableEntity, decl_base): + class A(ComparableEntity, decl_base): __tablename__ = "a" id: Mapped[int] = mapped_column(Identity(), primary_key=True) type: Mapped[str] @@ -1814,7 +1814,7 @@ class BulkDMLReturningSingleInhTest( def setup_classes(cls): decl_base = cls.DeclarativeBasic - class A(fixtures.ComparableEntity, decl_base): + class A(ComparableEntity, decl_base): __tablename__ = "a" id: Mapped[int] = mapped_column(Identity(), primary_key=True) type: Mapped[str] @@ -1857,7 +1857,7 @@ class BulkDMLReturningConcreteInhTest( def setup_classes(cls): decl_base = cls.DeclarativeBasic - class A(fixtures.ComparableEntity, decl_base): + class A(ComparableEntity, decl_base): __tablename__ = "a" id: Mapped[int] = mapped_column(Identity(), primary_key=True) type: Mapped[str] @@ -1897,7 +1897,7 @@ class CTETest(fixtures.DeclarativeMappedTest): def setup_classes(cls): decl_base = cls.DeclarativeBasic - class User(fixtures.ComparableEntity, decl_base): + class User(ComparableEntity, decl_base): __tablename__ = "users" id: Mapped[uuid.UUID] = mapped_column(primary_key=True) username: Mapped[str] diff --git a/test/orm/inheritance/_poly_fixtures.py b/test/orm/inheritance/_poly_fixtures.py index ae64bc9f92..5b5989c920 100644 --- a/test/orm/inheritance/_poly_fixtures.py +++ b/test/orm/inheritance/_poly_fixtures.py @@ -9,15 +9,16 @@ from sqlalchemy.sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import config from sqlalchemy.testing import fixtures +from sqlalchemy.testing.entities import ComparableEntity from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table -class Company(fixtures.ComparableEntity): +class Company(ComparableEntity): pass -class Person(fixtures.ComparableEntity): +class Person(ComparableEntity): pass @@ -33,19 +34,19 @@ class Boss(Manager): pass -class Machine(fixtures.ComparableEntity): +class Machine(ComparableEntity): pass -class MachineType(fixtures.ComparableEntity): +class MachineType(ComparableEntity): pass -class Paperwork(fixtures.ComparableEntity): +class Paperwork(ComparableEntity): pass -class Page(fixtures.ComparableEntity): +class Page(ComparableEntity): pass @@ -568,7 +569,7 @@ class GeometryFixtureBase(fixtures.DeclarativeMappedTest): items["__mapper_args__"][mapper_opt] = value[mapper_opt] if is_base: - klass = type(key, (fixtures.ComparableEntity, base), items) + klass = type(key, (ComparableEntity, base), items) else: klass = type(key, (base,), items) diff --git a/test/orm/inheritance/test_abc_polymorphic.py b/test/orm/inheritance/test_abc_polymorphic.py index 3ec9b55857..f0967d86cc 100644 --- a/test/orm/inheritance/test_abc_polymorphic.py +++ b/test/orm/inheritance/test_abc_polymorphic.py @@ -4,6 +4,7 @@ from sqlalchemy import String from sqlalchemy import testing from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures +from sqlalchemy.testing.entities import ComparableEntity from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table @@ -37,7 +38,7 @@ class ABCTest(fixtures.MappedTest): @testing.combinations(("union",), ("none",)) def test_abc_poly_roundtrip(self, fetchtype): - class A(fixtures.ComparableEntity): + class A(ComparableEntity): pass class B(A): diff --git a/test/orm/inheritance/test_assorted_poly.py b/test/orm/inheritance/test_assorted_poly.py index 6032229528..aa076e19f9 100644 --- a/test/orm/inheritance/test_assorted_poly.py +++ b/test/orm/inheritance/test_assorted_poly.py @@ -40,7 +40,7 @@ from sqlalchemy.testing import config from sqlalchemy.testing import eq_ from sqlalchemy.testing import expect_warnings from sqlalchemy.testing import fixtures -from sqlalchemy.testing.fixtures import ComparableEntity +from sqlalchemy.testing.entities import ComparableEntity from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.provision import normalize_sequence from sqlalchemy.testing.schema import Column @@ -1083,7 +1083,7 @@ class RelationshipTest8(fixtures.MappedTest): ) def test_selfref_onjoined(self): - class Taggable(fixtures.ComparableEntity): + class Taggable(ComparableEntity): pass class User(Taggable): @@ -1880,14 +1880,14 @@ class InheritingEagerTest(fixtures.MappedTest): """test that Query uses the full set of mapper._eager_loaders when generating SQL""" - class Person(fixtures.ComparableEntity): + class Person(ComparableEntity): pass class Employee(Person): def __init__(self, name="bob"): self.name = name - class Tag(fixtures.ComparableEntity): + class Tag(ComparableEntity): def __init__(self, label): self.label = label diff --git a/test/orm/inheritance/test_basic.py b/test/orm/inheritance/test_basic.py index ab97c5f250..769ef645e8 100644 --- a/test/orm/inheritance/test_basic.py +++ b/test/orm/inheritance/test_basic.py @@ -46,6 +46,8 @@ from sqlalchemy.testing.assertsql import CompiledSQL from sqlalchemy.testing.assertsql import Conditional from sqlalchemy.testing.assertsql import Or from sqlalchemy.testing.assertsql import RegexSQL +from sqlalchemy.testing.entities import BasicEntity +from sqlalchemy.testing.entities import ComparableEntity from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table @@ -188,7 +190,7 @@ class PolyExpressionEagerLoad(fixtures.DeclarativeMappedTest): def setup_classes(cls): Base = cls.DeclarativeBasic - class A(fixtures.ComparableEntity, Base): + class A(ComparableEntity, Base): __tablename__ = "a" id = Column( @@ -782,7 +784,7 @@ class PolymorphicSynonymTest(fixtures.MappedTest): ) def test_polymorphic_synonym(self): - class T1(fixtures.ComparableEntity): + class T1(ComparableEntity): def info(self): return "THE INFO IS:" + self._info @@ -1055,16 +1057,16 @@ class CascadeTest(fixtures.MappedTest): ) def test_cascade(self): - class T1(fixtures.BasicEntity): + class T1(BasicEntity): pass - class T2(fixtures.BasicEntity): + class T2(BasicEntity): pass class T3(T2): pass - class T4(fixtures.BasicEntity): + class T4(BasicEntity): pass self.mapper_registry.map_imperatively( @@ -1137,7 +1139,7 @@ class M2OUseGetTest(fixtures.MappedTest): ) # test [ticket:1186] - class Base(fixtures.BasicEntity): + class Base(BasicEntity): pass class Sub(Base): @@ -1405,7 +1407,7 @@ class EagerTargetingTest(fixtures.MappedTest): def test_adapt_stringency(self): b_table, a_table = self.tables.b_table, self.tables.a_table - class A(fixtures.ComparableEntity): + class A(ComparableEntity): pass class B(A): @@ -2107,7 +2109,7 @@ class VersioningTest(fixtures.MappedTest): self.tables.stuff, ) - class Base(fixtures.BasicEntity): + class Base(BasicEntity): pass class Sub(Base): @@ -2171,7 +2173,7 @@ class VersioningTest(fixtures.MappedTest): def test_delete(self): subtable, base = self.tables.subtable, self.tables.base - class Base(fixtures.BasicEntity): + class Base(BasicEntity): pass class Sub(Base): @@ -2833,10 +2835,10 @@ class OptimizedLoadTest(fixtures.MappedTest): def test_no_optimize_on_map_to_join(self): base, sub = self.tables.base, self.tables.sub - class Base(fixtures.ComparableEntity): + class Base(ComparableEntity): pass - class JoinBase(fixtures.ComparableEntity): + class JoinBase(ComparableEntity): pass class SubJoinBase(JoinBase): @@ -2902,7 +2904,7 @@ class OptimizedLoadTest(fixtures.MappedTest): base, sub = self.tables.base, self.tables.sub - class Base(fixtures.ComparableEntity): + class Base(ComparableEntity): pass class Sub(Base): @@ -3014,7 +3016,7 @@ class OptimizedLoadTest(fixtures.MappedTest): base, sub = self.tables.base, self.tables.sub - class Base(fixtures.ComparableEntity): + class Base(ComparableEntity): pass class Sub(Base): @@ -3050,7 +3052,7 @@ class OptimizedLoadTest(fixtures.MappedTest): def test_column_expression(self): base, sub = self.tables.base, self.tables.sub - class Base(fixtures.ComparableEntity): + class Base(ComparableEntity): pass class Sub(Base): @@ -3079,7 +3081,7 @@ class OptimizedLoadTest(fixtures.MappedTest): def test_column_expression_joined(self): base, sub = self.tables.base, self.tables.sub - class Base(fixtures.ComparableEntity): + class Base(ComparableEntity): pass class Sub(Base): @@ -3120,7 +3122,7 @@ class OptimizedLoadTest(fixtures.MappedTest): def test_composite_column_joined(self): base, with_comp = self.tables.base, self.tables.with_comp - class Base(fixtures.BasicEntity): + class Base(BasicEntity): pass class WithComp(Base): @@ -3168,7 +3170,7 @@ class OptimizedLoadTest(fixtures.MappedTest): expected_eager_defaults and testing.db.dialect.insert_returning ) - class Base(fixtures.BasicEntity): + class Base(BasicEntity): pass class Sub(Base): @@ -3257,7 +3259,7 @@ class OptimizedLoadTest(fixtures.MappedTest): def test_dont_generate_on_none(self): base, sub = self.tables.base, self.tables.sub - class Base(fixtures.BasicEntity): + class Base(BasicEntity): pass class Sub(Base): @@ -3305,7 +3307,7 @@ class OptimizedLoadTest(fixtures.MappedTest): self.tables.subsub, ) - class Base(fixtures.BasicEntity): + class Base(BasicEntity): pass class Sub(Base): @@ -3820,13 +3822,13 @@ class DeleteOrphanTest(fixtures.MappedTest): ) def test_orphan_message(self): - class Base(fixtures.BasicEntity): + class Base(BasicEntity): pass class SubClass(Base): pass - class Parent(fixtures.BasicEntity): + class Parent(BasicEntity): pass self.mapper_registry.map_imperatively( @@ -3927,11 +3929,11 @@ class DiscriminatorOrPkNoneTest(fixtures.DeclarativeMappedTest): def setup_classes(cls): Base = cls.DeclarativeBasic - class Parent(fixtures.ComparableEntity, Base): + class Parent(ComparableEntity, Base): __tablename__ = "parent" id = Column(Integer, primary_key=True) - class A(fixtures.ComparableEntity, Base): + class A(ComparableEntity, Base): __tablename__ = "a" id = Column(Integer, primary_key=True) parent_id = Column(ForeignKey("parent.id")) @@ -4019,7 +4021,7 @@ class UnexpectedPolymorphicIdentityTest(fixtures.DeclarativeMappedTest): def setup_classes(cls): Base = cls.DeclarativeBasic - class AJoined(fixtures.ComparableEntity, Base): + class AJoined(ComparableEntity, Base): __tablename__ = "ajoined" id = Column(Integer, primary_key=True) type = Column(String(10), nullable=False) @@ -4038,7 +4040,7 @@ class UnexpectedPolymorphicIdentityTest(fixtures.DeclarativeMappedTest): id = Column(ForeignKey("ajoined.id"), primary_key=True) __mapper_args__ = {"polymorphic_identity": "subb"} - class ASingle(fixtures.ComparableEntity, Base): + class ASingle(ComparableEntity, Base): __tablename__ = "asingle" id = Column(Integer, primary_key=True) type = Column(String(10), nullable=False) @@ -4110,7 +4112,7 @@ class CompositeJoinedInTest(fixtures.DeclarativeMappedTest): def setup_classes(cls): Base = cls.DeclarativeBasic - class A(fixtures.ComparableEntity, Base): + class A(ComparableEntity, Base): __tablename__ = "table_a" order_id: Mapped[str] = mapped_column(String(50), primary_key=True) diff --git a/test/orm/inheritance/test_poly_loading.py b/test/orm/inheritance/test_poly_loading.py index 755af492ca..8790848288 100644 --- a/test/orm/inheritance/test_poly_loading.py +++ b/test/orm/inheritance/test_poly_loading.py @@ -829,11 +829,11 @@ class LoaderOptionsTest( def setup_classes(cls): Base = cls.DeclarativeBasic - class Parent(fixtures.ComparableEntity, Base): + class Parent(ComparableEntity, Base): __tablename__ = "parent" id = Column(Integer, primary_key=True) - class Child(fixtures.ComparableEntity, Base): + class Child(ComparableEntity, Base): __tablename__ = "child" id = Column(Integer, primary_key=True) parent_id = Column(Integer, ForeignKey("parent.id")) @@ -850,7 +850,7 @@ class LoaderOptionsTest( "polymorphic_load": "selectin", } - class Other(fixtures.ComparableEntity, Base): + class Other(ComparableEntity, Base): __tablename__ = "other" id = Column(Integer, primary_key=True) diff --git a/test/orm/inheritance/test_poly_persistence.py b/test/orm/inheritance/test_poly_persistence.py index f244c91106..0a92b7f5a4 100644 --- a/test/orm/inheritance/test_poly_persistence.py +++ b/test/orm/inheritance/test_poly_persistence.py @@ -13,11 +13,12 @@ from sqlalchemy.testing import assert_raises from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ +from sqlalchemy.testing.entities import ComparableEntity from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column -class Person(fixtures.ComparableEntity): +class Person(ComparableEntity): pass @@ -33,7 +34,7 @@ class Boss(Manager): pass -class Company(fixtures.ComparableEntity): +class Company(ComparableEntity): pass diff --git a/test/orm/inheritance/test_relationship.py b/test/orm/inheritance/test_relationship.py index 293c7dfb59..4ed2a453d3 100644 --- a/test/orm/inheritance/test_relationship.py +++ b/test/orm/inheritance/test_relationship.py @@ -29,11 +29,11 @@ from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table -class Company(fixtures.ComparableEntity): +class Company(ComparableEntity): pass -class Person(fixtures.ComparableEntity): +class Person(ComparableEntity): pass @@ -49,11 +49,11 @@ class Boss(Manager): pass -class Machine(fixtures.ComparableEntity): +class Machine(ComparableEntity): pass -class Paperwork(fixtures.ComparableEntity): +class Paperwork(ComparableEntity): pass diff --git a/test/orm/inheritance/test_selects.py b/test/orm/inheritance/test_selects.py index 47827e8887..5fb15c9b7c 100644 --- a/test/orm/inheritance/test_selects.py +++ b/test/orm/inheritance/test_selects.py @@ -5,6 +5,7 @@ from sqlalchemy import String from sqlalchemy.orm import Session from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures +from sqlalchemy.testing.entities import ComparableEntity from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table @@ -31,7 +32,7 @@ class InheritingSelectablesTest(fixtures.MappedTest): connection.execute(foo.insert(), dict(a="i am bar", b="bar")) connection.execute(foo.insert(), dict(a="also bar", b="bar")) - class Foo(fixtures.ComparableEntity): + class Foo(ComparableEntity): pass class Bar(Foo): diff --git a/test/orm/inheritance/test_single.py b/test/orm/inheritance/test_single.py index 4461ac86d2..52f3cf9c9f 100644 --- a/test/orm/inheritance/test_single.py +++ b/test/orm/inheritance/test_single.py @@ -34,6 +34,7 @@ from sqlalchemy.testing import expect_raises_message from sqlalchemy.testing import fixtures from sqlalchemy.testing import mock from sqlalchemy.testing.assertsql import CompiledSQL +from sqlalchemy.testing.entities import ComparableEntity from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table @@ -1693,7 +1694,7 @@ class SingleOnJoinedTest(fixtures.MappedTest): ) def test_single_on_joined(self): - class Person(fixtures.ComparableEntity): + class Person(ComparableEntity): pass class Employee(Person): diff --git a/test/orm/test_ac_relationships.py b/test/orm/test_ac_relationships.py index f53aedf07d..603e71d249 100644 --- a/test/orm/test_ac_relationships.py +++ b/test/orm/test_ac_relationships.py @@ -19,7 +19,7 @@ from sqlalchemy.testing import expect_warnings from sqlalchemy.testing import fixtures from sqlalchemy.testing.assertions import expect_raises_message from sqlalchemy.testing.assertsql import CompiledSQL -from sqlalchemy.testing.fixtures import ComparableEntity +from sqlalchemy.testing.entities import ComparableEntity from sqlalchemy.testing.fixtures import fixture_session diff --git a/test/orm/test_attributes.py b/test/orm/test_attributes.py index 58e1ab97b9..4b9d3b2e02 100644 --- a/test/orm/test_attributes.py +++ b/test/orm/test_attributes.py @@ -22,6 +22,7 @@ from sqlalchemy.testing import is_not from sqlalchemy.testing import is_true from sqlalchemy.testing import not_in from sqlalchemy.testing.assertions import assert_warns +from sqlalchemy.testing.entities import BasicEntity from sqlalchemy.testing.util import all_partial_orderings from sqlalchemy.testing.util import gc_collect @@ -770,10 +771,10 @@ class AttributesTest(fixtures.ORMTest): def test_lazyhistory(self): """tests that history functions work with lazy-loading attributes""" - class Foo(fixtures.BasicEntity): + class Foo(BasicEntity): pass - class Bar(fixtures.BasicEntity): + class Bar(BasicEntity): pass instrumentation.register_class(Foo) @@ -1737,7 +1738,7 @@ class PendingBackrefTest(fixtures.ORMTest): class HistoryTest(fixtures.TestBase): def _fixture(self, uselist, useobject, active_history, **kw): - class Foo(fixtures.BasicEntity): + class Foo(BasicEntity): pass instrumentation.register_class(Foo) @@ -1752,10 +1753,10 @@ class HistoryTest(fixtures.TestBase): return Foo def _two_obj_fixture(self, uselist, active_history=False): - class Foo(fixtures.BasicEntity): + class Foo(BasicEntity): pass - class Bar(fixtures.BasicEntity): + class Bar(BasicEntity): def __bool__(self): assert False @@ -2571,10 +2572,10 @@ class HistoryTest(fixtures.TestBase): def test_dict_collections(self): # TODO: break into individual tests - class Foo(fixtures.BasicEntity): + class Foo(BasicEntity): pass - class Bar(fixtures.BasicEntity): + class Bar(BasicEntity): pass instrumentation.register_class(Foo) @@ -2630,10 +2631,10 @@ class HistoryTest(fixtures.TestBase): def test_object_collections_mutate(self): # TODO: break into individual tests - class Foo(fixtures.BasicEntity): + class Foo(BasicEntity): pass - class Bar(fixtures.BasicEntity): + class Bar(BasicEntity): pass instrumentation.register_class(Foo) @@ -2818,10 +2819,10 @@ class HistoryTest(fixtures.TestBase): def test_collections_via_backref(self): # TODO: break into individual tests - class Foo(fixtures.BasicEntity): + class Foo(BasicEntity): pass - class Bar(fixtures.BasicEntity): + class Bar(BasicEntity): pass instrumentation.register_class(Foo) @@ -2890,10 +2891,10 @@ class LazyloadHistoryTest(fixtures.TestBase): def test_lazy_backref_collections(self): # TODO: break into individual tests - class Foo(fixtures.BasicEntity): + class Foo(BasicEntity): pass - class Bar(fixtures.BasicEntity): + class Bar(BasicEntity): pass lazy_load = [] @@ -2949,10 +2950,10 @@ class LazyloadHistoryTest(fixtures.TestBase): def test_collections_via_lazyload(self): # TODO: break into individual tests - class Foo(fixtures.BasicEntity): + class Foo(BasicEntity): pass - class Bar(fixtures.BasicEntity): + class Bar(BasicEntity): pass lazy_load = [] @@ -3012,7 +3013,7 @@ class LazyloadHistoryTest(fixtures.TestBase): def test_scalar_via_lazyload(self): # TODO: break into individual tests - class Foo(fixtures.BasicEntity): + class Foo(BasicEntity): pass lazy_load = None @@ -3068,7 +3069,7 @@ class LazyloadHistoryTest(fixtures.TestBase): def test_scalar_via_lazyload_with_active(self): # TODO: break into individual tests - class Foo(fixtures.BasicEntity): + class Foo(BasicEntity): pass lazy_load = None @@ -3129,10 +3130,10 @@ class LazyloadHistoryTest(fixtures.TestBase): def test_scalar_object_via_lazyload(self): # TODO: break into individual tests - class Foo(fixtures.BasicEntity): + class Foo(BasicEntity): pass - class Bar(fixtures.BasicEntity): + class Bar(BasicEntity): pass lazy_load = None @@ -3195,10 +3196,10 @@ class LazyloadHistoryTest(fixtures.TestBase): class CollectionKeyTest(fixtures.ORMTest): @testing.fixture def dict_collection(self): - class Foo(fixtures.BasicEntity): + class Foo(BasicEntity): pass - class Bar(fixtures.BasicEntity): + class Bar(BasicEntity): def __init__(self, name): self.name = name @@ -3222,10 +3223,10 @@ class CollectionKeyTest(fixtures.ORMTest): @testing.fixture def list_collection(self): - class Foo(fixtures.BasicEntity): + class Foo(BasicEntity): pass - class Bar(fixtures.BasicEntity): + class Bar(BasicEntity): pass instrumentation.register_class(Foo) diff --git a/test/orm/test_cascade.py b/test/orm/test_cascade.py index 3c49fc8dcd..6b84ec6f78 100644 --- a/test/orm/test_cascade.py +++ b/test/orm/test_cascade.py @@ -30,6 +30,7 @@ from sqlalchemy.testing import fixtures from sqlalchemy.testing import in_ from sqlalchemy.testing import not_in from sqlalchemy.testing.assertsql import CompiledSQL +from sqlalchemy.testing.entities import ComparableEntity from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table @@ -3256,13 +3257,13 @@ class DoubleParentO2MOrphanTest(fixtures.MappedTest): self.tables.accounts, ) - class Customer(fixtures.ComparableEntity): + class Customer(ComparableEntity): pass - class Account(fixtures.ComparableEntity): + class Account(ComparableEntity): pass - class SalesRep(fixtures.ComparableEntity): + class SalesRep(ComparableEntity): pass self.mapper_registry.map_imperatively( @@ -3428,13 +3429,13 @@ class DoubleParentM2OOrphanTest(fixtures.MappedTest): self.tables.addresses, ) - class Address(fixtures.ComparableEntity): + class Address(ComparableEntity): pass - class Home(fixtures.ComparableEntity): + class Home(ComparableEntity): pass - class Business(fixtures.ComparableEntity): + class Business(ComparableEntity): pass self.mapper_registry.map_imperatively(Address, addresses) @@ -3488,13 +3489,13 @@ class DoubleParentM2OOrphanTest(fixtures.MappedTest): self.tables.addresses, ) - class Address(fixtures.ComparableEntity): + class Address(ComparableEntity): pass - class Home(fixtures.ComparableEntity): + class Home(ComparableEntity): pass - class Business(fixtures.ComparableEntity): + class Business(ComparableEntity): pass self.mapper_registry.map_imperatively(Address, addresses) @@ -3546,10 +3547,10 @@ class CollectionAssignmentOrphanTest(fixtures.MappedTest): def test_basic(self): table_b, table_a = self.tables.table_b, self.tables.table_a - class A(fixtures.ComparableEntity): + class A(ComparableEntity): pass - class B(fixtures.ComparableEntity): + class B(ComparableEntity): pass self.mapper_registry.map_imperatively( @@ -4044,10 +4045,10 @@ class PartialFlushTest(fixtures.MappedTest): def test_o2m_m2o(self): base, noninh_child = self.tables.base, self.tables.noninh_child - class Base(fixtures.ComparableEntity): + class Base(ComparableEntity): pass - class Child(fixtures.ComparableEntity): + class Child(ComparableEntity): pass self.mapper_registry.map_imperatively( @@ -4103,7 +4104,7 @@ class PartialFlushTest(fixtures.MappedTest): self.tables.parent, ) - class Base(fixtures.ComparableEntity): + class Base(ComparableEntity): pass class Parent(Base): diff --git a/test/orm/test_defaults.py b/test/orm/test_defaults.py index 562d9b9dc9..d230d4aafc 100644 --- a/test/orm/test_defaults.py +++ b/test/orm/test_defaults.py @@ -11,6 +11,7 @@ from sqlalchemy.testing.assertsql import assert_engine from sqlalchemy.testing.assertsql import CompiledSQL from sqlalchemy.testing.assertsql import Conditional from sqlalchemy.testing.assertsql import RegexSQL +from sqlalchemy.testing.entities import BasicEntity from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table @@ -209,7 +210,7 @@ class ExcludedDefaultsTest(fixtures.MappedTest): def test_exclude(self): dt = self.tables.dt - class Foo(fixtures.BasicEntity): + class Foo(BasicEntity): pass self.mapper_registry.map_imperatively( diff --git a/test/orm/test_deferred.py b/test/orm/test_deferred.py index c93ac6d60a..fa044d033c 100644 --- a/test/orm/test_deferred.py +++ b/test/orm/test_deferred.py @@ -43,6 +43,7 @@ from sqlalchemy.testing import eq_ from sqlalchemy.testing import expect_raises_message from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ +from sqlalchemy.testing.entities import ComparableEntity from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table @@ -2117,7 +2118,7 @@ class WithExpressionTest(fixtures.DeclarativeMappedTest): def setup_classes(cls): Base = cls.DeclarativeBasic - class A(fixtures.ComparableEntity, Base): + class A(ComparableEntity, Base): __tablename__ = "a" id = Column(Integer, primary_key=True) x = Column(Integer) @@ -2127,7 +2128,7 @@ class WithExpressionTest(fixtures.DeclarativeMappedTest): bs = relationship("B", order_by="B.id") - class A_default(fixtures.ComparableEntity, Base): + class A_default(ComparableEntity, Base): __tablename__ = "a_default" id = Column(Integer, primary_key=True) x = Column(Integer) @@ -2135,7 +2136,7 @@ class WithExpressionTest(fixtures.DeclarativeMappedTest): my_expr = query_expression(default_expr=literal(15)) - class B(fixtures.ComparableEntity, Base): + class B(ComparableEntity, Base): __tablename__ = "b" id = Column(Integer, primary_key=True) a_id = Column(ForeignKey("a.id")) @@ -2144,7 +2145,7 @@ class WithExpressionTest(fixtures.DeclarativeMappedTest): b_expr = query_expression() - class C(fixtures.ComparableEntity, Base): + class C(ComparableEntity, Base): __tablename__ = "c" id = Column(Integer, primary_key=True) x = Column(Integer) @@ -2489,7 +2490,7 @@ class RaiseLoadTest(fixtures.DeclarativeMappedTest): def setup_classes(cls): Base = cls.DeclarativeBasic - class A(fixtures.ComparableEntity, Base): + class A(ComparableEntity, Base): __tablename__ = "a" id = Column(Integer, primary_key=True) x = Column(Integer) diff --git a/test/orm/test_deprecations.py b/test/orm/test_deprecations.py index 0fa6c94a30..23248349cd 100644 --- a/test/orm/test_deprecations.py +++ b/test/orm/test_deprecations.py @@ -58,6 +58,7 @@ from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ from sqlalchemy.testing import is_true from sqlalchemy.testing import mock +from sqlalchemy.testing.entities import ComparableEntity from sqlalchemy.testing.fixtures import CacheKeyFixture from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.fixtures import RemoveORMEventsGlobally @@ -841,7 +842,7 @@ class DeprecatedMapperTest( assert_col = [] - class User(fixtures.ComparableEntity): + class User(ComparableEntity): def _get_name(self): assert_col.append(("get", self._name)) return self._name diff --git a/test/orm/test_eager_relations.py b/test/orm/test_eager_relations.py index fa44dbf10d..261269dec1 100644 --- a/test/orm/test_eager_relations.py +++ b/test/orm/test_eager_relations.py @@ -41,6 +41,7 @@ from sqlalchemy.testing import is_ from sqlalchemy.testing import is_not from sqlalchemy.testing import mock from sqlalchemy.testing.assertsql import CompiledSQL +from sqlalchemy.testing.entities import ComparableEntity from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table @@ -4318,10 +4319,10 @@ class OrderBySecondaryTest(fixtures.MappedTest): def test_ordering(self): a, m2m, b = (self.tables.a, self.tables.m2m, self.tables.b) - class A(fixtures.ComparableEntity): + class A(ComparableEntity): pass - class B(fixtures.ComparableEntity): + class B(ComparableEntity): pass self.mapper_registry.map_imperatively( @@ -4361,7 +4362,7 @@ class SelfReferentialEagerTest(fixtures.MappedTest): def test_basic(self): nodes = self.tables.nodes - class Node(fixtures.ComparableEntity): + class Node(ComparableEntity): def append(self, node): self.children.append(node) @@ -4437,7 +4438,7 @@ class SelfReferentialEagerTest(fixtures.MappedTest): def test_lazy_fallback_doesnt_affect_eager(self): nodes = self.tables.nodes - class Node(fixtures.ComparableEntity): + class Node(ComparableEntity): def append(self, node): self.children.append(node) @@ -4484,7 +4485,7 @@ class SelfReferentialEagerTest(fixtures.MappedTest): def test_with_deferred(self): nodes = self.tables.nodes - class Node(fixtures.ComparableEntity): + class Node(ComparableEntity): def append(self, node): self.children.append(node) @@ -4545,7 +4546,7 @@ class SelfReferentialEagerTest(fixtures.MappedTest): def test_options(self): nodes = self.tables.nodes - class Node(fixtures.ComparableEntity): + class Node(ComparableEntity): def append(self, node): self.children.append(node) @@ -4620,7 +4621,7 @@ class SelfReferentialEagerTest(fixtures.MappedTest): def test_no_depth(self): nodes = self.tables.nodes - class Node(fixtures.ComparableEntity): + class Node(ComparableEntity): def append(self, node): self.children.append(node) @@ -4813,7 +4814,7 @@ class SelfReferentialM2MEagerTest(fixtures.MappedTest): def test_basic(self): widget, widget_rel = self.tables.widget, self.tables.widget_rel - class Widget(fixtures.ComparableEntity): + class Widget(ComparableEntity): pass self.mapper_registry.map_imperatively( @@ -5236,12 +5237,12 @@ class SubqueryTest(fixtures.MappedTest): self.tables.users_table, ) - class User(fixtures.ComparableEntity): + class User(ComparableEntity): @property def prop_score(self): return sum([tag.prop_score for tag in self.tags]) - class Tag(fixtures.ComparableEntity): + class Tag(ComparableEntity): @property def prop_score(self): return self.score1 * self.score2 @@ -5395,10 +5396,10 @@ class CorrelatedSubqueryTest(fixtures.MappedTest): def _do_test(self, labeled, ondate, aliasstuff): stuff, users = self.tables.stuff, self.tables.users - class User(fixtures.ComparableEntity): + class User(ComparableEntity): pass - class Stuff(fixtures.ComparableEntity): + class Stuff(ComparableEntity): pass self.mapper_registry.map_imperatively(Stuff, stuff) diff --git a/test/orm/test_froms.py b/test/orm/test_froms.py index 6b28a637ad..51c86a5f1d 100644 --- a/test/orm/test_froms.py +++ b/test/orm/test_froms.py @@ -3912,13 +3912,13 @@ class TestOverlyEagerEquivalentCols(fixtures.MappedTest): self.tables.sub1, ) - class Base(fixtures.ComparableEntity): + class Base(ComparableEntity): pass - class Sub1(fixtures.ComparableEntity): + class Sub1(ComparableEntity): pass - class Sub2(fixtures.ComparableEntity): + class Sub2(ComparableEntity): pass self.mapper_registry.map_imperatively( diff --git a/test/orm/test_lazy_relations.py b/test/orm/test_lazy_relations.py index e3936159ad..4ab9617123 100644 --- a/test/orm/test_lazy_relations.py +++ b/test/orm/test_lazy_relations.py @@ -32,6 +32,7 @@ from sqlalchemy.testing import is_ from sqlalchemy.testing import is_false from sqlalchemy.testing import is_true from sqlalchemy.testing.assertsql import CompiledSQL +from sqlalchemy.testing.entities import ComparableEntity from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table @@ -1025,10 +1026,10 @@ class GetterStateTest(_fixtures.FixtureTest): Column("data", MyHashType()), ) - class Category(fixtures.ComparableEntity): + class Category(ComparableEntity): pass - class Article(fixtures.ComparableEntity): + class Article(ComparableEntity): pass self.mapper_registry.map_imperatively(Category, category) @@ -1314,10 +1315,10 @@ class CorrelatedTest(fixtures.MappedTest): def test_correlated_lazyload(self): stuff, user_t = self.tables.stuff, self.tables.user_t - class User(fixtures.ComparableEntity): + class User(ComparableEntity): pass - class Stuff(fixtures.ComparableEntity): + class Stuff(ComparableEntity): pass self.mapper_registry.map_imperatively(Stuff, stuff) diff --git a/test/orm/test_mapper.py b/test/orm/test_mapper.py index 19caf04487..a3aad69f08 100644 --- a/test/orm/test_mapper.py +++ b/test/orm/test_mapper.py @@ -51,8 +51,8 @@ from sqlalchemy.testing import is_ from sqlalchemy.testing import is_false from sqlalchemy.testing import is_true from sqlalchemy.testing import ne_ -from sqlalchemy.testing.fixtures import ComparableEntity -from sqlalchemy.testing.fixtures import ComparableMixin +from sqlalchemy.testing.entities import ComparableEntity +from sqlalchemy.testing.entities import ComparableMixin from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table @@ -973,7 +973,7 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): assert_col = [] - class User(fixtures.ComparableEntity): + class User(ComparableEntity): def _get_name(self): assert_col.append(("get", self._name)) return self._name diff --git a/test/orm/test_merge.py b/test/orm/test_merge.py index 6b3a7c1d6d..0c8e2651cd 100644 --- a/test/orm/test_merge.py +++ b/test/orm/test_merge.py @@ -31,6 +31,7 @@ from sqlalchemy.testing import fixtures from sqlalchemy.testing import in_ from sqlalchemy.testing import not_in from sqlalchemy.testing.assertsql import CountStatements +from sqlalchemy.testing.entities import ComparableEntity from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table @@ -2181,10 +2182,10 @@ class LoadOnPendingTest(fixtures.MappedTest): @classmethod def setup_classes(cls): - class Rock(cls.Basic, fixtures.ComparableEntity): + class Rock(cls.Basic, ComparableEntity): pass - class Bug(cls.Basic, fixtures.ComparableEntity): + class Bug(cls.Basic, ComparableEntity): pass def _setup_delete_orphan_o2o(self): @@ -2251,7 +2252,7 @@ class PolymorphicOnTest(fixtures.MappedTest): @classmethod def setup_classes(cls): - class Employee(cls.Basic, fixtures.ComparableEntity): + class Employee(cls.Basic, ComparableEntity): pass class Manager(Employee): diff --git a/test/orm/test_query.py b/test/orm/test_query.py index 367f854427..ce5c64a43a 100644 --- a/test/orm/test_query.py +++ b/test/orm/test_query.py @@ -79,6 +79,7 @@ from sqlalchemy.testing.assertions import expect_raises from sqlalchemy.testing.assertions import expect_warnings from sqlalchemy.testing.assertions import is_not_none from sqlalchemy.testing.assertsql import CompiledSQL +from sqlalchemy.testing.entities import ComparableEntity from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table @@ -1321,7 +1322,7 @@ class GetTest(QueryTest): s = users.outerjoin(addresses) - class UserThing(fixtures.ComparableEntity): + class UserThing(ComparableEntity): pass registry.map_imperatively( diff --git a/test/orm/test_relationships.py b/test/orm/test_relationships.py index 12651fe364..2de35a9a1e 100644 --- a/test/orm/test_relationships.py +++ b/test/orm/test_relationships.py @@ -42,6 +42,8 @@ from sqlalchemy.testing import in_ from sqlalchemy.testing import is_ from sqlalchemy.testing.assertsql import assert_engine from sqlalchemy.testing.assertsql import CompiledSQL +from sqlalchemy.testing.entities import BasicEntity +from sqlalchemy.testing.entities import ComparableEntity from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table @@ -334,10 +336,10 @@ class M2ODontOverwriteFKTest(fixtures.MappedTest): def _fixture(self, uselist=False): a, b = self.tables.a, self.tables.b - class A(fixtures.BasicEntity): + class A(BasicEntity): pass - class B(fixtures.BasicEntity): + class B(BasicEntity): pass self.mapper_registry.map_imperatively( @@ -1720,7 +1722,7 @@ class FKsAsPksTest(fixtures.MappedTest): ) tableC.create(connection) - class C(fixtures.BasicEntity): + class C(BasicEntity): pass self.mapper_registry.map_imperatively( @@ -1958,10 +1960,10 @@ class RelationshipToSelectableTest(fixtures.MappedTest): def test_basic(self): items = self.tables.items - class Container(fixtures.BasicEntity): + class Container(BasicEntity): pass - class LineItem(fixtures.BasicEntity): + class LineItem(BasicEntity): pass container_select = ( @@ -2050,10 +2052,10 @@ class FKEquatedToConstantTest(fixtures.MappedTest): def test_basic(self): tag_foo, tags = self.tables.tag_foo, self.tables.tags - class Tag(fixtures.ComparableEntity): + class Tag(ComparableEntity): pass - class TagInstance(fixtures.ComparableEntity): + class TagInstance(ComparableEntity): pass self.mapper_registry.map_imperatively( @@ -2682,13 +2684,13 @@ class TypeMatchTest(fixtures.MappedTest): def test_o2m_oncascade(self): a, c, b = (self.tables.a, self.tables.c, self.tables.b) - class A(fixtures.BasicEntity): + class A(BasicEntity): pass - class B(fixtures.BasicEntity): + class B(BasicEntity): pass - class C(fixtures.BasicEntity): + class C(BasicEntity): pass self.mapper_registry.map_imperatively( @@ -2716,13 +2718,13 @@ class TypeMatchTest(fixtures.MappedTest): def test_o2m_onflush(self): a, c, b = (self.tables.a, self.tables.c, self.tables.b) - class A(fixtures.BasicEntity): + class A(BasicEntity): pass - class B(fixtures.BasicEntity): + class B(BasicEntity): pass - class C(fixtures.BasicEntity): + class C(BasicEntity): pass self.mapper_registry.map_imperatively( @@ -2747,10 +2749,10 @@ class TypeMatchTest(fixtures.MappedTest): def test_o2m_nopoly_onflush(self): a, c, b = (self.tables.a, self.tables.c, self.tables.b) - class A(fixtures.BasicEntity): + class A(BasicEntity): pass - class B(fixtures.BasicEntity): + class B(BasicEntity): pass class C(B): @@ -2778,13 +2780,13 @@ class TypeMatchTest(fixtures.MappedTest): def test_m2o_nopoly_onflush(self): a, b, d = (self.tables.a, self.tables.b, self.tables.d) - class A(fixtures.BasicEntity): + class A(BasicEntity): pass class B(A): pass - class D(fixtures.BasicEntity): + class D(BasicEntity): pass self.mapper_registry.map_imperatively(A, a) @@ -2805,13 +2807,13 @@ class TypeMatchTest(fixtures.MappedTest): def test_m2o_oncascade(self): a, b, d = (self.tables.a, self.tables.b, self.tables.d) - class A(fixtures.BasicEntity): + class A(BasicEntity): pass - class B(fixtures.BasicEntity): + class B(BasicEntity): pass - class D(fixtures.BasicEntity): + class D(BasicEntity): pass self.mapper_registry.map_imperatively(A, a) @@ -2865,10 +2867,10 @@ class TypedAssociationTable(fixtures.MappedTest): t2, t3, t1 = (self.tables.t2, self.tables.t3, self.tables.t1) - class T1(fixtures.BasicEntity): + class T1(BasicEntity): pass - class T2(fixtures.BasicEntity): + class T2(BasicEntity): pass self.mapper_registry.map_imperatively(T2, t2) @@ -2928,10 +2930,10 @@ class CustomOperatorTest(fixtures.MappedTest, AssertsCompiledSQL): ) def test_join_on_custom_op_legacy_is_comparison(self): - class A(fixtures.BasicEntity): + class A(BasicEntity): pass - class B(fixtures.BasicEntity): + class B(BasicEntity): pass self.mapper_registry.map_imperatively( @@ -2955,10 +2957,10 @@ class CustomOperatorTest(fixtures.MappedTest, AssertsCompiledSQL): ) def test_join_on_custom_bool_op(self): - class A(fixtures.BasicEntity): + class A(BasicEntity): pass - class B(fixtures.BasicEntity): + class B(BasicEntity): pass self.mapper_registry.map_imperatively( @@ -3016,10 +3018,10 @@ class ViewOnlyHistoryTest(fixtures.MappedTest): return s def test_o2m_viewonly_oneside(self): - class A(fixtures.ComparableEntity): + class A(ComparableEntity): pass - class B(fixtures.ComparableEntity): + class B(ComparableEntity): pass self.mapper_registry.map_imperatively( @@ -3049,10 +3051,10 @@ class ViewOnlyHistoryTest(fixtures.MappedTest): assert b1 not in sess.dirty def test_m2o_viewonly_oneside(self): - class A(fixtures.ComparableEntity): + class A(ComparableEntity): pass - class B(fixtures.ComparableEntity): + class B(ComparableEntity): pass self.mapper_registry.map_imperatively( @@ -3082,10 +3084,10 @@ class ViewOnlyHistoryTest(fixtures.MappedTest): assert b1 not in sess.dirty def test_o2m_viewonly_only(self): - class A(fixtures.ComparableEntity): + class A(ComparableEntity): pass - class B(fixtures.ComparableEntity): + class B(ComparableEntity): pass self.mapper_registry.map_imperatively( @@ -3103,10 +3105,10 @@ class ViewOnlyHistoryTest(fixtures.MappedTest): self._assert_fk(a1, b1, False) def test_m2o_viewonly_only(self): - class A(fixtures.ComparableEntity): + class A(ComparableEntity): pass - class B(fixtures.ComparableEntity): + class B(ComparableEntity): pass self.mapper_registry.map_imperatively(A, self.tables.t1) @@ -3151,10 +3153,10 @@ class ViewOnlyM2MBackrefTest(fixtures.MappedTest): def test_viewonly(self): t1t2, t2, t1 = (self.tables.t1t2, self.tables.t2, self.tables.t1) - class A(fixtures.ComparableEntity): + class A(ComparableEntity): pass - class B(fixtures.ComparableEntity): + class B(ComparableEntity): pass self.mapper_registry.map_imperatively( @@ -3226,13 +3228,13 @@ class ViewOnlyOverlappingNames(fixtures.MappedTest): t2, t3, t1 = (self.tables.t2, self.tables.t3, self.tables.t1) - class C1(fixtures.BasicEntity): + class C1(BasicEntity): pass - class C2(fixtures.BasicEntity): + class C2(BasicEntity): pass - class C3(fixtures.BasicEntity): + class C3(BasicEntity): pass self.mapper_registry.map_imperatively( @@ -3360,10 +3362,10 @@ class ViewOnlySyncBackref(fixtures.MappedTest): @testing.combinations(True, False, None, argnames="B_a_sync") @testing.combinations(True, False, argnames="B_a_view") def test_case(self, B_a_view, B_a_sync, A_bs_view, A_bs_sync): - class A(fixtures.ComparableEntity): + class A(ComparableEntity): pass - class B(fixtures.ComparableEntity): + class B(ComparableEntity): pass case = self.cases[(B_a_view, B_a_sync, A_bs_view, A_bs_sync)] @@ -3490,13 +3492,13 @@ class ViewOnlyUniqueNames(fixtures.MappedTest): t2, t3, t1 = (self.tables.t2, self.tables.t3, self.tables.t1) - class C1(fixtures.BasicEntity): + class C1(BasicEntity): pass - class C2(fixtures.BasicEntity): + class C2(BasicEntity): pass - class C3(fixtures.BasicEntity): + class C3(BasicEntity): pass self.mapper_registry.map_imperatively( @@ -3596,10 +3598,10 @@ class ViewOnlyNonEquijoin(fixtures.MappedTest): def test_viewonly_join(self): bars, foos = self.tables.bars, self.tables.foos - class Foo(fixtures.ComparableEntity): + class Foo(ComparableEntity): pass - class Bar(fixtures.ComparableEntity): + class Bar(ComparableEntity): pass self.mapper_registry.map_imperatively( @@ -3669,10 +3671,10 @@ class ViewOnlyRepeatedRemoteColumn(fixtures.MappedTest): def test_relationship_on_or(self): bars, foos = self.tables.bars, self.tables.foos - class Foo(fixtures.ComparableEntity): + class Foo(ComparableEntity): pass - class Bar(fixtures.ComparableEntity): + class Bar(ComparableEntity): pass self.mapper_registry.map_imperatively( @@ -3744,10 +3746,10 @@ class ViewOnlyRepeatedLocalColumn(fixtures.MappedTest): def test_relationship_on_or(self): bars, foos = self.tables.bars, self.tables.foos - class Foo(fixtures.ComparableEntity): + class Foo(ComparableEntity): pass - class Bar(fixtures.ComparableEntity): + class Bar(ComparableEntity): pass self.mapper_registry.map_imperatively( @@ -4006,7 +4008,7 @@ class RemoteForeignBetweenColsTest(fixtures.DeclarativeMappedTest): def setup_classes(cls): Base = cls.DeclarativeBasic - class Network(fixtures.ComparableEntity, Base): + class Network(ComparableEntity, Base): __tablename__ = "network" id = Column( @@ -4023,7 +4025,7 @@ class RemoteForeignBetweenColsTest(fixtures.DeclarativeMappedTest): viewonly=True, ) - class Address(fixtures.ComparableEntity, Base): + class Address(ComparableEntity, Base): __tablename__ = "address" ip_addr = Column(Integer, primary_key=True) diff --git a/test/orm/test_scoping.py b/test/orm/test_scoping.py index 8c6ddfa0e5..509137afe3 100644 --- a/test/orm/test_scoping.py +++ b/test/orm/test_scoping.py @@ -17,6 +17,7 @@ from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ from sqlalchemy.testing import mock +from sqlalchemy.testing.entities import ComparableEntity from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table @@ -49,10 +50,10 @@ class ScopedSessionTest(fixtures.MappedTest): class CustomQuery(query.Query): pass - class SomeObject(fixtures.ComparableEntity): + class SomeObject(ComparableEntity): query = Session.query_property() - class SomeOtherObject(fixtures.ComparableEntity): + class SomeOtherObject(ComparableEntity): query = Session.query_property() custom_query = Session.query_property(query_cls=CustomQuery) diff --git a/test/orm/test_selectin_relations.py b/test/orm/test_selectin_relations.py index 2fdc12574f..c9907c7651 100644 --- a/test/orm/test_selectin_relations.py +++ b/test/orm/test_selectin_relations.py @@ -29,7 +29,7 @@ from sqlalchemy.testing import mock from sqlalchemy.testing.assertsql import AllOf from sqlalchemy.testing.assertsql import assert_engine from sqlalchemy.testing.assertsql import CompiledSQL -from sqlalchemy.testing.fixtures import ComparableEntity +from sqlalchemy.testing.entities import ComparableEntity from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table @@ -1712,10 +1712,10 @@ class OrderBySecondaryTest(fixtures.MappedTest): def test_ordering(self): a, m2m, b = (self.tables.a, self.tables.m2m, self.tables.b) - class A(fixtures.ComparableEntity): + class A(ComparableEntity): pass - class B(fixtures.ComparableEntity): + class B(ComparableEntity): pass self.mapper_registry.map_imperatively( @@ -2239,14 +2239,14 @@ class TupleTest(fixtures.DeclarativeMappedTest): def setup_classes(cls): Base = cls.DeclarativeBasic - class A(fixtures.ComparableEntity, Base): + class A(ComparableEntity, Base): __tablename__ = "a" id1 = Column(Integer, primary_key=True) id2 = Column(Integer, primary_key=True) bs = relationship("B", order_by="B.id", back_populates="a") - class B(fixtures.ComparableEntity, Base): + class B(ComparableEntity, Base): __tablename__ = "b" id = Column(Integer, primary_key=True) a_id1 = Column() @@ -2355,12 +2355,12 @@ class ChunkingTest(fixtures.DeclarativeMappedTest): def setup_classes(cls): Base = cls.DeclarativeBasic - class A(fixtures.ComparableEntity, Base): + class A(ComparableEntity, Base): __tablename__ = "a" id = Column(Integer, primary_key=True) bs = relationship("B", order_by="B.id", back_populates="a") - class B(fixtures.ComparableEntity, Base): + class B(ComparableEntity, Base): __tablename__ = "b" id = Column(Integer, primary_key=True) a_id = Column(ForeignKey("a.id")) @@ -2955,7 +2955,7 @@ class SelfRefInheritanceAliasedTest( def setup_classes(cls): Base = cls.DeclarativeBasic - class Foo(fixtures.ComparableEntity, Base): + class Foo(ComparableEntity, Base): __tablename__ = "foo" id = Column(Integer, primary_key=True) type = Column(String(50)) @@ -3203,14 +3203,14 @@ class MissingForeignTest( def setup_classes(cls): Base = cls.DeclarativeBasic - class A(fixtures.ComparableEntity, Base): + class A(ComparableEntity, Base): __tablename__ = "a" id = Column(Integer, primary_key=True) b_id = Column(Integer) b = relationship("B", primaryjoin="foreign(A.b_id) == B.id") q = Column(Integer) - class B(fixtures.ComparableEntity, Base): + class B(ComparableEntity, Base): __tablename__ = "b" id = Column(Integer, primary_key=True) x = Column(Integer) @@ -3256,7 +3256,7 @@ class M2OWDegradeTest( def setup_classes(cls): Base = cls.DeclarativeBasic - class A(fixtures.ComparableEntity, Base): + class A(ComparableEntity, Base): __tablename__ = "a" id = Column(Integer, primary_key=True) b_id = Column(ForeignKey("b.id")) @@ -3264,7 +3264,7 @@ class M2OWDegradeTest( b_no_omit_join = relationship("B", omit_join=False, overlaps="b") q = Column(Integer) - class B(fixtures.ComparableEntity, Base): + class B(ComparableEntity, Base): __tablename__ = "b" id = Column(Integer, primary_key=True) x = Column(Integer) diff --git a/test/orm/test_subquery_relations.py b/test/orm/test_subquery_relations.py index 1a83a58be8..00564cfb65 100644 --- a/test/orm/test_subquery_relations.py +++ b/test/orm/test_subquery_relations.py @@ -1766,10 +1766,10 @@ class OrderBySecondaryTest(fixtures.MappedTest): def test_ordering(self): a, m2m, b = (self.tables.a, self.tables.m2m, self.tables.b) - class A(fixtures.ComparableEntity): + class A(ComparableEntity): pass - class B(fixtures.ComparableEntity): + class B(ComparableEntity): pass self.mapper_registry.map_imperatively( @@ -2440,7 +2440,7 @@ class SelfReferentialTest(fixtures.MappedTest): def test_basic(self): nodes = self.tables.nodes - class Node(fixtures.ComparableEntity): + class Node(ComparableEntity): def append(self, node): self.children.append(node) @@ -2516,7 +2516,7 @@ class SelfReferentialTest(fixtures.MappedTest): def test_lazy_fallback_doesnt_affect_eager(self): nodes = self.tables.nodes - class Node(fixtures.ComparableEntity): + class Node(ComparableEntity): def append(self, node): self.children.append(node) @@ -2562,7 +2562,7 @@ class SelfReferentialTest(fixtures.MappedTest): def test_with_deferred(self): nodes = self.tables.nodes - class Node(fixtures.ComparableEntity): + class Node(ComparableEntity): def append(self, node): self.children.append(node) @@ -2623,7 +2623,7 @@ class SelfReferentialTest(fixtures.MappedTest): def test_options(self): nodes = self.tables.nodes - class Node(fixtures.ComparableEntity): + class Node(ComparableEntity): def append(self, node): self.children.append(node) @@ -2680,7 +2680,7 @@ class SelfReferentialTest(fixtures.MappedTest): nodes = self.tables.nodes - class Node(fixtures.ComparableEntity): + class Node(ComparableEntity): def append(self, node): self.children.append(node) diff --git a/test/orm/test_unitofwork.py b/test/orm/test_unitofwork.py index 78b56f1d46..0937c354f9 100644 --- a/test/orm/test_unitofwork.py +++ b/test/orm/test_unitofwork.py @@ -31,6 +31,8 @@ from sqlalchemy.testing.assertions import expect_raises_message from sqlalchemy.testing.assertsql import AllOf from sqlalchemy.testing.assertsql import CompiledSQL from sqlalchemy.testing.assertsql import Conditional +from sqlalchemy.testing.entities import BasicEntity +from sqlalchemy.testing.entities import ComparableEntity from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.provision import normalize_sequence from sqlalchemy.testing.schema import Column @@ -209,10 +211,10 @@ class UnicodeSchemaTest(fixtures.MappedTest): def test_mapping(self): t2, t1 = self.tables.t2, self.tables.t1 - class A(fixtures.ComparableEntity): + class A(ComparableEntity): pass - class B(fixtures.ComparableEntity): + class B(ComparableEntity): pass self.mapper_registry.map_imperatively( @@ -251,7 +253,7 @@ class UnicodeSchemaTest(fixtures.MappedTest): def test_inheritance_mapping(self): t2, t1 = self.tables.t2, self.tables.t1 - class A(fixtures.ComparableEntity): + class A(ComparableEntity): pass class B(A): @@ -1030,7 +1032,7 @@ class ColumnCollisionTest(fixtures.MappedTest): def test_naming(self): book = self.tables.book - class Book(fixtures.ComparableEntity): + class Book(ComparableEntity): pass self.mapper_registry.map_imperatively(Book, book) @@ -1909,7 +1911,7 @@ class SaveTest(_fixtures.FixtureTest): def test_synonym(self): users = self.tables.users - class SUser(fixtures.BasicEntity): + class SUser(BasicEntity): def _get_name(self): return "User:" + self.name @@ -2773,7 +2775,7 @@ class ManyToManyTest(_fixtures.FixtureTest): self.classes.Item, ) - class IKAssociation(fixtures.ComparableEntity): + class IKAssociation(ComparableEntity): pass self.mapper_registry.map_imperatively(Keyword, keywords) @@ -3026,7 +3028,7 @@ class BooleanColTest(fixtures.MappedTest): t1_t = self.tables.t1_t # use the regular mapper - class T(fixtures.ComparableEntity): + class T(ComparableEntity): pass self.mapper_registry.map_imperatively(T, t1_t) diff --git a/test/orm/test_unitofworkv2.py b/test/orm/test_unitofworkv2.py index 5cf8bd573f..5ca34d9174 100644 --- a/test/orm/test_unitofworkv2.py +++ b/test/orm/test_unitofworkv2.py @@ -44,6 +44,8 @@ from sqlalchemy.testing.assertsql import AllOf from sqlalchemy.testing.assertsql import CompiledSQL from sqlalchemy.testing.assertsql import Conditional from sqlalchemy.testing.assertsql import RegexSQL +from sqlalchemy.testing.entities import BasicEntity +from sqlalchemy.testing.entities import ComparableEntity from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.provision import normalize_sequence from sqlalchemy.testing.schema import Column @@ -1372,10 +1374,10 @@ class SingleCyclePlusAttributeTest( def test_flush_size(self): foobars, nodes = self.tables.foobars, self.tables.nodes - class Node(fixtures.ComparableEntity): + class Node(ComparableEntity): pass - class FooBar(fixtures.ComparableEntity): + class FooBar(ComparableEntity): pass self.mapper_registry.map_imperatively( @@ -1440,7 +1442,7 @@ class SingleCycleM2MTest( def test_many_to_many_one(self): nodes, node_to_nodes = self.tables.nodes, self.tables.node_to_nodes - class Node(fixtures.ComparableEntity): + class Node(ComparableEntity): pass self.mapper_registry.map_imperatively( @@ -1584,10 +1586,10 @@ class RowswitchAccountingTest(fixtures.MappedTest): def _fixture(self): parent, child = self.tables.parent, self.tables.child - class Parent(fixtures.BasicEntity): + class Parent(BasicEntity): pass - class Child(fixtures.BasicEntity): + class Child(BasicEntity): pass self.mapper_registry.map_imperatively( @@ -1678,13 +1680,13 @@ class RowswitchM2OTest(fixtures.MappedTest): def _fixture(self): a, b, c = self.tables.a, self.tables.b, self.tables.c - class A(fixtures.BasicEntity): + class A(BasicEntity): pass - class B(fixtures.BasicEntity): + class B(BasicEntity): pass - class C(fixtures.BasicEntity): + class C(BasicEntity): pass self.mapper_registry.map_imperatively( @@ -1787,10 +1789,10 @@ class BasicStaleChecksTest(fixtures.MappedTest): def _fixture(self, confirm_deleted_rows=True): parent, child = self.tables.parent, self.tables.child - class Parent(fixtures.BasicEntity): + class Parent(BasicEntity): pass - class Child(fixtures.BasicEntity): + class Child(BasicEntity): pass self.mapper_registry.map_imperatively( @@ -2081,7 +2083,7 @@ class BatchInsertsTest(fixtures.MappedTest, testing.AssertsExecutionResults): t = self.tables.t - class T(fixtures.ComparableEntity): + class T(ComparableEntity): pass mp = self.mapper_registry.map_imperatively(T, t) diff --git a/test/orm/test_validators.py b/test/orm/test_validators.py index 990d6a4c4b..df7334d5cb 100644 --- a/test/orm/test_validators.py +++ b/test/orm/test_validators.py @@ -9,8 +9,8 @@ from sqlalchemy.orm import validates from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import eq_ -from sqlalchemy.testing import fixtures from sqlalchemy.testing import ne_ +from sqlalchemy.testing.entities import ComparableEntity from sqlalchemy.testing.fixtures import fixture_session from test.orm import _fixtures @@ -20,7 +20,7 @@ class ValidatorTest(_fixtures.FixtureTest): users = self.tables.users canary = Mock() - class User(fixtures.ComparableEntity): + class User(ComparableEntity): @validates("name") def validate_name(self, key, name): canary(key, name) @@ -52,7 +52,7 @@ class ValidatorTest(_fixtures.FixtureTest): canary = Mock() - class User(fixtures.ComparableEntity): + class User(ComparableEntity): @validates("addresses") def validate_address(self, key, ad): canary(key, ad) @@ -87,7 +87,7 @@ class ValidatorTest(_fixtures.FixtureTest): self.classes.Address, ) - class User(fixtures.ComparableEntity): + class User(ComparableEntity): @validates("name") def validate_name(self, key, name): ne_(name, "fred") @@ -119,7 +119,7 @@ class ValidatorTest(_fixtures.FixtureTest): ) canary = Mock() - class User(fixtures.ComparableEntity): + class User(ComparableEntity): @validates("name", include_removes=True) def validate_name(self, key, item, remove): canary(key, item, remove) @@ -175,7 +175,7 @@ class ValidatorTest(_fixtures.FixtureTest): self.classes.Address, ) - class User(fixtures.ComparableEntity): + class User(ComparableEntity): @validates("addresses", include_removes=True) def validate_address(self, key, item, remove): if not remove: @@ -210,7 +210,7 @@ class ValidatorTest(_fixtures.FixtureTest): self.classes.Address, ) - class User(fixtures.ComparableEntity): + class User(ComparableEntity): @validates("addresses", include_removes=True) def validate_address(self, key, item, remove): if not remove: @@ -264,7 +264,7 @@ class ValidatorTest(_fixtures.FixtureTest): ne_(name, "fred") return name + " modified" - class User(fixtures.ComparableEntity): + class User(ComparableEntity): sv = validates("name")(SomeValidator()) self.mapper_registry.map_imperatively(User, users) @@ -332,7 +332,7 @@ class ValidatorTest(_fixtures.FixtureTest): bool(include_removes) and not include_removes.default ) - class User(fixtures.ComparableEntity): + class User(ComparableEntity): if need_remove_param: @validates("addresses", **validate_kw) @@ -347,7 +347,7 @@ class ValidatorTest(_fixtures.FixtureTest): canary(key, item) return item - class Address(fixtures.ComparableEntity): + class Address(ComparableEntity): if need_remove_param: @validates("user", **validate_kw) diff --git a/test/ext/mypy/plain_files/pg_stuff.py b/test/typing/plain_files/dialects/postgresql/pg_stuff.py similarity index 100% rename from test/ext/mypy/plain_files/pg_stuff.py rename to test/typing/plain_files/dialects/postgresql/pg_stuff.py diff --git a/test/ext/mypy/plain_files/engine_inspection.py b/test/typing/plain_files/engine/engine_inspection.py similarity index 100% rename from test/ext/mypy/plain_files/engine_inspection.py rename to test/typing/plain_files/engine/engine_inspection.py diff --git a/test/typing/plain_files/engine/engines.py b/test/typing/plain_files/engine/engines.py new file mode 100644 index 0000000000..5777b91484 --- /dev/null +++ b/test/typing/plain_files/engine/engines.py @@ -0,0 +1,34 @@ +from sqlalchemy import create_engine +from sqlalchemy import Pool +from sqlalchemy import text + + +def regular() -> None: + e = create_engine("sqlite://") + + # EXPECTED_TYPE: Engine + reveal_type(e) + + with e.connect() as conn: + # EXPECTED_TYPE: Connection + reveal_type(conn) + + result = conn.execute(text("select * from table")) + + # EXPECTED_TYPE: CursorResult[Any] + reveal_type(result) + + with e.begin() as conn: + # EXPECTED_TYPE: Connection + reveal_type(conn) + + result = conn.execute(text("select * from table")) + + # EXPECTED_TYPE: CursorResult[Any] + reveal_type(result) + + engine = create_engine("postgresql://scott:tiger@localhost/test") + status: str = engine.pool.status() + other_pool: Pool = engine.pool.recreate() + + print(status, other_pool) diff --git a/test/ext/mypy/plain_files/association_proxy_one.py b/test/typing/plain_files/ext/association_proxy/association_proxy_one.py similarity index 100% rename from test/ext/mypy/plain_files/association_proxy_one.py rename to test/typing/plain_files/ext/association_proxy/association_proxy_one.py diff --git a/test/ext/mypy/plain_files/association_proxy_three.py b/test/typing/plain_files/ext/association_proxy/association_proxy_three.py similarity index 100% rename from test/ext/mypy/plain_files/association_proxy_three.py rename to test/typing/plain_files/ext/association_proxy/association_proxy_three.py diff --git a/test/ext/mypy/plain_files/association_proxy_two.py b/test/typing/plain_files/ext/association_proxy/association_proxy_two.py similarity index 100% rename from test/ext/mypy/plain_files/association_proxy_two.py rename to test/typing/plain_files/ext/association_proxy/association_proxy_two.py diff --git a/test/ext/mypy/plain_files/async_sessionmaker.py b/test/typing/plain_files/ext/asyncio/async_sessionmaker.py similarity index 94% rename from test/ext/mypy/plain_files/async_sessionmaker.py rename to test/typing/plain_files/ext/asyncio/async_sessionmaker.py index c253774e2e..664ff0411d 100644 --- a/test/ext/mypy/plain_files/async_sessionmaker.py +++ b/test/typing/plain_files/ext/asyncio/async_sessionmaker.py @@ -88,5 +88,9 @@ async def async_main() -> None: await session.commit() + trans_ctx = engine.begin() + async with trans_ctx as connection: + await connection.execute(select(A)) + asyncio.run(async_main()) diff --git a/test/typing/plain_files/ext/asyncio/async_stuff.py b/test/typing/plain_files/ext/asyncio/async_stuff.py new file mode 100644 index 0000000000..9afd0b8aff --- /dev/null +++ b/test/typing/plain_files/ext/asyncio/async_stuff.py @@ -0,0 +1,39 @@ +from asyncio import current_task + +from sqlalchemy import text +from sqlalchemy.ext.asyncio import async_scoped_session +from sqlalchemy.ext.asyncio import async_sessionmaker +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.ext.asyncio import create_async_engine + + +engine = create_async_engine("") +SM = async_sessionmaker(engine, class_=AsyncSession) + +async_session = AsyncSession(engine) + +as_session = async_scoped_session(SM, current_task) + + +async def go() -> None: + r = await async_session.scalars(text("select 1"), params=[]) + r.first() + sr = await async_session.stream_scalars(text("select 1"), params=[]) + await sr.all() + r = await as_session.scalars(text("select 1"), params=[]) + r.first() + sr = await as_session.stream_scalars(text("select 1"), params=[]) + await sr.all() + + async with engine.connect() as conn: + cr = await conn.scalars(text("select 1")) + cr.first() + scr = await conn.stream_scalars(text("select 1")) + await scr.all() + + ast = async_session.get_transaction() + if ast: + ast.is_active + nt = async_session.get_nested_transaction() + if nt: + nt.is_active diff --git a/test/typing/plain_files/ext/asyncio/create_proxy_methods.py b/test/typing/plain_files/ext/asyncio/create_proxy_methods.py new file mode 100644 index 0000000000..235cf32ced --- /dev/null +++ b/test/typing/plain_files/ext/asyncio/create_proxy_methods.py @@ -0,0 +1,97 @@ +from sqlalchemy import text +from sqlalchemy.ext.asyncio import async_scoped_session +from sqlalchemy.ext.asyncio import AsyncConnection +from sqlalchemy.ext.asyncio import AsyncEngine +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.ext.asyncio import create_async_engine +from sqlalchemy.ext.asyncio.session import async_sessionmaker + +# async engine +async_engine: AsyncEngine = create_async_engine("") +async_engine.clear_compiled_cache() +async_engine.update_execution_options() +async_engine.get_execution_options() +async_engine.url +async_engine.pool +async_engine.dialect +async_engine.engine +async_engine.name +async_engine.driver +async_engine.echo + + +# async connection +async def go_async_conn() -> None: + async_conn: AsyncConnection = await async_engine.connect() + async_conn.closed + async_conn.invalidated + async_conn.dialect + async_conn.default_isolation_level + + +# async session +AsyncSession.object_session(object()) +AsyncSession.identity_key() +async_session: AsyncSession = AsyncSession(async_engine) +in_: bool = "foo" in async_session +list(async_session) +async_session.add(object()) +async_session.add_all([]) +async_session.expire(object()) +async_session.expire_all() +async_session.expunge(object()) +async_session.expunge_all() +async_session.get_bind() +async_session.is_modified(object()) +async_session.in_transaction() +async_session.in_nested_transaction() +async_session.dirty +async_session.deleted +async_session.new +async_session.identity_map +async_session.is_active +async_session.autoflush +async_session.no_autoflush +async_session.info + + +# async scoped session +async def test_async_scoped_session() -> None: + async_scoped_session.object_session(object()) + async_scoped_session.identity_key() + await async_scoped_session.close_all() + asm = async_sessionmaker() + async_ss = async_scoped_session(asm, lambda: 42) + value: bool = "foo" in async_ss + print(value) + list(async_ss) + async_ss.add(object()) + async_ss.add_all([]) + async_ss.begin() + async_ss.begin_nested() + await async_ss.close() + await async_ss.commit() + await async_ss.connection() + await async_ss.delete(object()) + await async_ss.execute(text("select 1")) + async_ss.expire(object()) + async_ss.expire_all() + async_ss.expunge(object()) + async_ss.expunge_all() + await async_ss.flush() + await async_ss.get(object, 1) + async_ss.get_bind() + async_ss.is_modified(object()) + await async_ss.merge(object()) + await async_ss.refresh(object()) + await async_ss.rollback() + await async_ss.scalar(text("select 1")) + async_ss.bind + async_ss.dirty + async_ss.deleted + async_ss.new + async_ss.identity_map + async_ss.is_active + async_ss.autoflush + async_ss.no_autoflush + async_ss.info diff --git a/test/ext/mypy/plain_files/engines.py b/test/typing/plain_files/ext/asyncio/engines.py similarity index 73% rename from test/ext/mypy/plain_files/engines.py rename to test/typing/plain_files/ext/asyncio/engines.py index b7621aca42..598d319a77 100644 --- a/test/ext/mypy/plain_files/engines.py +++ b/test/typing/plain_files/ext/asyncio/engines.py @@ -1,33 +1,7 @@ -from sqlalchemy import create_engine from sqlalchemy import text from sqlalchemy.ext.asyncio import create_async_engine -def regular() -> None: - e = create_engine("sqlite://") - - # EXPECTED_TYPE: Engine - reveal_type(e) - - with e.connect() as conn: - # EXPECTED_TYPE: Connection - reveal_type(conn) - - result = conn.execute(text("select * from table")) - - # EXPECTED_TYPE: CursorResult[Any] - reveal_type(result) - - with e.begin() as conn: - # EXPECTED_TYPE: Connection - reveal_type(conn) - - result = conn.execute(text("select * from table")) - - # EXPECTED_TYPE: CursorResult[Any] - reveal_type(result) - - async def asyncio() -> None: e = create_async_engine("sqlite://") diff --git a/test/ext/mypy/plain_files/hybrid_four.py b/test/typing/plain_files/ext/hybrid/hybrid_four.py similarity index 100% rename from test/ext/mypy/plain_files/hybrid_four.py rename to test/typing/plain_files/ext/hybrid/hybrid_four.py diff --git a/test/ext/mypy/plain_files/hybrid_one.py b/test/typing/plain_files/ext/hybrid/hybrid_one.py similarity index 100% rename from test/ext/mypy/plain_files/hybrid_one.py rename to test/typing/plain_files/ext/hybrid/hybrid_one.py diff --git a/test/ext/mypy/plain_files/hybrid_three.py b/test/typing/plain_files/ext/hybrid/hybrid_three.py similarity index 100% rename from test/ext/mypy/plain_files/hybrid_three.py rename to test/typing/plain_files/ext/hybrid/hybrid_three.py diff --git a/test/ext/mypy/plain_files/hybrid_two.py b/test/typing/plain_files/ext/hybrid/hybrid_two.py similarity index 100% rename from test/ext/mypy/plain_files/hybrid_two.py rename to test/typing/plain_files/ext/hybrid/hybrid_two.py diff --git a/test/ext/mypy/inspection_inspect.py b/test/typing/plain_files/inspection_inspect.py similarity index 78% rename from test/ext/mypy/inspection_inspect.py rename to test/typing/plain_files/inspection_inspect.py index c67b515f40..155ceffc03 100644 --- a/test/ext/mypy/inspection_inspect.py +++ b/test/typing/plain_files/inspection_inspect.py @@ -4,16 +4,21 @@ test inspect() however this is not really working """ +from typing import Any +from typing import Optional + from sqlalchemy import Column from sqlalchemy import create_engine from sqlalchemy import inspect from sqlalchemy import Integer from sqlalchemy import String from sqlalchemy.engine.reflection import Inspector -from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import Mapper -Base = declarative_base() + +class Base(DeclarativeBase): + pass class A(Base): @@ -30,9 +35,9 @@ e = create_engine("sqlite://") # TODO: I can't get these to work, pylance and mypy both don't want # to accommodate for different types for the first argument -t: bool = inspect(a1).transient +t: Optional[Any] = inspect(a1) -m: Mapper = inspect(A) +m: Mapper[Any] = inspect(A) inspect(e).get_table_names() diff --git a/test/ext/mypy/plugin_files/complete_orm_no_plugin.py b/test/typing/plain_files/orm/complete_orm_no_plugin.py similarity index 99% rename from test/ext/mypy/plugin_files/complete_orm_no_plugin.py rename to test/typing/plain_files/orm/complete_orm_no_plugin.py index 53291501ad..b22057a2f3 100644 --- a/test/ext/mypy/plugin_files/complete_orm_no_plugin.py +++ b/test/typing/plain_files/orm/complete_orm_no_plugin.py @@ -1,4 +1,3 @@ -# NOPLUGINS # this should pass typing with no plugins from typing import Any diff --git a/test/ext/mypy/plain_files/composite.py b/test/typing/plain_files/orm/composite.py similarity index 100% rename from test/ext/mypy/plain_files/composite.py rename to test/typing/plain_files/orm/composite.py diff --git a/test/ext/mypy/plain_files/composite_dc.py b/test/typing/plain_files/orm/composite_dc.py similarity index 100% rename from test/ext/mypy/plain_files/composite_dc.py rename to test/typing/plain_files/orm/composite_dc.py diff --git a/test/ext/mypy/plain_files/dataclass_transforms_one.py b/test/typing/plain_files/orm/dataclass_transforms_one.py similarity index 100% rename from test/ext/mypy/plain_files/dataclass_transforms_one.py rename to test/typing/plain_files/orm/dataclass_transforms_one.py diff --git a/test/ext/mypy/plain_files/declared_attr_one.py b/test/typing/plain_files/orm/declared_attr_one.py similarity index 98% rename from test/ext/mypy/plain_files/declared_attr_one.py rename to test/typing/plain_files/orm/declared_attr_one.py index 86f8cf7704..fc304db87e 100644 --- a/test/ext/mypy/plain_files/declared_attr_one.py +++ b/test/typing/plain_files/orm/declared_attr_one.py @@ -30,7 +30,7 @@ class Employee(Base): } __table_args__ = ( - Index("my_index", name, type), + Index("my_index", name, type.desc()), UniqueConstraint(name), PrimaryKeyConstraint(id), {"prefix": []}, diff --git a/test/ext/mypy/plain_files/declared_attr_two.py b/test/typing/plain_files/orm/declared_attr_two.py similarity index 100% rename from test/ext/mypy/plain_files/declared_attr_two.py rename to test/typing/plain_files/orm/declared_attr_two.py diff --git a/test/ext/mypy/plain_files/dynamic_rel.py b/test/typing/plain_files/orm/dynamic_rel.py similarity index 100% rename from test/ext/mypy/plain_files/dynamic_rel.py rename to test/typing/plain_files/orm/dynamic_rel.py diff --git a/test/ext/mypy/plain_files/issue_9340.py b/test/typing/plain_files/orm/issue_9340.py similarity index 100% rename from test/ext/mypy/plain_files/issue_9340.py rename to test/typing/plain_files/orm/issue_9340.py diff --git a/test/ext/mypy/plain_files/keyfunc_dict.py b/test/typing/plain_files/orm/keyfunc_dict.py similarity index 100% rename from test/ext/mypy/plain_files/keyfunc_dict.py rename to test/typing/plain_files/orm/keyfunc_dict.py diff --git a/test/typing/plain_files/orm/mapped_assign_expression.py b/test/typing/plain_files/orm/mapped_assign_expression.py new file mode 100644 index 0000000000..e68b4b44a7 --- /dev/null +++ b/test/typing/plain_files/orm/mapped_assign_expression.py @@ -0,0 +1,27 @@ +from datetime import datetime + +from sqlalchemy import create_engine +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import registry +from sqlalchemy.orm import Session +from sqlalchemy.sql.functions import now +from sqlalchemy.testing.schema import mapped_column + +mapper_registry: registry = registry() +e = create_engine("sqlite:///database.db", echo=True) + + +@mapper_registry.mapped +class A: + __tablename__ = "a" + id: Mapped[int] = mapped_column(primary_key=True) + date_time: Mapped[datetime] + + +mapper_registry.metadata.create_all(e) + +with Session(e) as s: + a = A() + a.date_time = now() + s.add(a) + s.commit() diff --git a/test/ext/mypy/plain_files/mapped_column.py b/test/typing/plain_files/orm/mapped_column.py similarity index 100% rename from test/ext/mypy/plain_files/mapped_column.py rename to test/typing/plain_files/orm/mapped_column.py diff --git a/test/ext/mypy/plain_files/orm_config_constructs.py b/test/typing/plain_files/orm/orm_config_constructs.py similarity index 100% rename from test/ext/mypy/plain_files/orm_config_constructs.py rename to test/typing/plain_files/orm/orm_config_constructs.py diff --git a/test/ext/mypy/plain_files/orm_querying.py b/test/typing/plain_files/orm/orm_querying.py similarity index 100% rename from test/ext/mypy/plain_files/orm_querying.py rename to test/typing/plain_files/orm/orm_querying.py diff --git a/test/ext/mypy/plain_files/experimental_relationship.py b/test/typing/plain_files/orm/relationship.py similarity index 95% rename from test/ext/mypy/plain_files/experimental_relationship.py rename to test/typing/plain_files/orm/relationship.py index 7acec89e18..ddd51e21e4 100644 --- a/test/ext/mypy/plain_files/experimental_relationship.py +++ b/test/typing/plain_files/orm/relationship.py @@ -85,8 +85,8 @@ if typing.TYPE_CHECKING: # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.str\*?\] reveal_type(Address.email_name) - # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*?\[experimental_relationship.Address\]\] + # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*?\[relationship.Address\]\] reveal_type(User.addresses_style_one) - # EXPECTED_RE_TYPE: sqlalchemy.orm.attributes.InstrumentedAttribute\[builtins.set\*?\[experimental_relationship.Address\]\] + # EXPECTED_RE_TYPE: sqlalchemy.orm.attributes.InstrumentedAttribute\[builtins.set\*?\[relationship.Address\]\] reveal_type(User.addresses_style_two) diff --git a/test/typing/plain_files/orm/scoped_session.py b/test/typing/plain_files/orm/scoped_session.py new file mode 100644 index 0000000000..9809901902 --- /dev/null +++ b/test/typing/plain_files/orm/scoped_session.py @@ -0,0 +1,58 @@ +from sqlalchemy import inspect +from sqlalchemy import text +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column +from sqlalchemy.orm import scoped_session +from sqlalchemy.orm import sessionmaker + + +class Base(DeclarativeBase): + pass + + +class X(Base): + __tablename__ = "x" + id: Mapped[int] = mapped_column(primary_key=True) + + +scoped_session.object_session(object()) +scoped_session.identity_key() +scoped_session.close_all() +ss = scoped_session(sessionmaker()) +value: bool = "foo" in ss +list(ss) +ss.add(object()) +ss.add_all([]) +ss.begin() +ss.begin_nested() +ss.close() +ss.commit() +ss.connection() +ss.delete(object()) +ss.execute(text("select 1")) +ss.expire(object()) +ss.expire_all() +ss.expunge(object()) +ss.expunge_all() +ss.flush() +ss.get(object, 1) +b = ss.get_bind() +ss.is_modified(object()) +ss.bulk_save_objects([]) +ss.bulk_insert_mappings(inspect(X), []) +ss.bulk_update_mappings(inspect(X), []) +ss.merge(object()) +q = (ss.query(object),) +ss.refresh(object()) +ss.rollback() +ss.scalar(text("select 1")) +ss.bind +ss.dirty +ss.deleted +ss.new +ss.identity_map +ss.is_active +ss.autoflush +ss.no_autoflush +ss.info diff --git a/test/ext/mypy/plain_files/session.py b/test/typing/plain_files/orm/session.py similarity index 100% rename from test/ext/mypy/plain_files/session.py rename to test/typing/plain_files/orm/session.py diff --git a/test/ext/mypy/plain_files/sessionmakers.py b/test/typing/plain_files/orm/sessionmakers.py similarity index 100% rename from test/ext/mypy/plain_files/sessionmakers.py rename to test/typing/plain_files/orm/sessionmakers.py diff --git a/test/ext/mypy/plain_files/trad_relationship_uselist.py b/test/typing/plain_files/orm/trad_relationship_uselist.py similarity index 100% rename from test/ext/mypy/plain_files/trad_relationship_uselist.py rename to test/typing/plain_files/orm/trad_relationship_uselist.py diff --git a/test/ext/mypy/plain_files/traditional_relationship.py b/test/typing/plain_files/orm/traditional_relationship.py similarity index 100% rename from test/ext/mypy/plain_files/traditional_relationship.py rename to test/typing/plain_files/orm/traditional_relationship.py diff --git a/test/ext/mypy/plain_files/typed_queries.py b/test/typing/plain_files/orm/typed_queries.py similarity index 100% rename from test/ext/mypy/plain_files/typed_queries.py rename to test/typing/plain_files/orm/typed_queries.py diff --git a/test/ext/mypy/plain_files/write_only.py b/test/typing/plain_files/orm/write_only.py similarity index 100% rename from test/ext/mypy/plain_files/write_only.py rename to test/typing/plain_files/orm/write_only.py diff --git a/test/ext/mypy/plain_files/common_sql_element.py b/test/typing/plain_files/sql/common_sql_element.py similarity index 100% rename from test/ext/mypy/plain_files/common_sql_element.py rename to test/typing/plain_files/sql/common_sql_element.py diff --git a/test/typing/plain_files/sql/core_ddl.py b/test/typing/plain_files/sql/core_ddl.py new file mode 100644 index 0000000000..b7e0ec5350 --- /dev/null +++ b/test/typing/plain_files/sql/core_ddl.py @@ -0,0 +1,151 @@ +from sqlalchemy import Boolean +from sqlalchemy import CheckConstraint +from sqlalchemy import Column +from sqlalchemy import DateTime +from sqlalchemy import FetchedValue +from sqlalchemy import ForeignKey +from sqlalchemy import func +from sqlalchemy import Index +from sqlalchemy import Integer +from sqlalchemy import literal_column +from sqlalchemy import MetaData +from sqlalchemy import PrimaryKeyConstraint +from sqlalchemy import String +from sqlalchemy import Table +from sqlalchemy import text +from sqlalchemy import true +from sqlalchemy import UUID + + +m = MetaData() + + +t1 = Table( + "t1", + m, + Column("id", Integer, primary_key=True), + Column("data", String), + Column("data2", String(50)), + Column("timestamp", DateTime()), + Index(None, "data2"), +) + +t2 = Table( + "t2", + m, + Column("t1id", ForeignKey("t1.id")), + Column("q", Integer, CheckConstraint("q > 5")), +) + +t3 = Table( + "t3", + m, + Column("x", Integer), + Column("y", Integer), + Column("t1id", ForeignKey(t1.c.id)), + PrimaryKeyConstraint("x", "y"), +) + +t4 = Table( + "test_table", + m, + Column("i", UUID(as_uuid=True), nullable=False, primary_key=True), + Column("x", UUID(as_uuid=True), index=True), + Column("y", UUID(as_uuid=False), index=True), + Index("ix_xy_unique", "x", "y", unique=True), +) + + +# cols w/ no name or type, used by declarative +c1: Column[int] = Column(ForeignKey(t3.c.x)) +# more colum args +Column("name", Integer, index=True) +Column(None, name="name") +Column(Integer, name="name", index=True) +Column("name", ForeignKey("a.id")) +Column(ForeignKey("a.id"), type_=None, index=True) +Column(ForeignKey("a.id"), name="name", type_=Integer()) +Column("name", None) +Column("name", index=True) +Column(ForeignKey("a.id"), name="name", index=True) +Column(type_=None, index=True) +Column(None, ForeignKey("a.id")) +Column("name") +Column(name="name", type_=None, index=True) +Column(ForeignKey("a.id"), name="name", type_=None) +Column(Integer) +Column(ForeignKey("a.id"), type_=Integer()) +Column("name", Integer, ForeignKey("a.id"), index=True) +Column("name", None, ForeignKey("a.id"), index=True) +Column(ForeignKey("a.id"), index=True) +Column("name", Integer) +Column(Integer, name="name") +Column(Integer, ForeignKey("a.id"), name="name", index=True) +Column(ForeignKey("a.id"), type_=None) +Column(ForeignKey("a.id"), name="name") +Column(name="name", index=True) +Column(type_=None) +Column(None, index=True) +Column(name="name", type_=None) +Column(type_=Integer(), index=True) +Column("name", Integer, ForeignKey("a.id")) +Column(name="name", type_=Integer(), index=True) +Column(Integer, ForeignKey("a.id"), index=True) +Column("name", None, ForeignKey("a.id")) +Column(index=True) +Column("name", type_=None, index=True) +Column("name", ForeignKey("a.id"), type_=Integer(), index=True) +Column(ForeignKey("a.id")) +Column(Integer, ForeignKey("a.id")) +Column(Integer, ForeignKey("a.id"), name="name") +Column("name", ForeignKey("a.id"), index=True) +Column("name", type_=Integer(), index=True) +Column(ForeignKey("a.id"), name="name", type_=Integer(), index=True) +Column(name="name") +Column("name", None, index=True) +Column("name", ForeignKey("a.id"), type_=None, index=True) +Column("name", type_=Integer()) +Column(None) +Column(None, ForeignKey("a.id"), index=True) +Column("name", ForeignKey("a.id"), type_=None) +Column(type_=Integer()) +Column(None, ForeignKey("a.id"), name="name", index=True) +Column(Integer, index=True) +Column(ForeignKey("a.id"), name="name", type_=None, index=True) +Column(ForeignKey("a.id"), type_=Integer(), index=True) +Column(name="name", type_=Integer()) +Column(None, name="name", index=True) +Column() +Column(None, ForeignKey("a.id"), name="name") +Column("name", type_=None) +Column("name", ForeignKey("a.id"), type_=Integer()) + +# server_default +Column(Boolean, nullable=False, server_default=true()) +Column(DateTime, server_default=func.now(), nullable=False) +Column(Boolean, server_default=func.xyzq(), nullable=False) +# what would be *nice* to emit an error would be this, but this +# is really not important, people don't usually put types in functions +# as they are usually part of a bigger context where the type is known +Column(Boolean, server_default=func.xyzq(type_=DateTime), nullable=False) +Column(DateTime, server_default="now()") +Column(DateTime, server_default=text("now()")) +Column(DateTime, server_default=FetchedValue()) +Column(Boolean, server_default=literal_column("false", Boolean)) +Column("name", server_default=FetchedValue(), nullable=False) +Column(server_default="now()", nullable=False) +Column("name", Integer, server_default=text("now()"), nullable=False) +Column(Integer, server_default=literal_column("42", Integer), nullable=False) + +# server_onupdate +Column("name", server_onupdate=FetchedValue(), nullable=False) +Column(server_onupdate=FetchedValue(), nullable=False) +Column("name", Integer, server_onupdate=FetchedValue(), nullable=False) +Column(Integer, server_onupdate=FetchedValue(), nullable=False) + +# TypeEngine.with_variant should accept both a TypeEngine instance and the Concrete Type +Integer().with_variant(Integer, "mysql") +Integer().with_variant(Integer(), "mysql") +# Also test Variant.with_variant +Integer().with_variant(Integer, "mysql").with_variant(Integer, "mysql") +Integer().with_variant(Integer, "mysql").with_variant(Integer(), "mysql") diff --git a/test/ext/mypy/plain_files/dml.py b/test/typing/plain_files/sql/dml.py similarity index 100% rename from test/ext/mypy/plain_files/dml.py rename to test/typing/plain_files/sql/dml.py diff --git a/test/ext/mypy/plain_files/functions.py b/test/typing/plain_files/sql/functions.py similarity index 100% rename from test/ext/mypy/plain_files/functions.py rename to test/typing/plain_files/sql/functions.py diff --git a/test/typing/plain_files/sql/functions_again.py b/test/typing/plain_files/sql/functions_again.py new file mode 100644 index 0000000000..edfbd6bb2b --- /dev/null +++ b/test/typing/plain_files/sql/functions_again.py @@ -0,0 +1,23 @@ +from sqlalchemy import func +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column + + +class Base(DeclarativeBase): + pass + + +class Foo(Base): + __tablename__ = "foo" + + id: Mapped[int] = mapped_column(primary_key=True) + a: Mapped[int] + b: Mapped[int] + + +func.row_number().over(order_by=Foo.a, partition_by=Foo.b.desc()) +func.row_number().over(order_by=[Foo.a.desc(), Foo.b.desc()]) +func.row_number().over(partition_by=[Foo.a.desc(), Foo.b.desc()]) +func.row_number().over(order_by="a", partition_by=("a", "b")) +func.row_number().over(partition_by="a", order_by=("a", "b")) diff --git a/test/ext/mypy/plain_files/lambda_stmt.py b/test/typing/plain_files/sql/lambda_stmt.py similarity index 100% rename from test/ext/mypy/plain_files/lambda_stmt.py rename to test/typing/plain_files/sql/lambda_stmt.py diff --git a/test/typing/plain_files/sql/lowercase_objects.py b/test/typing/plain_files/sql/lowercase_objects.py new file mode 100644 index 0000000000..ab26d7ede3 --- /dev/null +++ b/test/typing/plain_files/sql/lowercase_objects.py @@ -0,0 +1,16 @@ +import sqlalchemy as sa + +Book = sa.table( + "book", + sa.column("id", sa.Integer), + sa.column("name", sa.String), +) +Book.append_column(sa.column("other")) +Book.corresponding_column(Book.c.id) + +value_expr = sa.values( + sa.column("id", sa.Integer), sa.column("name", sa.String), name="my_values" +).data([(1, "name1"), (2, "name2"), (3, "name3")]) + +sa.select(Book) +sa.select(sa.literal_column("42"), sa.column("foo")).select_from(sa.table("t")) diff --git a/test/typing/plain_files/sql/operators.py b/test/typing/plain_files/sql/operators.py new file mode 100644 index 0000000000..41981d155b --- /dev/null +++ b/test/typing/plain_files/sql/operators.py @@ -0,0 +1,137 @@ +from decimal import Decimal +from typing import Any + +from sqlalchemy import ARRAY +from sqlalchemy import BigInteger +from sqlalchemy import column +from sqlalchemy import ColumnElement +from sqlalchemy import Integer +from sqlalchemy import String +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column + + +class Base(DeclarativeBase): + pass + + +class A(Base): + __tablename__ = "a" + id: Mapped[int] + string: Mapped[str] + arr: Mapped[list[int]] = mapped_column(ARRAY(Integer)) + + +lt1: "ColumnElement[bool]" = A.id > A.id +lt2: "ColumnElement[bool]" = A.id > 1 +lt3: "ColumnElement[bool]" = 1 < A.id + +le1: "ColumnElement[bool]" = A.id >= A.id +le2: "ColumnElement[bool]" = A.id >= 1 +le3: "ColumnElement[bool]" = 1 <= A.id + +eq1: "ColumnElement[bool]" = A.id == A.id +eq2: "ColumnElement[bool]" = A.id == 1 +# eq3: "ColumnElement[bool]" = 1 == A.id + +ne1: "ColumnElement[bool]" = A.id != A.id +ne2: "ColumnElement[bool]" = A.id != 1 +# ne3: "ColumnElement[bool]" = 1 != A.id + +gt1: "ColumnElement[bool]" = A.id < A.id +gt2: "ColumnElement[bool]" = A.id < 1 +gt3: "ColumnElement[bool]" = 1 > A.id + +ge1: "ColumnElement[bool]" = A.id <= A.id +ge2: "ColumnElement[bool]" = A.id <= 1 +ge3: "ColumnElement[bool]" = 1 >= A.id + + +# TODO "in" doesn't seem to pick up the typing of __contains__? +# EXPECTED_MYPY: Incompatible types in assignment (expression has type "bool", variable has type "ColumnElement[bool]") # noqa: E501 +contains1: "ColumnElement[bool]" = A.id in A.arr +# EXPECTED_MYPY: Incompatible types in assignment (expression has type "bool", variable has type "ColumnElement[bool]") # noqa: E501 +contains2: "ColumnElement[bool]" = A.id in A.string + +lshift1: "ColumnElement[int]" = A.id << A.id +lshift2: "ColumnElement[int]" = A.id << 1 +lshift3: "ColumnElement[Any]" = A.string << 1 + +rshift1: "ColumnElement[int]" = A.id >> A.id +rshift2: "ColumnElement[int]" = A.id >> 1 +rshift3: "ColumnElement[Any]" = A.string >> 1 + +concat1: "ColumnElement[str]" = A.string.concat(A.string) +concat2: "ColumnElement[str]" = A.string.concat(1) +concat3: "ColumnElement[str]" = A.string.concat("a") + +like1: "ColumnElement[bool]" = A.string.like("test") +like2: "ColumnElement[bool]" = A.string.like("test", escape="/") +ilike1: "ColumnElement[bool]" = A.string.ilike("test") +ilike2: "ColumnElement[bool]" = A.string.ilike("test", escape="/") + +in_: "ColumnElement[bool]" = A.id.in_([1, 2]) +not_in: "ColumnElement[bool]" = A.id.not_in([1, 2]) + +not_like1: "ColumnElement[bool]" = A.string.not_like("test") +not_like2: "ColumnElement[bool]" = A.string.not_like("test", escape="/") +not_ilike1: "ColumnElement[bool]" = A.string.not_ilike("test") +not_ilike2: "ColumnElement[bool]" = A.string.not_ilike("test", escape="/") + +is_: "ColumnElement[bool]" = A.string.is_("test") +is_not: "ColumnElement[bool]" = A.string.is_not("test") + +startswith: "ColumnElement[bool]" = A.string.startswith("test") +endswith: "ColumnElement[bool]" = A.string.endswith("test") +contains: "ColumnElement[bool]" = A.string.contains("test") +match: "ColumnElement[bool]" = A.string.match("test") +regexp_match: "ColumnElement[bool]" = A.string.regexp_match("test") + +regexp_replace: "ColumnElement[str]" = A.string.regexp_replace( + "pattern", "replacement" +) +between: "ColumnElement[bool]" = A.string.between("a", "b") + +adds: "ColumnElement[str]" = A.string + A.string +add1: "ColumnElement[int]" = A.id + A.id +add2: "ColumnElement[int]" = A.id + 1 +add3: "ColumnElement[int]" = 1 + A.id + +sub1: "ColumnElement[int]" = A.id - A.id +sub2: "ColumnElement[int]" = A.id - 1 +sub3: "ColumnElement[int]" = 1 - A.id + +mul1: "ColumnElement[int]" = A.id * A.id +mul2: "ColumnElement[int]" = A.id * 1 +mul3: "ColumnElement[int]" = 1 * A.id + +div1: "ColumnElement[float|Decimal]" = A.id / A.id +div2: "ColumnElement[float|Decimal]" = A.id / 1 +div3: "ColumnElement[float|Decimal]" = 1 / A.id + +mod1: "ColumnElement[int]" = A.id % A.id +mod2: "ColumnElement[int]" = A.id % 1 +mod3: "ColumnElement[int]" = 1 % A.id + +# unary + +neg: "ColumnElement[int]" = -A.id + +desc: "ColumnElement[int]" = A.id.desc() +asc: "ColumnElement[int]" = A.id.asc() +any_: "ColumnElement[bool]" = A.id.any_() +all_: "ColumnElement[bool]" = A.id.all_() +nulls_first: "ColumnElement[int]" = A.id.nulls_first() +nulls_last: "ColumnElement[int]" = A.id.nulls_last() +collate: "ColumnElement[str]" = A.string.collate("somelang") +distinct: "ColumnElement[int]" = A.id.distinct() + + +# custom ops +col = column("flags", Integer) +op_a: "ColumnElement[Any]" = col.op("&")(1) +op_b: "ColumnElement[int]" = col.op("&", return_type=Integer)(1) +op_c: "ColumnElement[str]" = col.op("&", return_type=String)("1") +op_d: "ColumnElement[int]" = col.op("&", return_type=BigInteger)("1") +op_e: "ColumnElement[bool]" = col.bool_op("&")("1") diff --git a/test/ext/mypy/plain_files/selectables.py b/test/typing/plain_files/sql/selectables.py similarity index 100% rename from test/ext/mypy/plain_files/selectables.py rename to test/typing/plain_files/sql/selectables.py diff --git a/test/ext/mypy/plain_files/sql_operations.py b/test/typing/plain_files/sql/sql_operations.py similarity index 100% rename from test/ext/mypy/plain_files/sql_operations.py rename to test/typing/plain_files/sql/sql_operations.py diff --git a/test/ext/mypy/plain_files/sqltypes.py b/test/typing/plain_files/sql/sqltypes.py similarity index 100% rename from test/ext/mypy/plain_files/sqltypes.py rename to test/typing/plain_files/sql/sqltypes.py diff --git a/test/ext/mypy/plain_files/typed_results.py b/test/typing/plain_files/sql/typed_results.py similarity index 100% rename from test/ext/mypy/plain_files/typed_results.py rename to test/typing/plain_files/sql/typed_results.py diff --git a/test/typing/test_mypy.py b/test/typing/test_mypy.py new file mode 100644 index 0000000000..14d13bd6f5 --- /dev/null +++ b/test/typing/test_mypy.py @@ -0,0 +1,17 @@ +import os + +from sqlalchemy import testing +from sqlalchemy.testing import fixtures + + +class MypyPlainTest(fixtures.MypyTest): + @testing.combinations( + *( + (os.path.basename(path), path) + for path in fixtures.MypyTest.file_combinations("plain_files") + ), + argnames="path", + id_="ia", + ) + def test_mypy_no_plugin(self, mypy_typecheck_file, path): + mypy_typecheck_file(path) diff --git a/test/ext/mypy/test_overloads.py b/test/typing/test_overloads.py similarity index 100% rename from test/ext/mypy/test_overloads.py rename to test/typing/test_overloads.py diff --git a/tools/generate_sql_functions.py b/tools/generate_sql_functions.py index 5845e89ad9..848a927225 100644 --- a/tools/generate_sql_functions.py +++ b/tools/generate_sql_functions.py @@ -136,7 +136,7 @@ def main(cmd: code_writer_cmd) -> None: functions_py = "lib/sqlalchemy/sql/functions.py" -test_functions_py = "test/ext/mypy/plain_files/functions.py" +test_functions_py = "test/typing/plain_files/sql/functions.py" if __name__ == "__main__":