state._modified_event(dict_, self, attributes.NEVER_SET)
- # this is a hack to allow the fixtures.ComparableEntity fixture
+ # this is a hack to allow the entities.ComparableEntity fixture
# to work
dict_[self.key] = True
return state.committed_state[self.key]
from __future__ import annotations
+from argparse import Namespace
import collections
import inspect
import typing
any_async = False
_current = None
ident = "main"
+options: Namespace = None # type: ignore
if typing.TYPE_CHECKING:
from .plugin.plugin_base import FixtureFunctions
+++ /dev/null
-# testing/fixtures.py
-# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
-# <see AUTHORS file>
-#
-# This module is part of SQLAlchemy and is released under
-# the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
-
-
-from __future__ import annotations
-
-import itertools
-import random
-import re
-import sys
-from typing import Any
-
-import sqlalchemy as sa
-from . import assertions
-from . import config
-from . import mock
-from . import schema
-from .assertions import eq_
-from .assertions import ne_
-from .entities import BasicEntity
-from .entities import ComparableEntity
-from .entities import ComparableMixin # noqa
-from .util import adict
-from .util import drop_all_tables_from_metadata
-from .. import Column
-from .. import event
-from .. import func
-from .. import Integer
-from .. import select
-from .. import Table
-from .. import util
-from ..orm import DeclarativeBase
-from ..orm import events as orm_events
-from ..orm import MappedAsDataclass
-from ..orm import registry
-from ..schema import sort_tables_and_constraints
-from ..sql import visitors
-from ..sql.elements import ClauseElement
-
-
-@config.mark_base_test_class()
-class TestBase:
- # A sequence of requirement names matching testing.requires decorators
- __requires__ = ()
-
- # A sequence of dialect names to exclude from the test class.
- __unsupported_on__ = ()
-
- # If present, test class is only runnable for the *single* specified
- # dialect. If you need multiple, use __unsupported_on__ and invert.
- __only_on__ = None
-
- # A sequence of no-arg callables. If any are True, the entire testcase is
- # skipped.
- __skip_if__ = None
-
- # if True, the testing reaper will not attempt to touch connection
- # state after a test is completed and before the outer teardown
- # starts
- __leave_connections_for_teardown__ = False
-
- def assert_(self, val, msg=None):
- assert val, msg
-
- @config.fixture()
- def nocache(self):
- _cache = config.db._compiled_cache
- config.db._compiled_cache = None
- yield
- config.db._compiled_cache = _cache
-
- @config.fixture()
- def connection_no_trans(self):
- eng = getattr(self, "bind", None) or config.db
-
- with eng.connect() as conn:
- yield conn
-
- @config.fixture()
- def connection(self):
- global _connection_fixture_connection
-
- eng = getattr(self, "bind", None) or config.db
-
- conn = eng.connect()
- trans = conn.begin()
-
- _connection_fixture_connection = conn
- yield conn
-
- _connection_fixture_connection = None
-
- if trans.is_active:
- trans.rollback()
- # trans would not be active here if the test is using
- # the legacy @provide_metadata decorator still, as it will
- # run a close all connections.
- conn.close()
-
- @config.fixture()
- def close_result_when_finished(self):
- to_close = []
- to_consume = []
-
- def go(result, consume=False):
- to_close.append(result)
- if consume:
- to_consume.append(result)
-
- yield go
- for r in to_consume:
- try:
- r.all()
- except:
- pass
- for r in to_close:
- try:
- r.close()
- except:
- pass
-
- @config.fixture()
- def registry(self, metadata):
- reg = registry(
- metadata=metadata,
- type_annotation_map={
- str: sa.String().with_variant(
- sa.String(50), "mysql", "mariadb", "oracle"
- )
- },
- )
- yield reg
- reg.dispose()
-
- @config.fixture
- def decl_base(self, metadata):
- _md = metadata
-
- class Base(DeclarativeBase):
- metadata = _md
- type_annotation_map = {
- str: sa.String().with_variant(
- sa.String(50), "mysql", "mariadb", "oracle"
- )
- }
-
- yield Base
- Base.registry.dispose()
-
- @config.fixture
- def dc_decl_base(self, metadata):
- _md = metadata
-
- class Base(MappedAsDataclass, DeclarativeBase):
- metadata = _md
- type_annotation_map = {
- str: sa.String().with_variant(
- sa.String(50), "mysql", "mariadb"
- )
- }
-
- yield Base
- Base.registry.dispose()
-
- @config.fixture()
- def future_connection(self, future_engine, connection):
- # integrate the future_engine and connection fixtures so
- # that users of the "connection" fixture will get at the
- # "future" connection
- yield connection
-
- @config.fixture()
- def future_engine(self):
- yield
-
- @config.fixture()
- def testing_engine(self):
- from . import engines
-
- def gen_testing_engine(
- url=None,
- options=None,
- future=None,
- asyncio=False,
- transfer_staticpool=False,
- share_pool=False,
- ):
- if options is None:
- options = {}
- options["scope"] = "fixture"
- return engines.testing_engine(
- url=url,
- options=options,
- asyncio=asyncio,
- transfer_staticpool=transfer_staticpool,
- share_pool=share_pool,
- )
-
- yield gen_testing_engine
-
- engines.testing_reaper._drop_testing_engines("fixture")
-
- @config.fixture()
- def async_testing_engine(self, testing_engine):
- def go(**kw):
- kw["asyncio"] = True
- return testing_engine(**kw)
-
- return go
-
- @config.fixture
- def fixture_session(self):
- return fixture_session()
-
- @config.fixture()
- def metadata(self, request):
- """Provide bound MetaData for a single test, dropping afterwards."""
-
- from ..sql import schema
-
- metadata = schema.MetaData()
- request.instance.metadata = metadata
- yield metadata
- del request.instance.metadata
-
- if (
- _connection_fixture_connection
- and _connection_fixture_connection.in_transaction()
- ):
- trans = _connection_fixture_connection.get_transaction()
- trans.rollback()
- with _connection_fixture_connection.begin():
- drop_all_tables_from_metadata(
- metadata, _connection_fixture_connection
- )
- else:
- drop_all_tables_from_metadata(metadata, config.db)
-
- @config.fixture(
- params=[
- (rollback, second_operation, begin_nested)
- for rollback in (True, False)
- for second_operation in ("none", "execute", "begin")
- for begin_nested in (
- True,
- False,
- )
- ]
- )
- def trans_ctx_manager_fixture(self, request, metadata):
- rollback, second_operation, begin_nested = request.param
-
- t = Table("test", metadata, Column("data", Integer))
- eng = getattr(self, "bind", None) or config.db
-
- t.create(eng)
-
- def run_test(subject, trans_on_subject, execute_on_subject):
- with subject.begin() as trans:
- if begin_nested:
- if not config.requirements.savepoints.enabled:
- config.skip_test("savepoints not enabled")
- if execute_on_subject:
- nested_trans = subject.begin_nested()
- else:
- nested_trans = trans.begin_nested()
-
- with nested_trans:
- if execute_on_subject:
- subject.execute(t.insert(), {"data": 10})
- else:
- trans.execute(t.insert(), {"data": 10})
-
- # for nested trans, we always commit/rollback on the
- # "nested trans" object itself.
- # only Session(future=False) will affect savepoint
- # transaction for session.commit/rollback
-
- if rollback:
- nested_trans.rollback()
- else:
- nested_trans.commit()
-
- if second_operation != "none":
- with assertions.expect_raises_message(
- sa.exc.InvalidRequestError,
- "Can't operate on closed transaction "
- "inside context "
- "manager. Please complete the context "
- "manager "
- "before emitting further commands.",
- ):
- if second_operation == "execute":
- if execute_on_subject:
- subject.execute(
- t.insert(), {"data": 12}
- )
- else:
- trans.execute(t.insert(), {"data": 12})
- elif second_operation == "begin":
- if execute_on_subject:
- subject.begin_nested()
- else:
- trans.begin_nested()
-
- # outside the nested trans block, but still inside the
- # transaction block, we can run SQL, and it will be
- # committed
- if execute_on_subject:
- subject.execute(t.insert(), {"data": 14})
- else:
- trans.execute(t.insert(), {"data": 14})
-
- else:
- if execute_on_subject:
- subject.execute(t.insert(), {"data": 10})
- else:
- trans.execute(t.insert(), {"data": 10})
-
- if trans_on_subject:
- if rollback:
- subject.rollback()
- else:
- subject.commit()
- else:
- if rollback:
- trans.rollback()
- else:
- trans.commit()
-
- if second_operation != "none":
- with assertions.expect_raises_message(
- sa.exc.InvalidRequestError,
- "Can't operate on closed transaction inside "
- "context "
- "manager. Please complete the context manager "
- "before emitting further commands.",
- ):
- if second_operation == "execute":
- if execute_on_subject:
- subject.execute(t.insert(), {"data": 12})
- else:
- trans.execute(t.insert(), {"data": 12})
- elif second_operation == "begin":
- if hasattr(trans, "begin"):
- trans.begin()
- else:
- subject.begin()
- elif second_operation == "begin_nested":
- if execute_on_subject:
- subject.begin_nested()
- else:
- trans.begin_nested()
-
- expected_committed = 0
- if begin_nested:
- # begin_nested variant, we inserted a row after the nested
- # block
- expected_committed += 1
- if not rollback:
- # not rollback variant, our row inserted in the target
- # block itself would be committed
- expected_committed += 1
-
- if execute_on_subject:
- eq_(
- subject.scalar(select(func.count()).select_from(t)),
- expected_committed,
- )
- else:
- with subject.connect() as conn:
- eq_(
- conn.scalar(select(func.count()).select_from(t)),
- expected_committed,
- )
-
- return run_test
-
-
-_connection_fixture_connection = None
-
-
-class FutureEngineMixin:
- """alembic's suite still using this"""
-
-
-class TablesTest(TestBase):
- # 'once', None
- run_setup_bind = "once"
-
- # 'once', 'each', None
- run_define_tables = "once"
-
- # 'once', 'each', None
- run_create_tables = "once"
-
- # 'once', 'each', None
- run_inserts = "each"
-
- # 'each', None
- run_deletes = "each"
-
- # 'once', None
- run_dispose_bind = None
-
- bind = None
- _tables_metadata = None
- tables = None
- other = None
- sequences = None
-
- @config.fixture(autouse=True, scope="class")
- def _setup_tables_test_class(self):
- cls = self.__class__
- cls._init_class()
-
- cls._setup_once_tables()
-
- cls._setup_once_inserts()
-
- yield
-
- cls._teardown_once_metadata_bind()
-
- @config.fixture(autouse=True, scope="function")
- def _setup_tables_test_instance(self):
- self._setup_each_tables()
- self._setup_each_inserts()
-
- yield
-
- self._teardown_each_tables()
-
- @property
- def tables_test_metadata(self):
- return self._tables_metadata
-
- @classmethod
- def _init_class(cls):
- if cls.run_define_tables == "each":
- if cls.run_create_tables == "once":
- cls.run_create_tables = "each"
- assert cls.run_inserts in ("each", None)
-
- cls.other = adict()
- cls.tables = adict()
- cls.sequences = adict()
-
- cls.bind = cls.setup_bind()
- cls._tables_metadata = sa.MetaData()
-
- @classmethod
- def _setup_once_inserts(cls):
- if cls.run_inserts == "once":
- cls._load_fixtures()
- with cls.bind.begin() as conn:
- cls.insert_data(conn)
-
- @classmethod
- def _setup_once_tables(cls):
- if cls.run_define_tables == "once":
- cls.define_tables(cls._tables_metadata)
- if cls.run_create_tables == "once":
- cls._tables_metadata.create_all(cls.bind)
- cls.tables.update(cls._tables_metadata.tables)
- cls.sequences.update(cls._tables_metadata._sequences)
-
- def _setup_each_tables(self):
- if self.run_define_tables == "each":
- self.define_tables(self._tables_metadata)
- if self.run_create_tables == "each":
- self._tables_metadata.create_all(self.bind)
- self.tables.update(self._tables_metadata.tables)
- self.sequences.update(self._tables_metadata._sequences)
- elif self.run_create_tables == "each":
- self._tables_metadata.create_all(self.bind)
-
- def _setup_each_inserts(self):
- if self.run_inserts == "each":
- self._load_fixtures()
- with self.bind.begin() as conn:
- self.insert_data(conn)
-
- def _teardown_each_tables(self):
- if self.run_define_tables == "each":
- self.tables.clear()
- if self.run_create_tables == "each":
- drop_all_tables_from_metadata(self._tables_metadata, self.bind)
- self._tables_metadata.clear()
- elif self.run_create_tables == "each":
- drop_all_tables_from_metadata(self._tables_metadata, self.bind)
-
- savepoints = getattr(config.requirements, "savepoints", False)
- if savepoints:
- savepoints = savepoints.enabled
-
- # no need to run deletes if tables are recreated on setup
- if (
- self.run_define_tables != "each"
- and self.run_create_tables != "each"
- and self.run_deletes == "each"
- ):
- with self.bind.begin() as conn:
- for table in reversed(
- [
- t
- for (t, fks) in sort_tables_and_constraints(
- self._tables_metadata.tables.values()
- )
- if t is not None
- ]
- ):
- try:
- if savepoints:
- with conn.begin_nested():
- conn.execute(table.delete())
- else:
- conn.execute(table.delete())
- except sa.exc.DBAPIError as ex:
- print(
- ("Error emptying table %s: %r" % (table, ex)),
- file=sys.stderr,
- )
-
- @classmethod
- def _teardown_once_metadata_bind(cls):
- if cls.run_create_tables:
- drop_all_tables_from_metadata(cls._tables_metadata, cls.bind)
-
- if cls.run_dispose_bind == "once":
- cls.dispose_bind(cls.bind)
-
- cls._tables_metadata.bind = None
-
- if cls.run_setup_bind is not None:
- cls.bind = None
-
- @classmethod
- def setup_bind(cls):
- return config.db
-
- @classmethod
- def dispose_bind(cls, bind):
- if hasattr(bind, "dispose"):
- bind.dispose()
- elif hasattr(bind, "close"):
- bind.close()
-
- @classmethod
- def define_tables(cls, metadata):
- pass
-
- @classmethod
- def fixtures(cls):
- return {}
-
- @classmethod
- def insert_data(cls, connection):
- pass
-
- def sql_count_(self, count, fn):
- self.assert_sql_count(self.bind, fn, count)
-
- def sql_eq_(self, callable_, statements):
- self.assert_sql(self.bind, callable_, statements)
-
- @classmethod
- def _load_fixtures(cls):
- """Insert rows as represented by the fixtures() method."""
- headers, rows = {}, {}
- for table, data in cls.fixtures().items():
- if len(data) < 2:
- continue
- if isinstance(table, str):
- table = cls.tables[table]
- headers[table] = data[0]
- rows[table] = data[1:]
- for table, fks in sort_tables_and_constraints(
- cls._tables_metadata.tables.values()
- ):
- if table is None:
- continue
- if table not in headers:
- continue
- with cls.bind.begin() as conn:
- conn.execute(
- table.insert(),
- [
- dict(zip(headers[table], column_values))
- for column_values in rows[table]
- ],
- )
-
-
-class NoCache:
- @config.fixture(autouse=True, scope="function")
- def _disable_cache(self):
- _cache = config.db._compiled_cache
- config.db._compiled_cache = None
- yield
- config.db._compiled_cache = _cache
-
-
-class RemovesEvents:
- @util.memoized_property
- def _event_fns(self):
- return set()
-
- def event_listen(self, target, name, fn, **kw):
- self._event_fns.add((target, name, fn))
- event.listen(target, name, fn, **kw)
-
- @config.fixture(autouse=True, scope="function")
- def _remove_events(self):
- yield
- for key in self._event_fns:
- event.remove(*key)
-
-
-class RemoveORMEventsGlobally:
- @config.fixture(autouse=True)
- def _remove_listeners(self):
- yield
- orm_events.MapperEvents._clear()
- orm_events.InstanceEvents._clear()
- orm_events.SessionEvents._clear()
- orm_events.InstrumentationEvents._clear()
- orm_events.QueryEvents._clear()
-
-
-_fixture_sessions = set()
-
-
-def fixture_session(**kw):
- kw.setdefault("autoflush", True)
- kw.setdefault("expire_on_commit", True)
-
- bind = kw.pop("bind", config.db)
-
- sess = sa.orm.Session(bind, **kw)
- _fixture_sessions.add(sess)
- return sess
-
-
-def _close_all_sessions():
- # will close all still-referenced sessions
- sa.orm.session.close_all_sessions()
- _fixture_sessions.clear()
-
-
-def stop_test_class_inside_fixtures(cls):
- _close_all_sessions()
- sa.orm.clear_mappers()
-
-
-def after_test():
- if _fixture_sessions:
- _close_all_sessions()
-
-
-class ORMTest(TestBase):
- pass
-
-
-class MappedTest(TablesTest, assertions.AssertsExecutionResults):
- # 'once', 'each', None
- run_setup_classes = "once"
-
- # 'once', 'each', None
- run_setup_mappers = "each"
-
- classes: Any = None
-
- @config.fixture(autouse=True, scope="class")
- def _setup_tables_test_class(self):
- cls = self.__class__
- cls._init_class()
-
- if cls.classes is None:
- cls.classes = adict()
-
- cls._setup_once_tables()
- cls._setup_once_classes()
- cls._setup_once_mappers()
- cls._setup_once_inserts()
-
- yield
-
- cls._teardown_once_class()
- cls._teardown_once_metadata_bind()
-
- @config.fixture(autouse=True, scope="function")
- def _setup_tables_test_instance(self):
- self._setup_each_tables()
- self._setup_each_classes()
- self._setup_each_mappers()
- self._setup_each_inserts()
-
- yield
-
- sa.orm.session.close_all_sessions()
- self._teardown_each_mappers()
- self._teardown_each_classes()
- self._teardown_each_tables()
-
- @classmethod
- def _teardown_once_class(cls):
- cls.classes.clear()
-
- @classmethod
- def _setup_once_classes(cls):
- if cls.run_setup_classes == "once":
- cls._with_register_classes(cls.setup_classes)
-
- @classmethod
- def _setup_once_mappers(cls):
- if cls.run_setup_mappers == "once":
- cls.mapper_registry, cls.mapper = cls._generate_registry()
- cls._with_register_classes(cls.setup_mappers)
-
- def _setup_each_mappers(self):
- if self.run_setup_mappers != "once":
- (
- self.__class__.mapper_registry,
- self.__class__.mapper,
- ) = self._generate_registry()
-
- if self.run_setup_mappers == "each":
- self._with_register_classes(self.setup_mappers)
-
- def _setup_each_classes(self):
- if self.run_setup_classes == "each":
- self._with_register_classes(self.setup_classes)
-
- @classmethod
- def _generate_registry(cls):
- decl = registry(metadata=cls._tables_metadata)
- return decl, decl.map_imperatively
-
- @classmethod
- def _with_register_classes(cls, fn):
- """Run a setup method, framing the operation with a Base class
- that will catch new subclasses to be established within
- the "classes" registry.
-
- """
- cls_registry = cls.classes
-
- class _Base:
- def __init_subclass__(cls) -> None:
- assert cls_registry is not None
- cls_registry[cls.__name__] = cls
- super().__init_subclass__()
-
- class Basic(BasicEntity, _Base):
- pass
-
- class Comparable(ComparableEntity, _Base):
- pass
-
- cls.Basic = Basic
- cls.Comparable = Comparable
- fn()
-
- def _teardown_each_mappers(self):
- # some tests create mappers in the test bodies
- # and will define setup_mappers as None -
- # clear mappers in any case
- if self.run_setup_mappers != "once":
- sa.orm.clear_mappers()
-
- def _teardown_each_classes(self):
- if self.run_setup_classes != "once":
- self.classes.clear()
-
- @classmethod
- def setup_classes(cls):
- pass
-
- @classmethod
- def setup_mappers(cls):
- pass
-
-
-class DeclarativeMappedTest(MappedTest):
- run_setup_classes = "once"
- run_setup_mappers = "once"
-
- @classmethod
- def _setup_once_tables(cls):
- pass
-
- @classmethod
- def _with_register_classes(cls, fn):
- cls_registry = cls.classes
-
- class _DeclBase(DeclarativeBase):
- __table_cls__ = schema.Table
- metadata = cls._tables_metadata
- type_annotation_map = {
- str: sa.String().with_variant(
- sa.String(50), "mysql", "mariadb", "oracle"
- )
- }
-
- def __init_subclass__(cls, **kw) -> None:
- assert cls_registry is not None
- cls_registry[cls.__name__] = cls
- super().__init_subclass__(**kw)
-
- cls.DeclarativeBasic = _DeclBase
-
- # sets up cls.Basic which is helpful for things like composite
- # classes
- super()._with_register_classes(fn)
-
- if cls._tables_metadata.tables and cls.run_create_tables:
- cls._tables_metadata.create_all(config.db)
-
-
-class ComputedReflectionFixtureTest(TablesTest):
- run_inserts = run_deletes = None
-
- __backend__ = True
- __requires__ = ("computed_columns", "table_reflection")
-
- regexp = re.compile(r"[\[\]\(\)\s`'\"]*")
-
- def normalize(self, text):
- return self.regexp.sub("", text).lower()
-
- @classmethod
- def define_tables(cls, metadata):
- from .. import Integer
- from .. import testing
- from ..schema import Column
- from ..schema import Computed
- from ..schema import Table
-
- Table(
- "computed_default_table",
- metadata,
- Column("id", Integer, primary_key=True),
- Column("normal", Integer),
- Column("computed_col", Integer, Computed("normal + 42")),
- Column("with_default", Integer, server_default="42"),
- )
-
- t = Table(
- "computed_column_table",
- metadata,
- Column("id", Integer, primary_key=True),
- Column("normal", Integer),
- Column("computed_no_flag", Integer, Computed("normal + 42")),
- )
-
- if testing.requires.schemas.enabled:
- t2 = Table(
- "computed_column_table",
- metadata,
- Column("id", Integer, primary_key=True),
- Column("normal", Integer),
- Column("computed_no_flag", Integer, Computed("normal / 42")),
- schema=config.test_schema,
- )
-
- if testing.requires.computed_columns_virtual.enabled:
- t.append_column(
- Column(
- "computed_virtual",
- Integer,
- Computed("normal + 2", persisted=False),
- )
- )
- if testing.requires.schemas.enabled:
- t2.append_column(
- Column(
- "computed_virtual",
- Integer,
- Computed("normal / 2", persisted=False),
- )
- )
- if testing.requires.computed_columns_stored.enabled:
- t.append_column(
- Column(
- "computed_stored",
- Integer,
- Computed("normal - 42", persisted=True),
- )
- )
- if testing.requires.schemas.enabled:
- t2.append_column(
- Column(
- "computed_stored",
- Integer,
- Computed("normal * 42", persisted=True),
- )
- )
-
-
-class CacheKeyFixture:
- def _compare_equal(self, a, b, compare_values):
- a_key = a._generate_cache_key()
- b_key = b._generate_cache_key()
-
- if a_key is None:
- assert a._annotations.get("nocache")
-
- assert b_key is None
- else:
- eq_(a_key.key, b_key.key)
- eq_(hash(a_key.key), hash(b_key.key))
-
- for a_param, b_param in zip(a_key.bindparams, b_key.bindparams):
- assert a_param.compare(b_param, compare_values=compare_values)
- return a_key, b_key
-
- def _run_cache_key_fixture(self, fixture, compare_values):
- case_a = fixture()
- case_b = fixture()
-
- for a, b in itertools.combinations_with_replacement(
- range(len(case_a)), 2
- ):
- if a == b:
- a_key, b_key = self._compare_equal(
- case_a[a], case_b[b], compare_values
- )
- if a_key is None:
- continue
- else:
- a_key = case_a[a]._generate_cache_key()
- b_key = case_b[b]._generate_cache_key()
-
- if a_key is None or b_key is None:
- if a_key is None:
- assert case_a[a]._annotations.get("nocache")
- if b_key is None:
- assert case_b[b]._annotations.get("nocache")
- continue
-
- if a_key.key == b_key.key:
- for a_param, b_param in zip(
- a_key.bindparams, b_key.bindparams
- ):
- if not a_param.compare(
- b_param, compare_values=compare_values
- ):
- break
- else:
- # this fails unconditionally since we could not
- # find bound parameter values that differed.
- # Usually we intended to get two distinct keys here
- # so the failure will be more descriptive using the
- # ne_() assertion.
- ne_(a_key.key, b_key.key)
- else:
- ne_(a_key.key, b_key.key)
-
- # ClauseElement-specific test to ensure the cache key
- # collected all the bound parameters that aren't marked
- # as "literal execute"
- if isinstance(case_a[a], ClauseElement) and isinstance(
- case_b[b], ClauseElement
- ):
- assert_a_params = []
- assert_b_params = []
-
- for elem in visitors.iterate(case_a[a]):
- if elem.__visit_name__ == "bindparam":
- assert_a_params.append(elem)
-
- for elem in visitors.iterate(case_b[b]):
- if elem.__visit_name__ == "bindparam":
- assert_b_params.append(elem)
-
- # note we're asserting the order of the params as well as
- # if there are dupes or not. ordering has to be
- # deterministic and matches what a traversal would provide.
- eq_(
- sorted(a_key.bindparams, key=lambda b: b.key),
- sorted(
- util.unique_list(assert_a_params), key=lambda b: b.key
- ),
- )
- eq_(
- sorted(b_key.bindparams, key=lambda b: b.key),
- sorted(
- util.unique_list(assert_b_params), key=lambda b: b.key
- ),
- )
-
- def _run_cache_key_equal_fixture(self, fixture, compare_values):
- case_a = fixture()
- case_b = fixture()
-
- for a, b in itertools.combinations_with_replacement(
- range(len(case_a)), 2
- ):
- self._compare_equal(case_a[a], case_b[b], compare_values)
-
-
-def insertmanyvalues_fixture(
- connection, randomize_rows=False, warn_on_downgraded=False
-):
- dialect = connection.dialect
- orig_dialect = dialect._deliver_insertmanyvalues_batches
- orig_conn = connection._exec_insertmany_context
-
- class RandomCursor:
- __slots__ = ("cursor",)
-
- def __init__(self, cursor):
- self.cursor = cursor
-
- # only this method is called by the deliver method.
- # by not having the other methods we assert that those aren't being
- # used
-
- def fetchall(self):
- rows = self.cursor.fetchall()
- rows = list(rows)
- random.shuffle(rows)
- return rows
-
- def _deliver_insertmanyvalues_batches(
- cursor, statement, parameters, generic_setinputsizes, context
- ):
- if randomize_rows:
- cursor = RandomCursor(cursor)
- for batch in orig_dialect(
- cursor, statement, parameters, generic_setinputsizes, context
- ):
- if warn_on_downgraded and batch.is_downgraded:
- util.warn("Batches were downgraded for sorted INSERT")
-
- yield batch
-
- def _exec_insertmany_context(
- dialect,
- context,
- ):
- with mock.patch.object(
- dialect,
- "_deliver_insertmanyvalues_batches",
- new=_deliver_insertmanyvalues_batches,
- ):
- return orig_conn(dialect, context)
-
- connection._exec_insertmany_context = _exec_insertmany_context
--- /dev/null
+# testing/fixtures/__init__.py
+# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+# mypy: ignore-errors
+from .base import FutureEngineMixin as FutureEngineMixin
+from .base import TestBase as TestBase
+from .mypy import MypyTest as MypyTest
+from .orm import after_test as after_test
+from .orm import close_all_sessions as close_all_sessions
+from .orm import DeclarativeMappedTest as DeclarativeMappedTest
+from .orm import fixture_session as fixture_session
+from .orm import MappedTest as MappedTest
+from .orm import ORMTest as ORMTest
+from .orm import RemoveORMEventsGlobally as RemoveORMEventsGlobally
+from .orm import (
+ stop_test_class_inside_fixtures as stop_test_class_inside_fixtures,
+)
+from .sql import CacheKeyFixture as CacheKeyFixture
+from .sql import (
+ ComputedReflectionFixtureTest as ComputedReflectionFixtureTest,
+)
+from .sql import insertmanyvalues_fixture as insertmanyvalues_fixture
+from .sql import NoCache as NoCache
+from .sql import RemovesEvents as RemovesEvents
+from .sql import TablesTest as TablesTest
--- /dev/null
+# testing/fixtures/base.py
+# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+# mypy: ignore-errors
+
+
+from __future__ import annotations
+
+import sqlalchemy as sa
+from .. import assertions
+from .. import config
+from ..assertions import eq_
+from ..util import drop_all_tables_from_metadata
+from ... import Column
+from ... import func
+from ... import Integer
+from ... import select
+from ... import Table
+from ...orm import DeclarativeBase
+from ...orm import MappedAsDataclass
+from ...orm import registry
+
+
+@config.mark_base_test_class()
+class TestBase:
+ # A sequence of requirement names matching testing.requires decorators
+ __requires__ = ()
+
+ # A sequence of dialect names to exclude from the test class.
+ __unsupported_on__ = ()
+
+ # If present, test class is only runnable for the *single* specified
+ # dialect. If you need multiple, use __unsupported_on__ and invert.
+ __only_on__ = None
+
+ # A sequence of no-arg callables. If any are True, the entire testcase is
+ # skipped.
+ __skip_if__ = None
+
+ # if True, the testing reaper will not attempt to touch connection
+ # state after a test is completed and before the outer teardown
+ # starts
+ __leave_connections_for_teardown__ = False
+
+ def assert_(self, val, msg=None):
+ assert val, msg
+
+ @config.fixture()
+ def nocache(self):
+ _cache = config.db._compiled_cache
+ config.db._compiled_cache = None
+ yield
+ config.db._compiled_cache = _cache
+
+ @config.fixture()
+ def connection_no_trans(self):
+ eng = getattr(self, "bind", None) or config.db
+
+ with eng.connect() as conn:
+ yield conn
+
+ @config.fixture()
+ def connection(self):
+ global _connection_fixture_connection
+
+ eng = getattr(self, "bind", None) or config.db
+
+ conn = eng.connect()
+ trans = conn.begin()
+
+ _connection_fixture_connection = conn
+ yield conn
+
+ _connection_fixture_connection = None
+
+ if trans.is_active:
+ trans.rollback()
+ # trans would not be active here if the test is using
+ # the legacy @provide_metadata decorator still, as it will
+ # run a close all connections.
+ conn.close()
+
+ @config.fixture()
+ def close_result_when_finished(self):
+ to_close = []
+ to_consume = []
+
+ def go(result, consume=False):
+ to_close.append(result)
+ if consume:
+ to_consume.append(result)
+
+ yield go
+ for r in to_consume:
+ try:
+ r.all()
+ except:
+ pass
+ for r in to_close:
+ try:
+ r.close()
+ except:
+ pass
+
+ @config.fixture()
+ def registry(self, metadata):
+ reg = registry(
+ metadata=metadata,
+ type_annotation_map={
+ str: sa.String().with_variant(
+ sa.String(50), "mysql", "mariadb", "oracle"
+ )
+ },
+ )
+ yield reg
+ reg.dispose()
+
+ @config.fixture
+ def decl_base(self, metadata):
+ _md = metadata
+
+ class Base(DeclarativeBase):
+ metadata = _md
+ type_annotation_map = {
+ str: sa.String().with_variant(
+ sa.String(50), "mysql", "mariadb", "oracle"
+ )
+ }
+
+ yield Base
+ Base.registry.dispose()
+
+ @config.fixture
+ def dc_decl_base(self, metadata):
+ _md = metadata
+
+ class Base(MappedAsDataclass, DeclarativeBase):
+ metadata = _md
+ type_annotation_map = {
+ str: sa.String().with_variant(
+ sa.String(50), "mysql", "mariadb"
+ )
+ }
+
+ yield Base
+ Base.registry.dispose()
+
+ @config.fixture()
+ def future_connection(self, future_engine, connection):
+ # integrate the future_engine and connection fixtures so
+ # that users of the "connection" fixture will get at the
+ # "future" connection
+ yield connection
+
+ @config.fixture()
+ def future_engine(self):
+ yield
+
+ @config.fixture()
+ def testing_engine(self):
+ from .. import engines
+
+ def gen_testing_engine(
+ url=None,
+ options=None,
+ future=None,
+ asyncio=False,
+ transfer_staticpool=False,
+ share_pool=False,
+ ):
+ if options is None:
+ options = {}
+ options["scope"] = "fixture"
+ return engines.testing_engine(
+ url=url,
+ options=options,
+ asyncio=asyncio,
+ transfer_staticpool=transfer_staticpool,
+ share_pool=share_pool,
+ )
+
+ yield gen_testing_engine
+
+ engines.testing_reaper._drop_testing_engines("fixture")
+
+ @config.fixture()
+ def async_testing_engine(self, testing_engine):
+ def go(**kw):
+ kw["asyncio"] = True
+ return testing_engine(**kw)
+
+ return go
+
+ @config.fixture()
+ def metadata(self, request):
+ """Provide bound MetaData for a single test, dropping afterwards."""
+
+ from ...sql import schema
+
+ metadata = schema.MetaData()
+ request.instance.metadata = metadata
+ yield metadata
+ del request.instance.metadata
+
+ if (
+ _connection_fixture_connection
+ and _connection_fixture_connection.in_transaction()
+ ):
+ trans = _connection_fixture_connection.get_transaction()
+ trans.rollback()
+ with _connection_fixture_connection.begin():
+ drop_all_tables_from_metadata(
+ metadata, _connection_fixture_connection
+ )
+ else:
+ drop_all_tables_from_metadata(metadata, config.db)
+
+ @config.fixture(
+ params=[
+ (rollback, second_operation, begin_nested)
+ for rollback in (True, False)
+ for second_operation in ("none", "execute", "begin")
+ for begin_nested in (
+ True,
+ False,
+ )
+ ]
+ )
+ def trans_ctx_manager_fixture(self, request, metadata):
+ rollback, second_operation, begin_nested = request.param
+
+ t = Table("test", metadata, Column("data", Integer))
+ eng = getattr(self, "bind", None) or config.db
+
+ t.create(eng)
+
+ def run_test(subject, trans_on_subject, execute_on_subject):
+ with subject.begin() as trans:
+ if begin_nested:
+ if not config.requirements.savepoints.enabled:
+ config.skip_test("savepoints not enabled")
+ if execute_on_subject:
+ nested_trans = subject.begin_nested()
+ else:
+ nested_trans = trans.begin_nested()
+
+ with nested_trans:
+ if execute_on_subject:
+ subject.execute(t.insert(), {"data": 10})
+ else:
+ trans.execute(t.insert(), {"data": 10})
+
+ # for nested trans, we always commit/rollback on the
+ # "nested trans" object itself.
+ # only Session(future=False) will affect savepoint
+ # transaction for session.commit/rollback
+
+ if rollback:
+ nested_trans.rollback()
+ else:
+ nested_trans.commit()
+
+ if second_operation != "none":
+ with assertions.expect_raises_message(
+ sa.exc.InvalidRequestError,
+ "Can't operate on closed transaction "
+ "inside context "
+ "manager. Please complete the context "
+ "manager "
+ "before emitting further commands.",
+ ):
+ if second_operation == "execute":
+ if execute_on_subject:
+ subject.execute(
+ t.insert(), {"data": 12}
+ )
+ else:
+ trans.execute(t.insert(), {"data": 12})
+ elif second_operation == "begin":
+ if execute_on_subject:
+ subject.begin_nested()
+ else:
+ trans.begin_nested()
+
+ # outside the nested trans block, but still inside the
+ # transaction block, we can run SQL, and it will be
+ # committed
+ if execute_on_subject:
+ subject.execute(t.insert(), {"data": 14})
+ else:
+ trans.execute(t.insert(), {"data": 14})
+
+ else:
+ if execute_on_subject:
+ subject.execute(t.insert(), {"data": 10})
+ else:
+ trans.execute(t.insert(), {"data": 10})
+
+ if trans_on_subject:
+ if rollback:
+ subject.rollback()
+ else:
+ subject.commit()
+ else:
+ if rollback:
+ trans.rollback()
+ else:
+ trans.commit()
+
+ if second_operation != "none":
+ with assertions.expect_raises_message(
+ sa.exc.InvalidRequestError,
+ "Can't operate on closed transaction inside "
+ "context "
+ "manager. Please complete the context manager "
+ "before emitting further commands.",
+ ):
+ if second_operation == "execute":
+ if execute_on_subject:
+ subject.execute(t.insert(), {"data": 12})
+ else:
+ trans.execute(t.insert(), {"data": 12})
+ elif second_operation == "begin":
+ if hasattr(trans, "begin"):
+ trans.begin()
+ else:
+ subject.begin()
+ elif second_operation == "begin_nested":
+ if execute_on_subject:
+ subject.begin_nested()
+ else:
+ trans.begin_nested()
+
+ expected_committed = 0
+ if begin_nested:
+ # begin_nested variant, we inserted a row after the nested
+ # block
+ expected_committed += 1
+ if not rollback:
+ # not rollback variant, our row inserted in the target
+ # block itself would be committed
+ expected_committed += 1
+
+ if execute_on_subject:
+ eq_(
+ subject.scalar(select(func.count()).select_from(t)),
+ expected_committed,
+ )
+ else:
+ with subject.connect() as conn:
+ eq_(
+ conn.scalar(select(func.count()).select_from(t)),
+ expected_committed,
+ )
+
+ return run_test
+
+
+_connection_fixture_connection = None
+
+
+class FutureEngineMixin:
+ """alembic's suite still using this"""
--- /dev/null
+# testing/fixtures/mypy.py
+# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+# mypy: ignore-errors
+
+from __future__ import annotations
+
+import inspect
+import os
+from pathlib import Path
+import re
+import shutil
+import sys
+import tempfile
+
+from .base import TestBase
+from .. import config
+from ..assertions import eq_
+from ... import util
+
+
+@config.add_to_marker.mypy
+class MypyTest(TestBase):
+ __requires__ = ("no_sqlalchemy2_stubs",)
+
+ @config.fixture(scope="function")
+ def per_func_cachedir(self):
+ yield from self._cachedir()
+
+ @config.fixture(scope="class")
+ def cachedir(self):
+ yield from self._cachedir()
+
+ def _cachedir(self):
+ # as of mypy 0.971 i think we need to keep mypy_path empty
+ mypy_path = ""
+
+ with tempfile.TemporaryDirectory() as cachedir:
+ with open(
+ Path(cachedir) / "sqla_mypy_config.cfg", "w"
+ ) as config_file:
+ config_file.write(
+ f"""
+ [mypy]\n
+ plugins = sqlalchemy.ext.mypy.plugin\n
+ show_error_codes = True\n
+ {mypy_path}
+ disable_error_code = no-untyped-call
+
+ [mypy-sqlalchemy.*]
+ ignore_errors = True
+
+ """
+ )
+ with open(
+ Path(cachedir) / "plain_mypy_config.cfg", "w"
+ ) as config_file:
+ config_file.write(
+ f"""
+ [mypy]\n
+ show_error_codes = True\n
+ {mypy_path}
+ disable_error_code = var-annotated,no-untyped-call
+ [mypy-sqlalchemy.*]
+ ignore_errors = True
+
+ """
+ )
+ yield cachedir
+
+ @config.fixture()
+ def mypy_runner(self, cachedir):
+ from mypy import api
+
+ def run(path, use_plugin=False, use_cachedir=None):
+ if use_cachedir is None:
+ use_cachedir = cachedir
+ args = [
+ "--strict",
+ "--raise-exceptions",
+ "--cache-dir",
+ use_cachedir,
+ "--config-file",
+ os.path.join(
+ use_cachedir,
+ "sqla_mypy_config.cfg"
+ if use_plugin
+ else "plain_mypy_config.cfg",
+ ),
+ ]
+
+ # mypy as of 0.990 is more aggressively blocking messaging
+ # for paths that are in sys.path, and as pytest puts currdir,
+ # test/ etc in sys.path, just copy the source file to the
+ # tempdir we are working in so that we don't have to try to
+ # manipulate sys.path and/or guess what mypy is doing
+ filename = os.path.basename(path)
+ test_program = os.path.join(use_cachedir, filename)
+ if path != test_program:
+ shutil.copyfile(path, test_program)
+ args.append(test_program)
+
+ # I set this locally but for the suite here needs to be
+ # disabled
+ os.environ.pop("MYPY_FORCE_COLOR", None)
+
+ stdout, stderr, exitcode = api.run(args)
+ return stdout, stderr, exitcode
+
+ return run
+
+ @config.fixture
+ def mypy_typecheck_file(self, mypy_runner):
+ def run(path, use_plugin=False):
+ expected_messages = self._collect_messages(path)
+ stdout, stderr, exitcode = mypy_runner(path, use_plugin=use_plugin)
+ self._check_output(
+ path, expected_messages, stdout, stderr, exitcode
+ )
+
+ return run
+
+ @staticmethod
+ def file_combinations(dirname):
+ if os.path.isabs(dirname):
+ path = dirname
+ else:
+ caller_path = inspect.stack()[1].filename
+ path = os.path.join(os.path.dirname(caller_path), dirname)
+ files = list(Path(path).glob("**/*.py"))
+
+ for extra_dir in config.options.mypy_extra_test_paths:
+ if extra_dir and os.path.isdir(extra_dir):
+ files.extend((Path(extra_dir) / dirname).glob("**/*.py"))
+ return files
+
+ def _collect_messages(self, path):
+ from sqlalchemy.ext.mypy.util import mypy_14
+
+ expected_messages = []
+ expected_re = re.compile(r"\s*# EXPECTED(_MYPY)?(_RE)?(_TYPE)?: (.+)")
+ py_ver_re = re.compile(r"^#\s*PYTHON_VERSION\s?>=\s?(\d+\.\d+)")
+ with open(path) as file_:
+ current_assert_messages = []
+ for num, line in enumerate(file_, 1):
+ m = py_ver_re.match(line)
+ if m:
+ major, _, minor = m.group(1).partition(".")
+ if sys.version_info < (int(major), int(minor)):
+ config.skip_test(
+ "Requires python >= %s" % (m.group(1))
+ )
+ continue
+
+ m = expected_re.match(line)
+ if m:
+ is_mypy = bool(m.group(1))
+ is_re = bool(m.group(2))
+ is_type = bool(m.group(3))
+
+ expected_msg = re.sub(r"# noqa[:]? ?.*", "", m.group(4))
+ if is_type:
+ if not is_re:
+ # the goal here is that we can cut-and-paste
+ # from vscode -> pylance into the
+ # EXPECTED_TYPE: line, then the test suite will
+ # validate that line against what mypy produces
+ expected_msg = re.sub(
+ r"([\[\]])",
+ lambda m: rf"\{m.group(0)}",
+ expected_msg,
+ )
+
+ # note making sure preceding text matches
+ # with a dot, so that an expect for "Select"
+ # does not match "TypedSelect"
+ expected_msg = re.sub(
+ r"([\w_]+)",
+ lambda m: rf"(?:.*\.)?{m.group(1)}\*?",
+ expected_msg,
+ )
+
+ expected_msg = re.sub(
+ "List", "builtins.list", expected_msg
+ )
+
+ expected_msg = re.sub(
+ r"\b(int|str|float|bool)\b",
+ lambda m: rf"builtins.{m.group(0)}\*?",
+ expected_msg,
+ )
+ # expected_msg = re.sub(
+ # r"(Sequence|Tuple|List|Union)",
+ # lambda m: fr"typing.{m.group(0)}\*?",
+ # expected_msg,
+ # )
+
+ is_mypy = is_re = True
+ expected_msg = f'Revealed type is "{expected_msg}"'
+
+ if mypy_14 and util.py39:
+ # use_lowercase_names, py39 and above
+ # https://github.com/python/mypy/blob/304997bfb85200fb521ac727ee0ce3e6085e5278/mypy/options.py#L363 # noqa: E501
+
+ # skip first character which could be capitalized
+ # "List item x not found" type of message
+ expected_msg = expected_msg[0] + re.sub(
+ r"\b(List|Tuple|Dict|Set)\b"
+ if is_type
+ else r"\b(List|Tuple|Dict|Set|Type)\b",
+ lambda m: m.group(1).lower(),
+ expected_msg[1:],
+ )
+
+ if mypy_14 and util.py310:
+ # use_or_syntax, py310 and above
+ # https://github.com/python/mypy/blob/304997bfb85200fb521ac727ee0ce3e6085e5278/mypy/options.py#L368 # noqa: E501
+ expected_msg = re.sub(
+ r"Optional\[(.*?)\]",
+ lambda m: f"{m.group(1)} | None",
+ expected_msg,
+ )
+ current_assert_messages.append(
+ (is_mypy, is_re, expected_msg.strip())
+ )
+ elif current_assert_messages:
+ expected_messages.extend(
+ (num, is_mypy, is_re, expected_msg)
+ for (
+ is_mypy,
+ is_re,
+ expected_msg,
+ ) in current_assert_messages
+ )
+ current_assert_messages[:] = []
+
+ return expected_messages
+
+ def _check_output(self, path, expected_messages, stdout, stderr, exitcode):
+ not_located = []
+ filename = os.path.basename(path)
+ if expected_messages:
+ # mypy 0.990 changed how return codes work, so don't assume a
+ # 1 or a 0 return code here, could be either depending on if
+ # errors were generated or not
+
+ output = []
+
+ raw_lines = stdout.split("\n")
+ while raw_lines:
+ e = raw_lines.pop(0)
+ if re.match(r".+\.py:\d+: error: .*", e):
+ output.append(("error", e))
+ elif re.match(
+ r".+\.py:\d+: note: +(?:Possible overload|def ).*", e
+ ):
+ while raw_lines:
+ ol = raw_lines.pop(0)
+ if not re.match(r".+\.py:\d+: note: +def \[.*", ol):
+ break
+ elif re.match(
+ r".+\.py:\d+: note: .*(?:perhaps|suggestion)", e, re.I
+ ):
+ pass
+ elif re.match(r".+\.py:\d+: note: .*", e):
+ output.append(("note", e))
+
+ for num, is_mypy, is_re, msg in expected_messages:
+ msg = msg.replace("'", '"')
+ prefix = "[SQLAlchemy Mypy plugin] " if not is_mypy else ""
+ for idx, (typ, errmsg) in enumerate(output):
+ if is_re:
+ if re.match(
+ rf".*{filename}\:{num}\: {typ}\: {prefix}{msg}",
+ errmsg,
+ ):
+ break
+ elif (
+ f"{filename}:{num}: {typ}: {prefix}{msg}"
+ in errmsg.replace("'", '"')
+ ):
+ break
+ else:
+ not_located.append(msg)
+ continue
+ del output[idx]
+
+ if not_located:
+ missing = "\n".join(not_located)
+ print("Couldn't locate expected messages:", missing, sep="\n")
+ if output:
+ extra = "\n".join(msg for _, msg in output)
+ print("Remaining messages:", extra, sep="\n")
+ assert False, "expected messages not found, see stdout"
+
+ if output:
+ print(f"{len(output)} messages from mypy were not consumed:")
+ print("\n".join(msg for _, msg in output))
+ assert False, "errors and/or notes remain, see stdout"
+
+ else:
+ if exitcode != 0:
+ print(stdout, stderr, sep="\n")
+
+ eq_(exitcode, 0, msg=stdout)
--- /dev/null
+# testing/fixtures/orm.py
+# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+# mypy: ignore-errors
+from __future__ import annotations
+
+from typing import Any
+
+import sqlalchemy as sa
+from .base import TestBase
+from .sql import TablesTest
+from .. import assertions
+from .. import config
+from .. import schema
+from ..entities import BasicEntity
+from ..entities import ComparableEntity
+from ..util import adict
+from ... import orm
+from ...orm import DeclarativeBase
+from ...orm import events as orm_events
+from ...orm import registry
+
+
+class ORMTest(TestBase):
+ @config.fixture
+ def fixture_session(self):
+ return fixture_session()
+
+
+class MappedTest(ORMTest, TablesTest, assertions.AssertsExecutionResults):
+ # 'once', 'each', None
+ run_setup_classes = "once"
+
+ # 'once', 'each', None
+ run_setup_mappers = "each"
+
+ classes: Any = None
+
+ @config.fixture(autouse=True, scope="class")
+ def _setup_tables_test_class(self):
+ cls = self.__class__
+ cls._init_class()
+
+ if cls.classes is None:
+ cls.classes = adict()
+
+ cls._setup_once_tables()
+ cls._setup_once_classes()
+ cls._setup_once_mappers()
+ cls._setup_once_inserts()
+
+ yield
+
+ cls._teardown_once_class()
+ cls._teardown_once_metadata_bind()
+
+ @config.fixture(autouse=True, scope="function")
+ def _setup_tables_test_instance(self):
+ self._setup_each_tables()
+ self._setup_each_classes()
+ self._setup_each_mappers()
+ self._setup_each_inserts()
+
+ yield
+
+ orm.session.close_all_sessions()
+ self._teardown_each_mappers()
+ self._teardown_each_classes()
+ self._teardown_each_tables()
+
+ @classmethod
+ def _teardown_once_class(cls):
+ cls.classes.clear()
+
+ @classmethod
+ def _setup_once_classes(cls):
+ if cls.run_setup_classes == "once":
+ cls._with_register_classes(cls.setup_classes)
+
+ @classmethod
+ def _setup_once_mappers(cls):
+ if cls.run_setup_mappers == "once":
+ cls.mapper_registry, cls.mapper = cls._generate_registry()
+ cls._with_register_classes(cls.setup_mappers)
+
+ def _setup_each_mappers(self):
+ if self.run_setup_mappers != "once":
+ (
+ self.__class__.mapper_registry,
+ self.__class__.mapper,
+ ) = self._generate_registry()
+
+ if self.run_setup_mappers == "each":
+ self._with_register_classes(self.setup_mappers)
+
+ def _setup_each_classes(self):
+ if self.run_setup_classes == "each":
+ self._with_register_classes(self.setup_classes)
+
+ @classmethod
+ def _generate_registry(cls):
+ decl = registry(metadata=cls._tables_metadata)
+ return decl, decl.map_imperatively
+
+ @classmethod
+ def _with_register_classes(cls, fn):
+ """Run a setup method, framing the operation with a Base class
+ that will catch new subclasses to be established within
+ the "classes" registry.
+
+ """
+ cls_registry = cls.classes
+
+ class _Base:
+ def __init_subclass__(cls) -> None:
+ assert cls_registry is not None
+ cls_registry[cls.__name__] = cls
+ super().__init_subclass__()
+
+ class Basic(BasicEntity, _Base):
+ pass
+
+ class Comparable(ComparableEntity, _Base):
+ pass
+
+ cls.Basic = Basic
+ cls.Comparable = Comparable
+ fn()
+
+ def _teardown_each_mappers(self):
+ # some tests create mappers in the test bodies
+ # and will define setup_mappers as None -
+ # clear mappers in any case
+ if self.run_setup_mappers != "once":
+ orm.clear_mappers()
+
+ def _teardown_each_classes(self):
+ if self.run_setup_classes != "once":
+ self.classes.clear()
+
+ @classmethod
+ def setup_classes(cls):
+ pass
+
+ @classmethod
+ def setup_mappers(cls):
+ pass
+
+
+class DeclarativeMappedTest(MappedTest):
+ run_setup_classes = "once"
+ run_setup_mappers = "once"
+
+ @classmethod
+ def _setup_once_tables(cls):
+ pass
+
+ @classmethod
+ def _with_register_classes(cls, fn):
+ cls_registry = cls.classes
+
+ class _DeclBase(DeclarativeBase):
+ __table_cls__ = schema.Table
+ metadata = cls._tables_metadata
+ type_annotation_map = {
+ str: sa.String().with_variant(
+ sa.String(50), "mysql", "mariadb", "oracle"
+ )
+ }
+
+ def __init_subclass__(cls, **kw) -> None:
+ assert cls_registry is not None
+ cls_registry[cls.__name__] = cls
+ super().__init_subclass__(**kw)
+
+ cls.DeclarativeBasic = _DeclBase
+
+ # sets up cls.Basic which is helpful for things like composite
+ # classes
+ super()._with_register_classes(fn)
+
+ if cls._tables_metadata.tables and cls.run_create_tables:
+ cls._tables_metadata.create_all(config.db)
+
+
+class RemoveORMEventsGlobally:
+ @config.fixture(autouse=True)
+ def _remove_listeners(self):
+ yield
+ orm_events.MapperEvents._clear()
+ orm_events.InstanceEvents._clear()
+ orm_events.SessionEvents._clear()
+ orm_events.InstrumentationEvents._clear()
+ orm_events.QueryEvents._clear()
+
+
+_fixture_sessions = set()
+
+
+def fixture_session(**kw):
+ kw.setdefault("autoflush", True)
+ kw.setdefault("expire_on_commit", True)
+
+ bind = kw.pop("bind", config.db)
+
+ sess = orm.Session(bind, **kw)
+ _fixture_sessions.add(sess)
+ return sess
+
+
+def close_all_sessions():
+ # will close all still-referenced sessions
+ orm.close_all_sessions()
+ _fixture_sessions.clear()
+
+
+def stop_test_class_inside_fixtures(cls):
+ close_all_sessions()
+ orm.clear_mappers()
+
+
+def after_test():
+ if _fixture_sessions:
+ close_all_sessions()
--- /dev/null
+# testing/fixtures/sql.py
+# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+# mypy: ignore-errors
+from __future__ import annotations
+
+import itertools
+import random
+import re
+import sys
+
+import sqlalchemy as sa
+from .base import TestBase
+from .. import config
+from .. import mock
+from ..assertions import eq_
+from ..assertions import ne_
+from ..util import adict
+from ..util import drop_all_tables_from_metadata
+from ... import event
+from ... import util
+from ...schema import sort_tables_and_constraints
+from ...sql import visitors
+from ...sql.elements import ClauseElement
+
+
+class TablesTest(TestBase):
+ # 'once', None
+ run_setup_bind = "once"
+
+ # 'once', 'each', None
+ run_define_tables = "once"
+
+ # 'once', 'each', None
+ run_create_tables = "once"
+
+ # 'once', 'each', None
+ run_inserts = "each"
+
+ # 'each', None
+ run_deletes = "each"
+
+ # 'once', None
+ run_dispose_bind = None
+
+ bind = None
+ _tables_metadata = None
+ tables = None
+ other = None
+ sequences = None
+
+ @config.fixture(autouse=True, scope="class")
+ def _setup_tables_test_class(self):
+ cls = self.__class__
+ cls._init_class()
+
+ cls._setup_once_tables()
+
+ cls._setup_once_inserts()
+
+ yield
+
+ cls._teardown_once_metadata_bind()
+
+ @config.fixture(autouse=True, scope="function")
+ def _setup_tables_test_instance(self):
+ self._setup_each_tables()
+ self._setup_each_inserts()
+
+ yield
+
+ self._teardown_each_tables()
+
+ @property
+ def tables_test_metadata(self):
+ return self._tables_metadata
+
+ @classmethod
+ def _init_class(cls):
+ if cls.run_define_tables == "each":
+ if cls.run_create_tables == "once":
+ cls.run_create_tables = "each"
+ assert cls.run_inserts in ("each", None)
+
+ cls.other = adict()
+ cls.tables = adict()
+ cls.sequences = adict()
+
+ cls.bind = cls.setup_bind()
+ cls._tables_metadata = sa.MetaData()
+
+ @classmethod
+ def _setup_once_inserts(cls):
+ if cls.run_inserts == "once":
+ cls._load_fixtures()
+ with cls.bind.begin() as conn:
+ cls.insert_data(conn)
+
+ @classmethod
+ def _setup_once_tables(cls):
+ if cls.run_define_tables == "once":
+ cls.define_tables(cls._tables_metadata)
+ if cls.run_create_tables == "once":
+ cls._tables_metadata.create_all(cls.bind)
+ cls.tables.update(cls._tables_metadata.tables)
+ cls.sequences.update(cls._tables_metadata._sequences)
+
+ def _setup_each_tables(self):
+ if self.run_define_tables == "each":
+ self.define_tables(self._tables_metadata)
+ if self.run_create_tables == "each":
+ self._tables_metadata.create_all(self.bind)
+ self.tables.update(self._tables_metadata.tables)
+ self.sequences.update(self._tables_metadata._sequences)
+ elif self.run_create_tables == "each":
+ self._tables_metadata.create_all(self.bind)
+
+ def _setup_each_inserts(self):
+ if self.run_inserts == "each":
+ self._load_fixtures()
+ with self.bind.begin() as conn:
+ self.insert_data(conn)
+
+ def _teardown_each_tables(self):
+ if self.run_define_tables == "each":
+ self.tables.clear()
+ if self.run_create_tables == "each":
+ drop_all_tables_from_metadata(self._tables_metadata, self.bind)
+ self._tables_metadata.clear()
+ elif self.run_create_tables == "each":
+ drop_all_tables_from_metadata(self._tables_metadata, self.bind)
+
+ savepoints = getattr(config.requirements, "savepoints", False)
+ if savepoints:
+ savepoints = savepoints.enabled
+
+ # no need to run deletes if tables are recreated on setup
+ if (
+ self.run_define_tables != "each"
+ and self.run_create_tables != "each"
+ and self.run_deletes == "each"
+ ):
+ with self.bind.begin() as conn:
+ for table in reversed(
+ [
+ t
+ for (t, fks) in sort_tables_and_constraints(
+ self._tables_metadata.tables.values()
+ )
+ if t is not None
+ ]
+ ):
+ try:
+ if savepoints:
+ with conn.begin_nested():
+ conn.execute(table.delete())
+ else:
+ conn.execute(table.delete())
+ except sa.exc.DBAPIError as ex:
+ print(
+ ("Error emptying table %s: %r" % (table, ex)),
+ file=sys.stderr,
+ )
+
+ @classmethod
+ def _teardown_once_metadata_bind(cls):
+ if cls.run_create_tables:
+ drop_all_tables_from_metadata(cls._tables_metadata, cls.bind)
+
+ if cls.run_dispose_bind == "once":
+ cls.dispose_bind(cls.bind)
+
+ cls._tables_metadata.bind = None
+
+ if cls.run_setup_bind is not None:
+ cls.bind = None
+
+ @classmethod
+ def setup_bind(cls):
+ return config.db
+
+ @classmethod
+ def dispose_bind(cls, bind):
+ if hasattr(bind, "dispose"):
+ bind.dispose()
+ elif hasattr(bind, "close"):
+ bind.close()
+
+ @classmethod
+ def define_tables(cls, metadata):
+ pass
+
+ @classmethod
+ def fixtures(cls):
+ return {}
+
+ @classmethod
+ def insert_data(cls, connection):
+ pass
+
+ def sql_count_(self, count, fn):
+ self.assert_sql_count(self.bind, fn, count)
+
+ def sql_eq_(self, callable_, statements):
+ self.assert_sql(self.bind, callable_, statements)
+
+ @classmethod
+ def _load_fixtures(cls):
+ """Insert rows as represented by the fixtures() method."""
+ headers, rows = {}, {}
+ for table, data in cls.fixtures().items():
+ if len(data) < 2:
+ continue
+ if isinstance(table, str):
+ table = cls.tables[table]
+ headers[table] = data[0]
+ rows[table] = data[1:]
+ for table, fks in sort_tables_and_constraints(
+ cls._tables_metadata.tables.values()
+ ):
+ if table is None:
+ continue
+ if table not in headers:
+ continue
+ with cls.bind.begin() as conn:
+ conn.execute(
+ table.insert(),
+ [
+ dict(zip(headers[table], column_values))
+ for column_values in rows[table]
+ ],
+ )
+
+
+class NoCache:
+ @config.fixture(autouse=True, scope="function")
+ def _disable_cache(self):
+ _cache = config.db._compiled_cache
+ config.db._compiled_cache = None
+ yield
+ config.db._compiled_cache = _cache
+
+
+class RemovesEvents:
+ @util.memoized_property
+ def _event_fns(self):
+ return set()
+
+ def event_listen(self, target, name, fn, **kw):
+ self._event_fns.add((target, name, fn))
+ event.listen(target, name, fn, **kw)
+
+ @config.fixture(autouse=True, scope="function")
+ def _remove_events(self):
+ yield
+ for key in self._event_fns:
+ event.remove(*key)
+
+
+class ComputedReflectionFixtureTest(TablesTest):
+ run_inserts = run_deletes = None
+
+ __backend__ = True
+ __requires__ = ("computed_columns", "table_reflection")
+
+ regexp = re.compile(r"[\[\]\(\)\s`'\"]*")
+
+ def normalize(self, text):
+ return self.regexp.sub("", text).lower()
+
+ @classmethod
+ def define_tables(cls, metadata):
+ from ... import Integer
+ from ... import testing
+ from ...schema import Column
+ from ...schema import Computed
+ from ...schema import Table
+
+ Table(
+ "computed_default_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("normal", Integer),
+ Column("computed_col", Integer, Computed("normal + 42")),
+ Column("with_default", Integer, server_default="42"),
+ )
+
+ t = Table(
+ "computed_column_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("normal", Integer),
+ Column("computed_no_flag", Integer, Computed("normal + 42")),
+ )
+
+ if testing.requires.schemas.enabled:
+ t2 = Table(
+ "computed_column_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("normal", Integer),
+ Column("computed_no_flag", Integer, Computed("normal / 42")),
+ schema=config.test_schema,
+ )
+
+ if testing.requires.computed_columns_virtual.enabled:
+ t.append_column(
+ Column(
+ "computed_virtual",
+ Integer,
+ Computed("normal + 2", persisted=False),
+ )
+ )
+ if testing.requires.schemas.enabled:
+ t2.append_column(
+ Column(
+ "computed_virtual",
+ Integer,
+ Computed("normal / 2", persisted=False),
+ )
+ )
+ if testing.requires.computed_columns_stored.enabled:
+ t.append_column(
+ Column(
+ "computed_stored",
+ Integer,
+ Computed("normal - 42", persisted=True),
+ )
+ )
+ if testing.requires.schemas.enabled:
+ t2.append_column(
+ Column(
+ "computed_stored",
+ Integer,
+ Computed("normal * 42", persisted=True),
+ )
+ )
+
+
+class CacheKeyFixture:
+ def _compare_equal(self, a, b, compare_values):
+ a_key = a._generate_cache_key()
+ b_key = b._generate_cache_key()
+
+ if a_key is None:
+ assert a._annotations.get("nocache")
+
+ assert b_key is None
+ else:
+ eq_(a_key.key, b_key.key)
+ eq_(hash(a_key.key), hash(b_key.key))
+
+ for a_param, b_param in zip(a_key.bindparams, b_key.bindparams):
+ assert a_param.compare(b_param, compare_values=compare_values)
+ return a_key, b_key
+
+ def _run_cache_key_fixture(self, fixture, compare_values):
+ case_a = fixture()
+ case_b = fixture()
+
+ for a, b in itertools.combinations_with_replacement(
+ range(len(case_a)), 2
+ ):
+ if a == b:
+ a_key, b_key = self._compare_equal(
+ case_a[a], case_b[b], compare_values
+ )
+ if a_key is None:
+ continue
+ else:
+ a_key = case_a[a]._generate_cache_key()
+ b_key = case_b[b]._generate_cache_key()
+
+ if a_key is None or b_key is None:
+ if a_key is None:
+ assert case_a[a]._annotations.get("nocache")
+ if b_key is None:
+ assert case_b[b]._annotations.get("nocache")
+ continue
+
+ if a_key.key == b_key.key:
+ for a_param, b_param in zip(
+ a_key.bindparams, b_key.bindparams
+ ):
+ if not a_param.compare(
+ b_param, compare_values=compare_values
+ ):
+ break
+ else:
+ # this fails unconditionally since we could not
+ # find bound parameter values that differed.
+ # Usually we intended to get two distinct keys here
+ # so the failure will be more descriptive using the
+ # ne_() assertion.
+ ne_(a_key.key, b_key.key)
+ else:
+ ne_(a_key.key, b_key.key)
+
+ # ClauseElement-specific test to ensure the cache key
+ # collected all the bound parameters that aren't marked
+ # as "literal execute"
+ if isinstance(case_a[a], ClauseElement) and isinstance(
+ case_b[b], ClauseElement
+ ):
+ assert_a_params = []
+ assert_b_params = []
+
+ for elem in visitors.iterate(case_a[a]):
+ if elem.__visit_name__ == "bindparam":
+ assert_a_params.append(elem)
+
+ for elem in visitors.iterate(case_b[b]):
+ if elem.__visit_name__ == "bindparam":
+ assert_b_params.append(elem)
+
+ # note we're asserting the order of the params as well as
+ # if there are dupes or not. ordering has to be
+ # deterministic and matches what a traversal would provide.
+ eq_(
+ sorted(a_key.bindparams, key=lambda b: b.key),
+ sorted(
+ util.unique_list(assert_a_params), key=lambda b: b.key
+ ),
+ )
+ eq_(
+ sorted(b_key.bindparams, key=lambda b: b.key),
+ sorted(
+ util.unique_list(assert_b_params), key=lambda b: b.key
+ ),
+ )
+
+ def _run_cache_key_equal_fixture(self, fixture, compare_values):
+ case_a = fixture()
+ case_b = fixture()
+
+ for a, b in itertools.combinations_with_replacement(
+ range(len(case_a)), 2
+ ):
+ self._compare_equal(case_a[a], case_b[b], compare_values)
+
+
+def insertmanyvalues_fixture(
+ connection, randomize_rows=False, warn_on_downgraded=False
+):
+ dialect = connection.dialect
+ orig_dialect = dialect._deliver_insertmanyvalues_batches
+ orig_conn = connection._exec_insertmany_context
+
+ class RandomCursor:
+ __slots__ = ("cursor",)
+
+ def __init__(self, cursor):
+ self.cursor = cursor
+
+ # only this method is called by the deliver method.
+ # by not having the other methods we assert that those aren't being
+ # used
+
+ def fetchall(self):
+ rows = self.cursor.fetchall()
+ rows = list(rows)
+ random.shuffle(rows)
+ return rows
+
+ def _deliver_insertmanyvalues_batches(
+ cursor, statement, parameters, generic_setinputsizes, context
+ ):
+ if randomize_rows:
+ cursor = RandomCursor(cursor)
+ for batch in orig_dialect(
+ cursor, statement, parameters, generic_setinputsizes, context
+ ):
+ if warn_on_downgraded and batch.is_downgraded:
+ util.warn("Batches were downgraded for sorted INSERT")
+
+ yield batch
+
+ def _exec_insertmany_context(
+ dialect,
+ context,
+ ):
+ with mock.patch.object(
+ dialect,
+ "_deliver_insertmanyvalues_batches",
+ new=_deliver_insertmanyvalues_batches,
+ ):
+ return orig_conn(dialect, context)
+
+ connection._exec_insertmany_context = _exec_insertmany_context
from __future__ import annotations
-from . import fixtures
+from .entities import ComparableEntity
from ..schema import Column
from ..types import String
-class User(fixtures.ComparableEntity):
+class User(ComparableEntity):
pass
-class Order(fixtures.ComparableEntity):
+class Order(ComparableEntity):
pass
-class Dingaling(fixtures.ComparableEntity):
+class Dingaling(ComparableEntity):
pass
pass
-class Address(fixtures.ComparableEntity):
+class Address(ComparableEntity):
pass
# TODO: these are kind of arbitrary....
-class Child1(fixtures.ComparableEntity):
+class Child1(ComparableEntity):
pass
-class Child2(fixtures.ComparableEntity):
+class Child2(ComparableEntity):
pass
-class Parent(fixtures.ComparableEntity):
+class Parent(ComparableEntity):
pass
email_address = Column(String)
-class AddressWMixin(Mixin, fixtures.ComparableEntity):
+class AddressWMixin(Mixin, ComparableEntity):
pass
from __future__ import annotations
import abc
+from argparse import Namespace
import configparser
import logging
import os
logging = None
include_tags = set()
exclude_tags = set()
-options = None
+options: Namespace = None # type: ignore
def setup_options(make_option):
import os
import re
import sys
+from typing import TYPE_CHECKING
import uuid
import pytest
try:
# installed by bootstrap.py
- import sqla_plugin_base as plugin_base
+ if not TYPE_CHECKING:
+ import sqla_plugin_base as plugin_base
except ImportError:
# assume we're a package, use traditional import
from . import plugin_base
# we have to hardcode some of that cleanup ahead of time.
# close ORM sessions
- fixtures._close_all_sessions()
+ fixtures.close_all_sessions()
# integrate with the "connection" fixture as there are many
# tests where it is used along with provide_metadata
- if fixtures._connection_fixture_connection:
+ cfc = fixtures.base._connection_fixture_connection
+ if cfc:
# TODO: this warning can be used to find all the places
# this is used with connection fixture
# warn("mixing legacy provide metadata with connection fixture")
- drop_all_tables_from_metadata(
- metadata, fixtures._connection_fixture_connection
- )
+ drop_all_tables_from_metadata(metadata, cfc)
# as the provide_metadata fixture is often used with "testing.db",
# when we do the drop we have to commit the transaction so that
# the DB is actually updated as the CREATE would have been
# committed
- fixtures._connection_fixture_connection.get_transaction().commit()
+ cfc.get_transaction().commit()
else:
drop_all_tables_from_metadata(metadata, config.db)
self.metadata = prev_meta
per-file-ignores =
**/__init__.py:F401
test/*:FA100
- test/ext/mypy/plain_files/*:F821,E501,FA100
+ test/typing/plain_files/*:F821,E501,FA100
test/ext/mypy/plugin_files/*:F821,E501,FA100
lib/sqlalchemy/events.py:F401
lib/sqlalchemy/schema.py:F401
from sqlalchemy.testing import eq_
from sqlalchemy.testing import fixtures
from sqlalchemy.testing import pickleable
+from sqlalchemy.testing.entities import ComparableEntity
from sqlalchemy.testing.fixtures import fixture_session
from sqlalchemy.testing.schema import Column
from sqlalchemy.testing.schema import Table
from ..orm import _fixtures
-class A(fixtures.ComparableEntity):
+class A(ComparableEntity):
pass
-class B(fixtures.ComparableEntity):
+class B(ComparableEntity):
pass
@profile_memory()
def go():
- class A(fixtures.ComparableEntity):
+ class A(ComparableEntity):
pass
class B(A):
@profile_memory()
def go():
- class A(fixtures.ComparableEntity):
+ class A(ComparableEntity):
pass
- class B(fixtures.ComparableEntity):
+ class B(ComparableEntity):
pass
self.mapper_registry.map_imperatively(
from sqlalchemy.testing import mock
from sqlalchemy.testing.assertions import expect_deprecated
from sqlalchemy.testing.assertions import is_false
+from sqlalchemy.testing.entities import ComparableEntity
from sqlalchemy.testing.provision import normalize_sequence
from .test_engine_py3k import AsyncFixture as _AsyncFixture
from ...orm import _fixtures
def decl_base(self, metadata):
_md = metadata
- class Base(fixtures.ComparableEntity, AsyncAttrs, DeclarativeBase):
+ class Base(ComparableEntity, AsyncAttrs, DeclarativeBase):
metadata = _md
type_annotation_map = {
str: String().with_variant(
from sqlalchemy.testing import fixtures
from sqlalchemy.testing import mock
from sqlalchemy.testing.assertions import expect_raises_message
+from sqlalchemy.testing.entities import ComparableEntity
from sqlalchemy.testing.fixtures import fixture_session
from sqlalchemy.testing.fixtures import RemoveORMEventsGlobally
from sqlalchemy.testing.schema import Column
"punion",
)
- class Employee(Base, fixtures.ComparableEntity):
+ class Employee(Base, ComparableEntity):
__table__ = punion
__mapper_args__ = {"polymorphic_on": punion.c.type}
def test_concrete_inline_non_polymorphic(self):
"""test the example from the declarative docs."""
- class Employee(Base, fixtures.ComparableEntity):
+ class Employee(Base, ComparableEntity):
__tablename__ = "people"
id = Column(
Integer, primary_key=True, test_needs_autoincrement=True
self._roundtrip(Employee, Manager, Engineer, Boss, polymorphic=False)
def test_abstract_concrete_base_didnt_configure(self):
- class Employee(AbstractConcreteBase, Base, fixtures.ComparableEntity):
+ class Employee(AbstractConcreteBase, Base, ComparableEntity):
strict_attrs = True
assert_raises_message(
)
def test_abstract_concrete_extension(self):
- class Employee(AbstractConcreteBase, Base, fixtures.ComparableEntity):
+ class Employee(AbstractConcreteBase, Base, ComparableEntity):
name = Column(String(50))
class Manager(Employee):
def test_abstract_concrete_extension_descriptor_refresh(
self, use_strict_attrs
):
- class Employee(AbstractConcreteBase, Base, fixtures.ComparableEntity):
+ class Employee(AbstractConcreteBase, Base, ComparableEntity):
strict_attrs = use_strict_attrs
@declared_attr
eq_(e1.name, "d")
def test_concrete_extension(self):
- class Employee(ConcreteBase, Base, fixtures.ComparableEntity):
+ class Employee(ConcreteBase, Base, ComparableEntity):
__tablename__ = "employee"
employee_id = Column(
Integer, primary_key=True, test_needs_autoincrement=True
self._roundtrip(Employee, Manager, Engineer, Boss)
def test_concrete_extension_warn_for_overlap(self):
- class Employee(ConcreteBase, Base, fixtures.ComparableEntity):
+ class Employee(ConcreteBase, Base, ComparableEntity):
__tablename__ = "employee"
employee_id = Column(
configure_mappers()
def test_concrete_extension_warn_concrete_disc_resolves_overlap(self):
- class Employee(ConcreteBase, Base, fixtures.ComparableEntity):
+ class Employee(ConcreteBase, Base, ComparableEntity):
_concrete_discriminator_name = "_type"
__tablename__ = "employee"
)
def test_abs_concrete_extension_warn_for_overlap(self):
- class Employee(AbstractConcreteBase, Base, fixtures.ComparableEntity):
+ class Employee(AbstractConcreteBase, Base, ComparableEntity):
name = Column(String(50))
__mapper_args__ = {
"polymorphic_identity": "employee",
def test_abs_concrete_extension_warn_concrete_disc_resolves_overlap(
self, use_strict_attrs
):
- class Employee(AbstractConcreteBase, Base, fixtures.ComparableEntity):
+ class Employee(AbstractConcreteBase, Base, ComparableEntity):
strict_attrs = use_strict_attrs
_concrete_discriminator_name = "_type"
assert PolyTest.__mapper__.polymorphic_on is Test.__table__.c.type
def test_ok_to_override_type_from_abstract(self):
- class Employee(AbstractConcreteBase, Base, fixtures.ComparableEntity):
+ class Employee(AbstractConcreteBase, Base, ComparableEntity):
name = Column(String(50))
class Manager(Employee):
__dialect__ = "default"
def test_classreg_setup(self):
- class A(Base, fixtures.ComparableEntity):
+ class A(Base, ComparableEntity):
__tablename__ = "a"
id = Column(
Integer, primary_key=True, test_needs_autoincrement=True
"BC", primaryjoin="BC.a_id == A.id", collection_class=set
)
- class BC(AbstractConcreteBase, Base, fixtures.ComparableEntity):
+ class BC(AbstractConcreteBase, Base, ComparableEntity):
a_id = Column(Integer, ForeignKey("a.id"))
class B(BC):
from sqlalchemy.testing import eq_
from sqlalchemy.testing import expect_raises_message
from sqlalchemy.testing import fixtures
+from sqlalchemy.testing.entities import ComparableEntity
from sqlalchemy.testing.fixtures import fixture_session
from sqlalchemy.testing.schema import Column
from sqlalchemy.testing.schema import Table
)
def test_pk_fk(self):
- class B(DeferredReflection, fixtures.ComparableEntity, Base):
+ class B(DeferredReflection, ComparableEntity, Base):
__tablename__ = "b"
a = relationship("A")
- class A(DeferredReflection, fixtures.ComparableEntity, Base):
+ class A(DeferredReflection, ComparableEntity, Base):
__tablename__ = "a"
DeferredReflection.prepare(testing.db)
eq_(a1.user, User(name="u1"))
def test_exception_prepare_not_called(self):
- class User(DeferredReflection, fixtures.ComparableEntity, Base):
+ class User(DeferredReflection, ComparableEntity, Base):
__tablename__ = "users"
addresses = relationship("Address", backref="user")
- class Address(DeferredReflection, fixtures.ComparableEntity, Base):
+ class Address(DeferredReflection, ComparableEntity, Base):
__tablename__ = "addresses"
assert_raises_message(
@testing.variation("bind", ["engine", "connection", "raise_"])
def test_basic_deferred(self, bind):
- class User(DeferredReflection, fixtures.ComparableEntity, Base):
+ class User(DeferredReflection, ComparableEntity, Base):
__tablename__ = "users"
addresses = relationship("Address", backref="user")
- class Address(DeferredReflection, fixtures.ComparableEntity, Base):
+ class Address(DeferredReflection, ComparableEntity, Base):
__tablename__ = "addresses"
if bind.engine:
class OtherDefBase(DeferredReflection, Base):
__abstract__ = True
- class User(fixtures.ComparableEntity, DefBase):
+ class User(ComparableEntity, DefBase):
__tablename__ = "users"
addresses = relationship("Address", backref="user")
- class Address(fixtures.ComparableEntity, DefBase):
+ class Address(ComparableEntity, DefBase):
__tablename__ = "addresses"
class Fake(OtherDefBase):
self._roundtrip()
def test_redefine_fk_double(self):
- class User(DeferredReflection, fixtures.ComparableEntity, Base):
+ class User(DeferredReflection, ComparableEntity, Base):
__tablename__ = "users"
addresses = relationship("Address", backref="user")
- class Address(DeferredReflection, fixtures.ComparableEntity, Base):
+ class Address(DeferredReflection, ComparableEntity, Base):
__tablename__ = "addresses"
user_id = Column(Integer, ForeignKey("users.id"))
"""test that __mapper_args__ is not called until *after*
table reflection"""
- class User(DeferredReflection, fixtures.ComparableEntity, Base):
+ class User(DeferredReflection, ComparableEntity, Base):
__tablename__ = "users"
@declared_attr
@testing.requires.predictable_gc
def test_cls_not_strong_ref(self):
- class User(DeferredReflection, fixtures.ComparableEntity, Base):
+ class User(DeferredReflection, ComparableEntity, Base):
__tablename__ = "users"
- class Address(DeferredReflection, fixtures.ComparableEntity, Base):
+ class Address(DeferredReflection, ComparableEntity, Base):
__tablename__ = "addresses"
eq_(len(_DeferredMapperConfig._configs), 2)
)
def test_string_resolution(self):
- class User(DeferredReflection, fixtures.ComparableEntity, Base):
+ class User(DeferredReflection, ComparableEntity, Base):
__tablename__ = "users"
items = relationship("Item", secondary="user_items")
- class Item(DeferredReflection, fixtures.ComparableEntity, Base):
+ class Item(DeferredReflection, ComparableEntity, Base):
__tablename__ = "items"
DeferredReflection.prepare(testing.db)
self._roundtrip()
def test_table_resolution(self):
- class User(DeferredReflection, fixtures.ComparableEntity, Base):
+ class User(DeferredReflection, ComparableEntity, Base):
__tablename__ = "users"
items = relationship(
"Item", secondary=Table("user_items", Base.metadata)
)
- class Item(DeferredReflection, fixtures.ComparableEntity, Base):
+ class Item(DeferredReflection, ComparableEntity, Base):
__tablename__ = "items"
DeferredReflection.prepare(testing.db)
)
def test_basic(self, decl_base):
- class Foo(DeferredReflection, fixtures.ComparableEntity, decl_base):
+ class Foo(DeferredReflection, ComparableEntity, decl_base):
__tablename__ = "foo"
__mapper_args__ = {
"polymorphic_on": "type",
self._roundtrip()
def test_add_subclass_column(self, decl_base):
- class Foo(DeferredReflection, fixtures.ComparableEntity, decl_base):
+ class Foo(DeferredReflection, ComparableEntity, decl_base):
__tablename__ = "foo"
__mapper_args__ = {
"polymorphic_on": "type",
self._roundtrip()
def test_add_subclass_mapped_column(self, decl_base):
- class Foo(DeferredReflection, fixtures.ComparableEntity, decl_base):
+ class Foo(DeferredReflection, ComparableEntity, decl_base):
__tablename__ = "foo"
__mapper_args__ = {
"polymorphic_on": "type",
self._roundtrip()
def test_subclass_mapped_column_no_existing(self, decl_base):
- class Foo(DeferredReflection, fixtures.ComparableEntity, decl_base):
+ class Foo(DeferredReflection, ComparableEntity, decl_base):
__tablename__ = "foo"
__mapper_args__ = {
"polymorphic_on": "type",
bar_data: Mapped[str] = mapped_column(use_existing_column=True)
def test_add_pk_column(self, decl_base):
- class Foo(DeferredReflection, fixtures.ComparableEntity, decl_base):
+ class Foo(DeferredReflection, ComparableEntity, decl_base):
__tablename__ = "foo"
__mapper_args__ = {
"polymorphic_on": "type",
self._roundtrip()
def test_add_pk_mapped_column(self, decl_base):
- class Foo(DeferredReflection, fixtures.ComparableEntity, decl_base):
+ class Foo(DeferredReflection, ComparableEntity, decl_base):
__tablename__ = "foo"
__mapper_args__ = {
"polymorphic_on": "type",
)
def test_basic(self):
- class Foo(DeferredReflection, fixtures.ComparableEntity, Base):
+ class Foo(DeferredReflection, ComparableEntity, Base):
__tablename__ = "foo"
__mapper_args__ = {
"polymorphic_on": "type",
self._roundtrip()
def test_add_subclass_column(self):
- class Foo(DeferredReflection, fixtures.ComparableEntity, Base):
+ class Foo(DeferredReflection, ComparableEntity, Base):
__tablename__ = "foo"
__mapper_args__ = {
"polymorphic_on": "type",
self._roundtrip()
def test_add_pk_column(self):
- class Foo(DeferredReflection, fixtures.ComparableEntity, Base):
+ class Foo(DeferredReflection, ComparableEntity, Base):
__tablename__ = "foo"
__mapper_args__ = {
"polymorphic_on": "type",
self._roundtrip()
def test_add_fk_pk_column(self):
- class Foo(DeferredReflection, fixtures.ComparableEntity, Base):
+ class Foo(DeferredReflection, ComparableEntity, Base):
__tablename__ = "foo"
__mapper_args__ = {
"polymorphic_on": "type",
+++ /dev/null
-from sqlalchemy import CheckConstraint
-from sqlalchemy import Column
-from sqlalchemy import DateTime
-from sqlalchemy import ForeignKey
-from sqlalchemy import Index
-from sqlalchemy import Integer
-from sqlalchemy import MetaData
-from sqlalchemy import PrimaryKeyConstraint
-from sqlalchemy import String
-from sqlalchemy import Table
-
-
-m = MetaData()
-
-
-t1 = Table(
- "t1",
- m,
- Column("id", Integer, primary_key=True),
- Column("data", String),
- Column("data2", String(50)),
- Column("timestamp", DateTime()),
- Index(None, "data2"),
-)
-
-t2 = Table(
- "t2",
- m,
- Column("t1id", ForeignKey("t1.id")),
- Column("q", Integer, CheckConstraint("q > 5")),
-)
-
-t3 = Table(
- "t3",
- m,
- Column("x", Integer),
- Column("y", Integer),
- Column("t1id", ForeignKey(t1.c.id)),
- PrimaryKeyConstraint("x", "y"),
-)
-
-# cols w/ no name or type, used by declarative
-c1: Column[int] = Column(ForeignKey(t3.c.x))
import os
-import re
import shutil
-import sys
-import tempfile
-from typing import Any
-from typing import cast
-from typing import List
-from typing import Tuple
from sqlalchemy import testing
-from sqlalchemy import util
-from sqlalchemy.testing import config
from sqlalchemy.testing import eq_
from sqlalchemy.testing import fixtures
-def _file_combinations(dirname):
- path = os.path.join(os.path.dirname(__file__), dirname)
- files = []
- for f in os.listdir(path):
- if f.endswith(".py"):
- files.append(os.path.join(os.path.dirname(__file__), dirname, f))
-
- for extra_dir in testing.config.options.mypy_extra_test_paths:
- if extra_dir and os.path.isdir(extra_dir):
- for f in os.listdir(os.path.join(extra_dir, dirname)):
- if f.endswith(".py"):
- files.append(os.path.join(extra_dir, dirname, f))
- return files
-
-
def _incremental_dirs():
path = os.path.join(os.path.dirname(__file__), "incremental")
files = []
return files
-@testing.add_to_marker.mypy
-class MypyPluginTest(fixtures.TestBase):
- __tags__ = ("mypy",)
- __requires__ = ("no_sqlalchemy2_stubs",)
-
- @testing.fixture(scope="function")
- def per_func_cachedir(self):
- yield from self._cachedir()
-
- @testing.fixture(scope="class")
- def cachedir(self):
- yield from self._cachedir()
-
- def _cachedir(self):
- # as of mypy 0.971 i think we need to keep mypy_path empty
- mypy_path = ""
-
- with tempfile.TemporaryDirectory() as cachedir:
- with open(
- os.path.join(cachedir, "sqla_mypy_config.cfg"), "w"
- ) as config_file:
- config_file.write(
- f"""
- [mypy]\n
- plugins = sqlalchemy.ext.mypy.plugin\n
- show_error_codes = True\n
- {mypy_path}
- disable_error_code = no-untyped-call
-
- [mypy-sqlalchemy.*]
- ignore_errors = True
-
- """
- )
- with open(
- os.path.join(cachedir, "plain_mypy_config.cfg"), "w"
- ) as config_file:
- config_file.write(
- f"""
- [mypy]\n
- show_error_codes = True\n
- {mypy_path}
- disable_error_code = var-annotated,no-untyped-call
- [mypy-sqlalchemy.*]
- ignore_errors = True
-
- """
- )
- yield cachedir
-
- @testing.fixture()
- def mypy_runner(self, cachedir):
- from mypy import api
-
- def run(path, use_plugin=True, incremental=False):
- args = [
- "--strict",
- "--raise-exceptions",
- "--cache-dir",
- cachedir,
- "--config-file",
- os.path.join(
- cachedir,
- "sqla_mypy_config.cfg"
- if use_plugin
- else "plain_mypy_config.cfg",
- ),
- ]
-
- # mypy as of 0.990 is more aggressively blocking messaging
- # for paths that are in sys.path, and as pytest puts currdir,
- # test/ etc in sys.path, just copy the source file to the
- # tempdir we are working in so that we don't have to try to
- # manipulate sys.path and/or guess what mypy is doing
- filename = os.path.basename(path)
- test_program = os.path.join(cachedir, filename)
- shutil.copyfile(path, test_program)
- args.append(test_program)
-
- # I set this locally but for the suite here needs to be
- # disabled
- os.environ.pop("MYPY_FORCE_COLOR", None)
-
- result = api.run(args)
- return result
-
- return run
-
+class MypyPluginTest(fixtures.MypyTest):
@testing.combinations(
- *[
- (pathname, testing.exclusions.closed())
- for pathname in _incremental_dirs()
- ],
+ *[(pathname) for pathname in _incremental_dirs()],
argnames="pathname",
)
@testing.requires.patch_library
result = mypy_runner(
dest,
use_plugin=True,
- incremental=True,
+ use_cachedir=cachedir,
)
eq_(
result[2],
@testing.combinations(
*(
- cast(
- List[Tuple[Any, ...]],
- [
- ("w_plugin", os.path.basename(path), path, True)
- for path in _file_combinations("plugin_files")
- ],
- )
- + cast(
- List[Tuple[Any, ...]],
- [
- ("plain", os.path.basename(path), path, False)
- for path in _file_combinations("plain_files")
- ],
- )
+ (os.path.basename(path), path, True)
+ for path in fixtures.MypyTest.file_combinations("plugin_files")
),
- argnames="filename,path,use_plugin",
- id_="isaa",
+ argnames="path",
+ id_="ia",
)
- def test_files(self, mypy_runner, filename, path, use_plugin):
- expected_messages = []
- expected_re = re.compile(r"\s*# EXPECTED(_MYPY)?(_RE)?(_TYPE)?: (.+)")
- py_ver_re = re.compile(r"^#\s*PYTHON_VERSION\s?>=\s?(\d+\.\d+)")
-
- from sqlalchemy.ext.mypy.util import mypy_14
-
- with open(path) as file_:
- current_assert_messages = []
- for num, line in enumerate(file_, 1):
- m = py_ver_re.match(line)
- if m:
- major, _, minor = m.group(1).partition(".")
- if sys.version_info < (int(major), int(minor)):
- config.skip_test(
- "Requires python >= %s" % (m.group(1))
- )
- continue
- if line.startswith("# NOPLUGINS"):
- use_plugin = False
- continue
-
- m = expected_re.match(line)
- if m:
- is_mypy = bool(m.group(1))
- is_re = bool(m.group(2))
- is_type = bool(m.group(3))
-
- expected_msg = re.sub(r"# noqa[:]? ?.*", "", m.group(4))
-
- if is_type:
- if not is_re:
- # the goal here is that we can cut-and-paste
- # from vscode -> pylance into the
- # EXPECTED_TYPE: line, then the test suite will
- # validate that line against what mypy produces
- expected_msg = re.sub(
- r"([\[\]])",
- lambda m: rf"\{m.group(0)}",
- expected_msg,
- )
-
- # note making sure preceding text matches
- # with a dot, so that an expect for "Select"
- # does not match "TypedSelect"
- expected_msg = re.sub(
- r"([\w_]+)",
- lambda m: rf"(?:.*\.)?{m.group(1)}\*?",
- expected_msg,
- )
-
- expected_msg = re.sub(
- "List", "builtins.list", expected_msg
- )
-
- expected_msg = re.sub(
- r"\b(int|str|float|bool)\b",
- lambda m: rf"builtins.{m.group(0)}\*?",
- expected_msg,
- )
- # expected_msg = re.sub(
- # r"(Sequence|Tuple|List|Union)",
- # lambda m: fr"typing.{m.group(0)}\*?",
- # expected_msg,
- # )
-
- is_mypy = is_re = True
- expected_msg = f'Revealed type is "{expected_msg}"'
-
- if mypy_14 and util.py39:
- # use_lowercase_names, py39 and above
- # https://github.com/python/mypy/blob/304997bfb85200fb521ac727ee0ce3e6085e5278/mypy/options.py#L363 # noqa: E501
-
- # skip first character which could be capitalized
- # "List item x not found" type of message
- expected_msg = expected_msg[0] + re.sub(
- r"\b(List|Tuple|Dict|Set)\b"
- if is_type
- else r"\b(List|Tuple|Dict|Set|Type)\b",
- lambda m: m.group(1).lower(),
- expected_msg[1:],
- )
-
- if mypy_14 and util.py310:
- # use_or_syntax, py310 and above
- # https://github.com/python/mypy/blob/304997bfb85200fb521ac727ee0ce3e6085e5278/mypy/options.py#L368 # noqa: E501
- expected_msg = re.sub(
- r"Optional\[(.*?)\]",
- lambda m: f"{m.group(1)} | None",
- expected_msg,
- )
- current_assert_messages.append(
- (is_mypy, is_re, expected_msg.strip())
- )
- elif current_assert_messages:
- expected_messages.extend(
- (num, is_mypy, is_re, expected_msg)
- for (
- is_mypy,
- is_re,
- expected_msg,
- ) in current_assert_messages
- )
- current_assert_messages[:] = []
-
- result = mypy_runner(path, use_plugin=use_plugin)
-
- not_located = []
-
- if expected_messages:
- # mypy 0.990 changed how return codes work, so don't assume a
- # 1 or a 0 return code here, could be either depending on if
- # errors were generated or not
-
- output = []
-
- raw_lines = result[0].split("\n")
- while raw_lines:
- e = raw_lines.pop(0)
- if re.match(r".+\.py:\d+: error: .*", e):
- output.append(("error", e))
- elif re.match(
- r".+\.py:\d+: note: +(?:Possible overload|def ).*", e
- ):
- while raw_lines:
- ol = raw_lines.pop(0)
- if not re.match(r".+\.py:\d+: note: +def \[.*", ol):
- break
- elif re.match(
- r".+\.py:\d+: note: .*(?:perhaps|suggestion)", e, re.I
- ):
- pass
- elif re.match(r".+\.py:\d+: note: .*", e):
- output.append(("note", e))
-
- for num, is_mypy, is_re, msg in expected_messages:
- msg = msg.replace("'", '"')
- prefix = "[SQLAlchemy Mypy plugin] " if not is_mypy else ""
- for idx, (typ, errmsg) in enumerate(output):
- if is_re:
- if re.match(
- rf".*{filename}\:{num}\: {typ}\: {prefix}{msg}", # noqa: E501
- errmsg,
- ):
- break
- elif (
- f"{filename}:{num}: {typ}: {prefix}{msg}"
- in errmsg.replace("'", '"')
- ):
- break
- else:
- not_located.append(msg)
- continue
- del output[idx]
-
- if not_located:
- print(f"Couldn't locate expected messages: {not_located}")
- print("\n".join(msg for _, msg in output))
- assert False, "expected messages not found, see stdout"
-
- if output:
- print(f"{len(output)} messages from mypy were not consumed:")
- print("\n".join(msg for _, msg in output))
- assert False, "errors and/or notes remain, see stdout"
-
- else:
- if result[2] != 0:
- print(result[0])
-
- eq_(result[2], 0, msg=result)
+ def test_plugin_files(self, mypy_typecheck_file, path):
+ mypy_typecheck_file(path, use_plugin=True)
from sqlalchemy.testing import is_
from sqlalchemy.testing import is_false
from sqlalchemy.testing.assertions import expect_raises_message
+from sqlalchemy.testing.entities import ComparableEntity # noqa
from sqlalchemy.testing.entities import ComparableMixin # noqa
from sqlalchemy.testing.fixtures import fixture_session
from sqlalchemy.testing.schema import Column
creator=lambda point: PointData(point=point),
)
- class PointData(fixtures.ComparableEntity, cls.DeclarativeBasic):
+ class PointData(ComparableEntity, cls.DeclarativeBasic):
__tablename__ = "point"
id = Column(
from sqlalchemy.testing import is_
from sqlalchemy.testing import ne_
from sqlalchemy.testing import not_in
+from sqlalchemy.testing.entities import ComparableEntity
from sqlalchemy.testing.schema import Column
def setup_classes(cls):
Base = cls.DeclarativeBasic
- class Array(fixtures.ComparableEntity, Base):
+ class Array(ComparableEntity, Base):
__tablename__ = "array"
id = Column(
expr = super().expr(model)
return expr.astext.cast(self.cast_type)
- class Json(fixtures.ComparableEntity, Base):
+ class Json(ComparableEntity, Base):
__tablename__ = "json"
id = Column(
from sqlalchemy.testing import is_
from sqlalchemy.testing import is_true
from sqlalchemy.testing import mock
+from sqlalchemy.testing.entities import BasicEntity
from sqlalchemy.testing.fixtures import fixture_session
from sqlalchemy.testing.schema import Column
from sqlalchemy.testing.schema import Table
from sqlalchemy.types import VARCHAR
-class Foo(fixtures.BasicEntity):
+class Foo(BasicEntity):
pass
pass
-class Foo2(fixtures.BasicEntity):
+class Foo2(BasicEntity):
pass
return self.id == other.id
-class FooWNoHash(fixtures.BasicEntity):
+class FooWNoHash(BasicEntity):
__hash__ = None
from sqlalchemy.testing import AssertsCompiledSQL
from sqlalchemy.testing import eq_
from sqlalchemy.testing import fixtures
+from sqlalchemy.testing.entities import ComparableEntity
from sqlalchemy.testing.schema import Column
from sqlalchemy.testing.schema import Table
# return iter([-1, 0, 1, 2])
-class User(fixtures.ComparableEntity):
+class User(ComparableEntity):
pass
-class Address(fixtures.ComparableEntity):
+class Address(ComparableEntity):
pass
from sqlalchemy.testing import fixtures
from sqlalchemy.testing import is_
from sqlalchemy.testing import mock
+from sqlalchemy.testing.entities import ComparableEntity
from sqlalchemy.testing.fixtures import fixture_session
from sqlalchemy.testing.schema import Column
from sqlalchemy.testing.schema import Table
xyzzy = "magic"
# _as_declarative() inspects obj.__class__.__bases__
- class User(BrokenParent, fixtures.ComparableEntity):
+ class User(BrokenParent, ComparableEntity):
__tablename__ = "users"
id = Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
assert Base().foobar() == "foobar"
def test_as_declarative(self, metadata):
- class User(fixtures.ComparableEntity):
+ class User(ComparableEntity):
__tablename__ = "users"
id = Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
name = Column("name", String(50))
addresses = relationship("Address", backref="user")
- class Address(fixtures.ComparableEntity):
+ class Address(ComparableEntity):
__tablename__ = "addresses"
id = Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
)
def test_map_declaratively(self, metadata):
- class User(fixtures.ComparableEntity):
+ class User(ComparableEntity):
__tablename__ = "users"
id = Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
name = Column("name", String(50))
addresses = relationship("Address", backref="user")
- class Address(fixtures.ComparableEntity):
+ class Address(ComparableEntity):
__tablename__ = "addresses"
id = Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
testing.config.skip_test("current base has no metaclass")
def test_basic(self):
- class User(Base, fixtures.ComparableEntity):
+ class User(Base, ComparableEntity):
__tablename__ = "users"
id = Column(
name = Column("name", String(50))
addresses = relationship("Address", backref="user")
- class Address(Base, fixtures.ComparableEntity):
+ class Address(Base, ComparableEntity):
__tablename__ = "addresses"
id = Column(
)
def test_unicode_string_resolve(self):
- class User(Base, fixtures.ComparableEntity):
+ class User(Base, ComparableEntity):
__tablename__ = "users"
id = Column(
name = Column("name", String(50))
addresses = relationship("Address", backref="user")
- class Address(Base, fixtures.ComparableEntity):
+ class Address(Base, ComparableEntity):
__tablename__ = "addresses"
id = Column(
assert User.addresses.property.mapper.class_ is Address
def test_unicode_string_resolve_backref(self):
- class User(Base, fixtures.ComparableEntity):
+ class User(Base, ComparableEntity):
__tablename__ = "users"
id = Column(
)
name = Column("name", String(50))
- class Address(Base, fixtures.ComparableEntity):
+ class Address(Base, ComparableEntity):
__tablename__ = "addresses"
id = Column(
assert User.__mapper__.registry._new_mappers is False
def test_string_dependency_resolution(self):
- class User(Base, fixtures.ComparableEntity):
+ class User(Base, ComparableEntity):
__tablename__ = "users"
id = Column(
Integer, primary_key=True, test_needs_autoincrement=True
),
)
- class Address(Base, fixtures.ComparableEntity):
+ class Address(Base, ComparableEntity):
__tablename__ = "addresses"
id = Column(
Integer, primary_key=True, test_needs_autoincrement=True
),
)
- class Foo(Base, fixtures.ComparableEntity):
+ class Foo(Base, ComparableEntity):
__tablename__ = "foo"
id = Column(Integer, primary_key=True)
rel = relationship("User", primaryjoin="User.addresses==Foo.id")
)
def test_string_dependency_resolution_synonym(self):
- class User(Base, fixtures.ComparableEntity):
+ class User(Base, ComparableEntity):
__tablename__ = "users"
id = Column(
Integer, primary_key=True, test_needs_autoincrement=True
sess.expunge_all()
eq_(sess.query(User).filter(User.name == "ed").one(), User(name="ed"))
- class Foo(Base, fixtures.ComparableEntity):
+ class Foo(Base, ComparableEntity):
__tablename__ = "foo"
id = Column(Integer, primary_key=True)
_user_id = Column(Integer)
)
def test_string_dependency_resolution_no_table(self):
- class User(Base, fixtures.ComparableEntity):
+ class User(Base, ComparableEntity):
__tablename__ = "users"
id = Column(
Integer, primary_key=True, test_needs_autoincrement=True
)
name = Column(String(50))
- class Bar(Base, fixtures.ComparableEntity):
+ class Bar(Base, ComparableEntity):
__tablename__ = "bar"
id = Column(Integer, primary_key=True)
rel = relationship("User", primaryjoin="User.id==Bar.__table__.id")
)
def test_string_w_pj_annotations(self):
- class User(Base, fixtures.ComparableEntity):
+ class User(Base, ComparableEntity):
__tablename__ = "users"
id = Column(
Integer, primary_key=True, test_needs_autoincrement=True
)
name = Column(String(50))
- class Address(Base, fixtures.ComparableEntity):
+ class Address(Base, ComparableEntity):
__tablename__ = "addresses"
id = Column(
Integer, primary_key=True, test_needs_autoincrement=True
def test_string_dependency_resolution_no_magic(self):
"""test that full tinkery expressions work as written"""
- class User(Base, fixtures.ComparableEntity):
+ class User(Base, ComparableEntity):
__tablename__ = "users"
id = Column(Integer, primary_key=True)
addresses = relationship(
primaryjoin="User.id==Address.user_id.prop.columns[0]",
)
- class Address(Base, fixtures.ComparableEntity):
+ class Address(Base, ComparableEntity):
__tablename__ = "addresses"
id = Column(Integer, primary_key=True)
user_id = Column(Integer, ForeignKey("users.id"))
)
def test_string_dependency_resolution_module_qualified(self):
- class User(Base, fixtures.ComparableEntity):
+ class User(Base, ComparableEntity):
__tablename__ = "users"
id = Column(Integer, primary_key=True)
addresses = relationship(
% (__name__, __name__),
)
- class Address(Base, fixtures.ComparableEntity):
+ class Address(Base, ComparableEntity):
__tablename__ = "addresses"
id = Column(Integer, primary_key=True)
user_id = Column(Integer, ForeignKey("users.id"))
)
def test_string_dependency_resolution_in_backref(self):
- class User(Base, fixtures.ComparableEntity):
+ class User(Base, ComparableEntity):
__tablename__ = "users"
id = Column(Integer, primary_key=True)
name = Column(String(50))
backref="user",
)
- class Address(Base, fixtures.ComparableEntity):
+ class Address(Base, ComparableEntity):
__tablename__ = "addresses"
id = Column(Integer, primary_key=True)
email = Column(String(50))
)
def test_string_dependency_resolution_tables(self):
- class User(Base, fixtures.ComparableEntity):
+ class User(Base, ComparableEntity):
__tablename__ = "users"
id = Column(Integer, primary_key=True)
name = Column(String(50))
backref="users",
)
- class Prop(Base, fixtures.ComparableEntity):
+ class Prop(Base, ComparableEntity):
__tablename__ = "props"
id = Column(Integer, primary_key=True)
name = Column(String(50))
def test_string_dependency_resolution_table_over_class(self):
# test for second half of #5774
- class User(Base, fixtures.ComparableEntity):
+ class User(Base, ComparableEntity):
__tablename__ = "users"
id = Column(Integer, primary_key=True)
name = Column(String(50))
backref="users",
)
- class Prop(Base, fixtures.ComparableEntity):
+ class Prop(Base, ComparableEntity):
__tablename__ = "props"
id = Column(Integer, primary_key=True)
name = Column(String(50))
def test_string_dependency_resolution_class_over_table(self):
# test for second half of #5774
- class User(Base, fixtures.ComparableEntity):
+ class User(Base, ComparableEntity):
__tablename__ = "users"
id = Column(Integer, primary_key=True)
name = Column(String(50))
)
def test_uncompiled_attributes_in_relationship(self):
- class Address(Base, fixtures.ComparableEntity):
+ class Address(Base, ComparableEntity):
__tablename__ = "addresses"
id = Column(
Integer, primary_key=True, test_needs_autoincrement=True
email = Column(String(50))
user_id = Column(Integer, ForeignKey("users.id"))
- class User(Base, fixtures.ComparableEntity):
+ class User(Base, ComparableEntity):
__tablename__ = "users"
id = Column(
Integer, primary_key=True, test_needs_autoincrement=True
def test_add_prop_auto(
self, require_metaclass, assert_user_address_mapping, _column
):
- class User(Base, fixtures.ComparableEntity):
+ class User(Base, ComparableEntity):
__tablename__ = "users"
id = Column("id", Integer, primary_key=True)
User.name = _column("name", String(50))
User.addresses = relationship("Address", backref="user")
- class Address(Base, fixtures.ComparableEntity):
+ class Address(Base, ComparableEntity):
__tablename__ = "addresses"
id = _column(Integer, primary_key=True)
@testing.combinations(Column, mapped_column, argnames="_column")
def test_add_prop_manual(self, assert_user_address_mapping, _column):
- class User(Base, fixtures.ComparableEntity):
+ class User(Base, ComparableEntity):
__tablename__ = "users"
id = _column("id", Integer, primary_key=True)
User, "addresses", relationship("Address", backref="user")
)
- class Address(Base, fixtures.ComparableEntity):
+ class Address(Base, ComparableEntity):
__tablename__ = "addresses"
id = _column(Integer, primary_key=True)
A(brap=B())
def test_eager_order_by(self):
- class Address(Base, fixtures.ComparableEntity):
+ class Address(Base, ComparableEntity):
__tablename__ = "addresses"
id = Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
email = Column("email", String(50))
user_id = Column("user_id", Integer, ForeignKey("users.id"))
- class User(Base, fixtures.ComparableEntity):
+ class User(Base, ComparableEntity):
__tablename__ = "users"
id = Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
)
def test_order_by_multi(self):
- class Address(Base, fixtures.ComparableEntity):
+ class Address(Base, ComparableEntity):
__tablename__ = "addresses"
id = Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
email = Column("email", String(50))
user_id = Column("user_id", Integer, ForeignKey("users.id"))
- class User(Base, fixtures.ComparableEntity):
+ class User(Base, ComparableEntity):
__tablename__ = "users"
id = Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
"Ignoring declarative-like tuple value of " "attribute 'name'"
):
- class User(Base, fixtures.ComparableEntity):
+ class User(Base, ComparableEntity):
__tablename__ = "users"
id = Column("id", Integer, primary_key=True)
name = (Column("name", String(50)),)
is_(inspect(Employee).local_table, Person.__table__)
def test_expression(self, require_metaclass):
- class User(Base, fixtures.ComparableEntity):
+ class User(Base, ComparableEntity):
__tablename__ = "users"
id = Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
name = Column("name", String(50))
addresses = relationship("Address", backref="user")
- class Address(Base, fixtures.ComparableEntity):
+ class Address(Base, ComparableEntity):
__tablename__ = "addresses"
id = Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
)
def test_useless_declared_attr(self):
- class Address(Base, fixtures.ComparableEntity):
+ class Address(Base, ComparableEntity):
__tablename__ = "addresses"
id = Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
email = Column("email", String(50))
user_id = Column("user_id", Integer, ForeignKey("users.id"))
- class User(Base, fixtures.ComparableEntity):
+ class User(Base, ComparableEntity):
__tablename__ = "users"
id = Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
return Column(Integer)
def test_column(self, require_metaclass):
- class User(Base, fixtures.ComparableEntity):
+ class User(Base, ComparableEntity):
__tablename__ = "users"
id = Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
eq_(Foo.d.impl.active_history, False)
def test_column_properties(self):
- class Address(Base, fixtures.ComparableEntity):
+ class Address(Base, ComparableEntity):
__tablename__ = "addresses"
id = Column(
Integer, primary_key=True, test_needs_autoincrement=True
email = Column(String(50))
user_id = Column(Integer, ForeignKey("users.id"))
- class User(Base, fixtures.ComparableEntity):
+ class User(Base, ComparableEntity):
__tablename__ = "users"
id = Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
)
def test_column_properties_2(self):
- class Address(Base, fixtures.ComparableEntity):
+ class Address(Base, ComparableEntity):
__tablename__ = "addresses"
id = Column(Integer, primary_key=True)
email = Column(String(50))
user_id = Column(Integer, ForeignKey("users.id"))
- class User(Base, fixtures.ComparableEntity):
+ class User(Base, ComparableEntity):
__tablename__ = "users"
id = Column("id", Integer, primary_key=True)
name = Column("name", String(50))
eq_(set(Address.__table__.c.keys()), {"id", "email", "user_id"})
def test_deferred(self):
- class User(Base, fixtures.ComparableEntity):
+ class User(Base, ComparableEntity):
__tablename__ = "users"
id = Column(
Integer, primary_key=True, test_needs_autoincrement=True
self.assert_sql_count(testing.db, go, 1)
def test_composite_inline(self):
- class AddressComposite(fixtures.ComparableEntity):
+ class AddressComposite(ComparableEntity):
def __init__(self, street, state):
self.street = street
self.state = state
def __composite_values__(self):
return [self.street, self.state]
- class User(Base, fixtures.ComparableEntity):
+ class User(Base, ComparableEntity):
__tablename__ = "user"
id = Column(
Integer, primary_key=True, test_needs_autoincrement=True
)
def test_composite_separate(self):
- class AddressComposite(fixtures.ComparableEntity):
+ class AddressComposite(ComparableEntity):
def __init__(self, street, state):
self.street = street
self.state = state
def __composite_values__(self):
return [self.street, self.state]
- class User(Base, fixtures.ComparableEntity):
+ class User(Base, ComparableEntity):
__tablename__ = "user"
id = Column(
Integer, primary_key=True, test_needs_autoincrement=True
)
def test_synonym_inline(self):
- class User(Base, fixtures.ComparableEntity):
+ class User(Base, ComparableEntity):
__tablename__ = "users"
id = Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
def __eq__(self, other):
return self.__clause_element__() == other + " FOO"
- class User(Base, fixtures.ComparableEntity):
+ class User(Base, ComparableEntity):
__tablename__ = "users"
id = Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
eq_(sess.query(User).filter(User.name == "someuser").one(), u1)
def test_synonym_added(self, require_metaclass):
- class User(Base, fixtures.ComparableEntity):
+ class User(Base, ComparableEntity):
__tablename__ = "users"
id = Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
)
def test_reentrant_compile_via_foreignkey(self):
- class User(Base, fixtures.ComparableEntity):
+ class User(Base, ComparableEntity):
__tablename__ = "users"
id = Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
name = Column("name", String(50))
addresses = relationship("Address", backref="user")
- class Address(Base, fixtures.ComparableEntity):
+ class Address(Base, ComparableEntity):
__tablename__ = "addresses"
id = Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
)
def test_relationship_reference(self, require_metaclass):
- class Address(Base, fixtures.ComparableEntity):
+ class Address(Base, ComparableEntity):
__tablename__ = "addresses"
id = Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
email = Column("email", String(50))
user_id = Column("user_id", Integer, ForeignKey("users.id"))
- class User(Base, fixtures.ComparableEntity):
+ class User(Base, ComparableEntity):
__tablename__ = "users"
id = Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
eq_(sess.execute(t1.select()).fetchall(), [("someid", "somedata")])
def test_synonym_for(self):
- class User(Base, fixtures.ComparableEntity):
+ class User(Base, ComparableEntity):
__tablename__ = "users"
id = Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
from sqlalchemy.testing import is_
from sqlalchemy.testing import is_false
from sqlalchemy.testing import is_true
+from sqlalchemy.testing.entities import ComparableEntity
from sqlalchemy.testing.fixtures import fixture_session
from sqlalchemy.testing.schema import Column
from sqlalchemy.testing.schema import Table
configure_mappers()
def test_joined(self):
- class Company(Base, fixtures.ComparableEntity):
+ class Company(Base, ComparableEntity):
__tablename__ = "companies"
id = Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
name = Column("name", String(50))
employees = relationship("Person")
- class Person(Base, fixtures.ComparableEntity):
+ class Person(Base, ComparableEntity):
__tablename__ = "people"
id = Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
self.assert_sql_count(testing.db, go, 1)
def test_add_subcol_after_the_fact(self):
- class Person(Base, fixtures.ComparableEntity):
+ class Person(Base, ComparableEntity):
__tablename__ = "people"
id = Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
)
def test_add_parentcol_after_the_fact(self):
- class Person(Base, fixtures.ComparableEntity):
+ class Person(Base, ComparableEntity):
__tablename__ = "people"
id = Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
)
def test_add_sub_parentcol_after_the_fact(self):
- class Person(Base, fixtures.ComparableEntity):
+ class Person(Base, ComparableEntity):
__tablename__ = "people"
id = Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
)
def test_subclass_mixin(self):
- class Person(Base, fixtures.ComparableEntity):
+ class Person(Base, ComparableEntity):
__tablename__ = "people"
id = Column("id", Integer, primary_key=True)
name = Column("name", String(50))
"""test single inheritance where all the columns are on the base
class."""
- class Company(Base, fixtures.ComparableEntity):
+ class Company(Base, ComparableEntity):
__tablename__ = "companies"
id = Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
name = Column("name", String(50))
employees = relationship("Person")
- class Person(Base, fixtures.ComparableEntity):
+ class Person(Base, ComparableEntity):
__tablename__ = "people"
id = Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
"""
- class Company(Base, fixtures.ComparableEntity):
+ class Company(Base, ComparableEntity):
__tablename__ = "companies"
id = Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
name = Column("name", String(50))
employees = relationship("Person")
- class Person(Base, fixtures.ComparableEntity):
+ class Person(Base, ComparableEntity):
__tablename__ = "people"
id = Column(
Integer, primary_key=True, test_needs_autoincrement=True
def test_single_constraint_on_sub(self):
"""test the somewhat unusual case of [ticket:3341]"""
- class Person(Base, fixtures.ComparableEntity):
+ class Person(Base, ComparableEntity):
__tablename__ = "people"
id = Column(
Integer, primary_key=True, test_needs_autoincrement=True
is_(Manager.id.property.columns[0], Person.__table__.c.id)
def test_joined_from_single(self):
- class Company(Base, fixtures.ComparableEntity):
+ class Company(Base, ComparableEntity):
__tablename__ = "companies"
id = Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
name = Column("name", String(50))
employees = relationship("Person")
- class Person(Base, fixtures.ComparableEntity):
+ class Person(Base, ComparableEntity):
__tablename__ = "people"
id = Column(
Integer, primary_key=True, test_needs_autoincrement=True
)
def test_single_from_joined_colsonsub(self):
- class Person(Base, fixtures.ComparableEntity):
+ class Person(Base, ComparableEntity):
__tablename__ = "people"
id = Column(
Integer, primary_key=True, test_needs_autoincrement=True
is_(B.__mapper__.polymorphic_on, A.__table__.c.discriminator)
def test_add_deferred(self):
- class Person(Base, fixtures.ComparableEntity):
+ class Person(Base, ComparableEntity):
__tablename__ = "people"
id = Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
"""
- class Person(Base, fixtures.ComparableEntity):
+ class Person(Base, ComparableEntity):
__tablename__ = "people"
id = Column(
Integer, primary_key=True, test_needs_autoincrement=True
primary_language_id = Column(Integer, ForeignKey("languages.id"))
primary_language = relationship("Language")
- class Language(Base, fixtures.ComparableEntity):
+ class Language(Base, ComparableEntity):
__tablename__ = "languages"
id = Column(
Integer, primary_key=True, test_needs_autoincrement=True
)
def test_single_three_levels(self):
- class Person(Base, fixtures.ComparableEntity):
+ class Person(Base, ComparableEntity):
__tablename__ = "people"
id = Column(Integer, primary_key=True)
name = Column(String(50))
assert_raises(sa.exc.ArgumentError, go)
def test_single_no_special_cols(self):
- class Person(Base, fixtures.ComparableEntity):
+ class Person(Base, ComparableEntity):
__tablename__ = "people"
id = Column("id", Integer, primary_key=True)
name = Column("name", String(50))
assert_raises_message(sa.exc.ArgumentError, "place primary key", go)
def test_single_no_table_args(self):
- class Person(Base, fixtures.ComparableEntity):
+ class Person(Base, ComparableEntity):
__tablename__ = "people"
id = Column("id", Integer, primary_key=True)
name = Column("name", String(50))
from sqlalchemy.testing import assert_raises
from sqlalchemy.testing import eq_
from sqlalchemy.testing import fixtures
+from sqlalchemy.testing.entities import ComparableEntity
+from sqlalchemy.testing.entities import ComparableMixin
from sqlalchemy.testing.fixtures import fixture_session
from sqlalchemy.testing.schema import Column
from sqlalchemy.testing.schema import Table
)
def test_basic(self):
- class User(Base, fixtures.ComparableEntity):
+ class User(Base, ComparableEntity):
__tablename__ = "users"
__autoload_with__ = testing.db
addresses = relationship("Address", backref="user")
- class Address(Base, fixtures.ComparableEntity):
+ class Address(Base, ComparableEntity):
__tablename__ = "addresses"
__autoload_with__ = testing.db
eq_(a1.user, User(name="u1"))
def test_rekey_wbase(self):
- class User(Base, fixtures.ComparableEntity):
+ class User(Base, ComparableEntity):
__tablename__ = "users"
__autoload_with__ = testing.db
nom = Column("name", String(50), key="nom")
addresses = relationship("Address", backref="user")
- class Address(Base, fixtures.ComparableEntity):
+ class Address(Base, ComparableEntity):
__tablename__ = "addresses"
__autoload_with__ = testing.db
def test_rekey_wdecorator(self):
@registry.mapped
- class User(fixtures.ComparableMixin):
+ class User(ComparableMixin):
__tablename__ = "users"
__autoload_with__ = testing.db
nom = Column("name", String(50), key="nom")
addresses = relationship("Address", backref="user")
@registry.mapped
- class Address(fixtures.ComparableMixin):
+ class Address(ComparableMixin):
__tablename__ = "addresses"
__autoload_with__ = testing.db
assert_raises(TypeError, User, name="u3")
def test_supplied_fk(self):
- class IMHandle(Base, fixtures.ComparableEntity):
+ class IMHandle(Base, ComparableEntity):
__tablename__ = "imhandles"
__autoload_with__ = testing.db
user_id = Column("user_id", Integer, ForeignKey("users.id"))
- class User(Base, fixtures.ComparableEntity):
+ class User(Base, ComparableEntity):
__tablename__ = "users"
__autoload_with__ = testing.db
handles = relationship("IMHandle", backref="user")
style: testing.Variation,
sort_by_parameter_order,
):
- class A(fixtures.ComparableEntity, decl_base):
+ class A(ComparableEntity, decl_base):
__tablename__ = "a"
id: Mapped[int] = mapped_column(Identity(), primary_key=True)
data: Mapped[str]
def setup_classes(cls):
decl_base = cls.DeclarativeBasic
- class A(fixtures.ComparableEntity, decl_base):
+ class A(ComparableEntity, decl_base):
__tablename__ = "a"
id: Mapped[int] = mapped_column(Identity(), primary_key=True)
type: Mapped[str]
def setup_classes(cls):
decl_base = cls.DeclarativeBasic
- class A(fixtures.ComparableEntity, decl_base):
+ class A(ComparableEntity, decl_base):
__tablename__ = "a"
id: Mapped[int] = mapped_column(Identity(), primary_key=True)
type: Mapped[str]
def setup_classes(cls):
decl_base = cls.DeclarativeBasic
- class A(fixtures.ComparableEntity, decl_base):
+ class A(ComparableEntity, decl_base):
__tablename__ = "a"
id: Mapped[int] = mapped_column(Identity(), primary_key=True)
type: Mapped[str]
def setup_classes(cls):
decl_base = cls.DeclarativeBasic
- class User(fixtures.ComparableEntity, decl_base):
+ class User(ComparableEntity, decl_base):
__tablename__ = "users"
id: Mapped[uuid.UUID] = mapped_column(primary_key=True)
username: Mapped[str]
from sqlalchemy.testing import AssertsCompiledSQL
from sqlalchemy.testing import config
from sqlalchemy.testing import fixtures
+from sqlalchemy.testing.entities import ComparableEntity
from sqlalchemy.testing.schema import Column
from sqlalchemy.testing.schema import Table
-class Company(fixtures.ComparableEntity):
+class Company(ComparableEntity):
pass
-class Person(fixtures.ComparableEntity):
+class Person(ComparableEntity):
pass
pass
-class Machine(fixtures.ComparableEntity):
+class Machine(ComparableEntity):
pass
-class MachineType(fixtures.ComparableEntity):
+class MachineType(ComparableEntity):
pass
-class Paperwork(fixtures.ComparableEntity):
+class Paperwork(ComparableEntity):
pass
-class Page(fixtures.ComparableEntity):
+class Page(ComparableEntity):
pass
items["__mapper_args__"][mapper_opt] = value[mapper_opt]
if is_base:
- klass = type(key, (fixtures.ComparableEntity, base), items)
+ klass = type(key, (ComparableEntity, base), items)
else:
klass = type(key, (base,), items)
from sqlalchemy import testing
from sqlalchemy.testing import eq_
from sqlalchemy.testing import fixtures
+from sqlalchemy.testing.entities import ComparableEntity
from sqlalchemy.testing.fixtures import fixture_session
from sqlalchemy.testing.schema import Column
from sqlalchemy.testing.schema import Table
@testing.combinations(("union",), ("none",))
def test_abc_poly_roundtrip(self, fetchtype):
- class A(fixtures.ComparableEntity):
+ class A(ComparableEntity):
pass
class B(A):
from sqlalchemy.testing import eq_
from sqlalchemy.testing import expect_warnings
from sqlalchemy.testing import fixtures
-from sqlalchemy.testing.fixtures import ComparableEntity
+from sqlalchemy.testing.entities import ComparableEntity
from sqlalchemy.testing.fixtures import fixture_session
from sqlalchemy.testing.provision import normalize_sequence
from sqlalchemy.testing.schema import Column
)
def test_selfref_onjoined(self):
- class Taggable(fixtures.ComparableEntity):
+ class Taggable(ComparableEntity):
pass
class User(Taggable):
"""test that Query uses the full set of mapper._eager_loaders
when generating SQL"""
- class Person(fixtures.ComparableEntity):
+ class Person(ComparableEntity):
pass
class Employee(Person):
def __init__(self, name="bob"):
self.name = name
- class Tag(fixtures.ComparableEntity):
+ class Tag(ComparableEntity):
def __init__(self, label):
self.label = label
from sqlalchemy.testing.assertsql import Conditional
from sqlalchemy.testing.assertsql import Or
from sqlalchemy.testing.assertsql import RegexSQL
+from sqlalchemy.testing.entities import BasicEntity
+from sqlalchemy.testing.entities import ComparableEntity
from sqlalchemy.testing.fixtures import fixture_session
from sqlalchemy.testing.schema import Column
from sqlalchemy.testing.schema import Table
def setup_classes(cls):
Base = cls.DeclarativeBasic
- class A(fixtures.ComparableEntity, Base):
+ class A(ComparableEntity, Base):
__tablename__ = "a"
id = Column(
)
def test_polymorphic_synonym(self):
- class T1(fixtures.ComparableEntity):
+ class T1(ComparableEntity):
def info(self):
return "THE INFO IS:" + self._info
)
def test_cascade(self):
- class T1(fixtures.BasicEntity):
+ class T1(BasicEntity):
pass
- class T2(fixtures.BasicEntity):
+ class T2(BasicEntity):
pass
class T3(T2):
pass
- class T4(fixtures.BasicEntity):
+ class T4(BasicEntity):
pass
self.mapper_registry.map_imperatively(
)
# test [ticket:1186]
- class Base(fixtures.BasicEntity):
+ class Base(BasicEntity):
pass
class Sub(Base):
def test_adapt_stringency(self):
b_table, a_table = self.tables.b_table, self.tables.a_table
- class A(fixtures.ComparableEntity):
+ class A(ComparableEntity):
pass
class B(A):
self.tables.stuff,
)
- class Base(fixtures.BasicEntity):
+ class Base(BasicEntity):
pass
class Sub(Base):
def test_delete(self):
subtable, base = self.tables.subtable, self.tables.base
- class Base(fixtures.BasicEntity):
+ class Base(BasicEntity):
pass
class Sub(Base):
def test_no_optimize_on_map_to_join(self):
base, sub = self.tables.base, self.tables.sub
- class Base(fixtures.ComparableEntity):
+ class Base(ComparableEntity):
pass
- class JoinBase(fixtures.ComparableEntity):
+ class JoinBase(ComparableEntity):
pass
class SubJoinBase(JoinBase):
base, sub = self.tables.base, self.tables.sub
- class Base(fixtures.ComparableEntity):
+ class Base(ComparableEntity):
pass
class Sub(Base):
base, sub = self.tables.base, self.tables.sub
- class Base(fixtures.ComparableEntity):
+ class Base(ComparableEntity):
pass
class Sub(Base):
def test_column_expression(self):
base, sub = self.tables.base, self.tables.sub
- class Base(fixtures.ComparableEntity):
+ class Base(ComparableEntity):
pass
class Sub(Base):
def test_column_expression_joined(self):
base, sub = self.tables.base, self.tables.sub
- class Base(fixtures.ComparableEntity):
+ class Base(ComparableEntity):
pass
class Sub(Base):
def test_composite_column_joined(self):
base, with_comp = self.tables.base, self.tables.with_comp
- class Base(fixtures.BasicEntity):
+ class Base(BasicEntity):
pass
class WithComp(Base):
expected_eager_defaults and testing.db.dialect.insert_returning
)
- class Base(fixtures.BasicEntity):
+ class Base(BasicEntity):
pass
class Sub(Base):
def test_dont_generate_on_none(self):
base, sub = self.tables.base, self.tables.sub
- class Base(fixtures.BasicEntity):
+ class Base(BasicEntity):
pass
class Sub(Base):
self.tables.subsub,
)
- class Base(fixtures.BasicEntity):
+ class Base(BasicEntity):
pass
class Sub(Base):
)
def test_orphan_message(self):
- class Base(fixtures.BasicEntity):
+ class Base(BasicEntity):
pass
class SubClass(Base):
pass
- class Parent(fixtures.BasicEntity):
+ class Parent(BasicEntity):
pass
self.mapper_registry.map_imperatively(
def setup_classes(cls):
Base = cls.DeclarativeBasic
- class Parent(fixtures.ComparableEntity, Base):
+ class Parent(ComparableEntity, Base):
__tablename__ = "parent"
id = Column(Integer, primary_key=True)
- class A(fixtures.ComparableEntity, Base):
+ class A(ComparableEntity, Base):
__tablename__ = "a"
id = Column(Integer, primary_key=True)
parent_id = Column(ForeignKey("parent.id"))
def setup_classes(cls):
Base = cls.DeclarativeBasic
- class AJoined(fixtures.ComparableEntity, Base):
+ class AJoined(ComparableEntity, Base):
__tablename__ = "ajoined"
id = Column(Integer, primary_key=True)
type = Column(String(10), nullable=False)
id = Column(ForeignKey("ajoined.id"), primary_key=True)
__mapper_args__ = {"polymorphic_identity": "subb"}
- class ASingle(fixtures.ComparableEntity, Base):
+ class ASingle(ComparableEntity, Base):
__tablename__ = "asingle"
id = Column(Integer, primary_key=True)
type = Column(String(10), nullable=False)
def setup_classes(cls):
Base = cls.DeclarativeBasic
- class A(fixtures.ComparableEntity, Base):
+ class A(ComparableEntity, Base):
__tablename__ = "table_a"
order_id: Mapped[str] = mapped_column(String(50), primary_key=True)
def setup_classes(cls):
Base = cls.DeclarativeBasic
- class Parent(fixtures.ComparableEntity, Base):
+ class Parent(ComparableEntity, Base):
__tablename__ = "parent"
id = Column(Integer, primary_key=True)
- class Child(fixtures.ComparableEntity, Base):
+ class Child(ComparableEntity, Base):
__tablename__ = "child"
id = Column(Integer, primary_key=True)
parent_id = Column(Integer, ForeignKey("parent.id"))
"polymorphic_load": "selectin",
}
- class Other(fixtures.ComparableEntity, Base):
+ class Other(ComparableEntity, Base):
__tablename__ = "other"
id = Column(Integer, primary_key=True)
from sqlalchemy.testing import eq_
from sqlalchemy.testing import fixtures
from sqlalchemy.testing import is_
+from sqlalchemy.testing.entities import ComparableEntity
from sqlalchemy.testing.fixtures import fixture_session
from sqlalchemy.testing.schema import Column
-class Person(fixtures.ComparableEntity):
+class Person(ComparableEntity):
pass
pass
-class Company(fixtures.ComparableEntity):
+class Company(ComparableEntity):
pass
from sqlalchemy.testing.schema import Table
-class Company(fixtures.ComparableEntity):
+class Company(ComparableEntity):
pass
-class Person(fixtures.ComparableEntity):
+class Person(ComparableEntity):
pass
pass
-class Machine(fixtures.ComparableEntity):
+class Machine(ComparableEntity):
pass
-class Paperwork(fixtures.ComparableEntity):
+class Paperwork(ComparableEntity):
pass
from sqlalchemy.orm import Session
from sqlalchemy.testing import eq_
from sqlalchemy.testing import fixtures
+from sqlalchemy.testing.entities import ComparableEntity
from sqlalchemy.testing.fixtures import fixture_session
from sqlalchemy.testing.schema import Column
from sqlalchemy.testing.schema import Table
connection.execute(foo.insert(), dict(a="i am bar", b="bar"))
connection.execute(foo.insert(), dict(a="also bar", b="bar"))
- class Foo(fixtures.ComparableEntity):
+ class Foo(ComparableEntity):
pass
class Bar(Foo):
from sqlalchemy.testing import fixtures
from sqlalchemy.testing import mock
from sqlalchemy.testing.assertsql import CompiledSQL
+from sqlalchemy.testing.entities import ComparableEntity
from sqlalchemy.testing.fixtures import fixture_session
from sqlalchemy.testing.schema import Column
from sqlalchemy.testing.schema import Table
)
def test_single_on_joined(self):
- class Person(fixtures.ComparableEntity):
+ class Person(ComparableEntity):
pass
class Employee(Person):
from sqlalchemy.testing import fixtures
from sqlalchemy.testing.assertions import expect_raises_message
from sqlalchemy.testing.assertsql import CompiledSQL
-from sqlalchemy.testing.fixtures import ComparableEntity
+from sqlalchemy.testing.entities import ComparableEntity
from sqlalchemy.testing.fixtures import fixture_session
from sqlalchemy.testing import is_true
from sqlalchemy.testing import not_in
from sqlalchemy.testing.assertions import assert_warns
+from sqlalchemy.testing.entities import BasicEntity
from sqlalchemy.testing.util import all_partial_orderings
from sqlalchemy.testing.util import gc_collect
def test_lazyhistory(self):
"""tests that history functions work with lazy-loading attributes"""
- class Foo(fixtures.BasicEntity):
+ class Foo(BasicEntity):
pass
- class Bar(fixtures.BasicEntity):
+ class Bar(BasicEntity):
pass
instrumentation.register_class(Foo)
class HistoryTest(fixtures.TestBase):
def _fixture(self, uselist, useobject, active_history, **kw):
- class Foo(fixtures.BasicEntity):
+ class Foo(BasicEntity):
pass
instrumentation.register_class(Foo)
return Foo
def _two_obj_fixture(self, uselist, active_history=False):
- class Foo(fixtures.BasicEntity):
+ class Foo(BasicEntity):
pass
- class Bar(fixtures.BasicEntity):
+ class Bar(BasicEntity):
def __bool__(self):
assert False
def test_dict_collections(self):
# TODO: break into individual tests
- class Foo(fixtures.BasicEntity):
+ class Foo(BasicEntity):
pass
- class Bar(fixtures.BasicEntity):
+ class Bar(BasicEntity):
pass
instrumentation.register_class(Foo)
def test_object_collections_mutate(self):
# TODO: break into individual tests
- class Foo(fixtures.BasicEntity):
+ class Foo(BasicEntity):
pass
- class Bar(fixtures.BasicEntity):
+ class Bar(BasicEntity):
pass
instrumentation.register_class(Foo)
def test_collections_via_backref(self):
# TODO: break into individual tests
- class Foo(fixtures.BasicEntity):
+ class Foo(BasicEntity):
pass
- class Bar(fixtures.BasicEntity):
+ class Bar(BasicEntity):
pass
instrumentation.register_class(Foo)
def test_lazy_backref_collections(self):
# TODO: break into individual tests
- class Foo(fixtures.BasicEntity):
+ class Foo(BasicEntity):
pass
- class Bar(fixtures.BasicEntity):
+ class Bar(BasicEntity):
pass
lazy_load = []
def test_collections_via_lazyload(self):
# TODO: break into individual tests
- class Foo(fixtures.BasicEntity):
+ class Foo(BasicEntity):
pass
- class Bar(fixtures.BasicEntity):
+ class Bar(BasicEntity):
pass
lazy_load = []
def test_scalar_via_lazyload(self):
# TODO: break into individual tests
- class Foo(fixtures.BasicEntity):
+ class Foo(BasicEntity):
pass
lazy_load = None
def test_scalar_via_lazyload_with_active(self):
# TODO: break into individual tests
- class Foo(fixtures.BasicEntity):
+ class Foo(BasicEntity):
pass
lazy_load = None
def test_scalar_object_via_lazyload(self):
# TODO: break into individual tests
- class Foo(fixtures.BasicEntity):
+ class Foo(BasicEntity):
pass
- class Bar(fixtures.BasicEntity):
+ class Bar(BasicEntity):
pass
lazy_load = None
class CollectionKeyTest(fixtures.ORMTest):
@testing.fixture
def dict_collection(self):
- class Foo(fixtures.BasicEntity):
+ class Foo(BasicEntity):
pass
- class Bar(fixtures.BasicEntity):
+ class Bar(BasicEntity):
def __init__(self, name):
self.name = name
@testing.fixture
def list_collection(self):
- class Foo(fixtures.BasicEntity):
+ class Foo(BasicEntity):
pass
- class Bar(fixtures.BasicEntity):
+ class Bar(BasicEntity):
pass
instrumentation.register_class(Foo)
from sqlalchemy.testing import in_
from sqlalchemy.testing import not_in
from sqlalchemy.testing.assertsql import CompiledSQL
+from sqlalchemy.testing.entities import ComparableEntity
from sqlalchemy.testing.fixtures import fixture_session
from sqlalchemy.testing.schema import Column
from sqlalchemy.testing.schema import Table
self.tables.accounts,
)
- class Customer(fixtures.ComparableEntity):
+ class Customer(ComparableEntity):
pass
- class Account(fixtures.ComparableEntity):
+ class Account(ComparableEntity):
pass
- class SalesRep(fixtures.ComparableEntity):
+ class SalesRep(ComparableEntity):
pass
self.mapper_registry.map_imperatively(
self.tables.addresses,
)
- class Address(fixtures.ComparableEntity):
+ class Address(ComparableEntity):
pass
- class Home(fixtures.ComparableEntity):
+ class Home(ComparableEntity):
pass
- class Business(fixtures.ComparableEntity):
+ class Business(ComparableEntity):
pass
self.mapper_registry.map_imperatively(Address, addresses)
self.tables.addresses,
)
- class Address(fixtures.ComparableEntity):
+ class Address(ComparableEntity):
pass
- class Home(fixtures.ComparableEntity):
+ class Home(ComparableEntity):
pass
- class Business(fixtures.ComparableEntity):
+ class Business(ComparableEntity):
pass
self.mapper_registry.map_imperatively(Address, addresses)
def test_basic(self):
table_b, table_a = self.tables.table_b, self.tables.table_a
- class A(fixtures.ComparableEntity):
+ class A(ComparableEntity):
pass
- class B(fixtures.ComparableEntity):
+ class B(ComparableEntity):
pass
self.mapper_registry.map_imperatively(
def test_o2m_m2o(self):
base, noninh_child = self.tables.base, self.tables.noninh_child
- class Base(fixtures.ComparableEntity):
+ class Base(ComparableEntity):
pass
- class Child(fixtures.ComparableEntity):
+ class Child(ComparableEntity):
pass
self.mapper_registry.map_imperatively(
self.tables.parent,
)
- class Base(fixtures.ComparableEntity):
+ class Base(ComparableEntity):
pass
class Parent(Base):
from sqlalchemy.testing.assertsql import CompiledSQL
from sqlalchemy.testing.assertsql import Conditional
from sqlalchemy.testing.assertsql import RegexSQL
+from sqlalchemy.testing.entities import BasicEntity
from sqlalchemy.testing.fixtures import fixture_session
from sqlalchemy.testing.schema import Column
from sqlalchemy.testing.schema import Table
def test_exclude(self):
dt = self.tables.dt
- class Foo(fixtures.BasicEntity):
+ class Foo(BasicEntity):
pass
self.mapper_registry.map_imperatively(
from sqlalchemy.testing import expect_raises_message
from sqlalchemy.testing import fixtures
from sqlalchemy.testing import is_
+from sqlalchemy.testing.entities import ComparableEntity
from sqlalchemy.testing.fixtures import fixture_session
from sqlalchemy.testing.schema import Column
from sqlalchemy.testing.schema import Table
def setup_classes(cls):
Base = cls.DeclarativeBasic
- class A(fixtures.ComparableEntity, Base):
+ class A(ComparableEntity, Base):
__tablename__ = "a"
id = Column(Integer, primary_key=True)
x = Column(Integer)
bs = relationship("B", order_by="B.id")
- class A_default(fixtures.ComparableEntity, Base):
+ class A_default(ComparableEntity, Base):
__tablename__ = "a_default"
id = Column(Integer, primary_key=True)
x = Column(Integer)
my_expr = query_expression(default_expr=literal(15))
- class B(fixtures.ComparableEntity, Base):
+ class B(ComparableEntity, Base):
__tablename__ = "b"
id = Column(Integer, primary_key=True)
a_id = Column(ForeignKey("a.id"))
b_expr = query_expression()
- class C(fixtures.ComparableEntity, Base):
+ class C(ComparableEntity, Base):
__tablename__ = "c"
id = Column(Integer, primary_key=True)
x = Column(Integer)
def setup_classes(cls):
Base = cls.DeclarativeBasic
- class A(fixtures.ComparableEntity, Base):
+ class A(ComparableEntity, Base):
__tablename__ = "a"
id = Column(Integer, primary_key=True)
x = Column(Integer)
from sqlalchemy.testing import is_
from sqlalchemy.testing import is_true
from sqlalchemy.testing import mock
+from sqlalchemy.testing.entities import ComparableEntity
from sqlalchemy.testing.fixtures import CacheKeyFixture
from sqlalchemy.testing.fixtures import fixture_session
from sqlalchemy.testing.fixtures import RemoveORMEventsGlobally
assert_col = []
- class User(fixtures.ComparableEntity):
+ class User(ComparableEntity):
def _get_name(self):
assert_col.append(("get", self._name))
return self._name
from sqlalchemy.testing import is_not
from sqlalchemy.testing import mock
from sqlalchemy.testing.assertsql import CompiledSQL
+from sqlalchemy.testing.entities import ComparableEntity
from sqlalchemy.testing.fixtures import fixture_session
from sqlalchemy.testing.schema import Column
from sqlalchemy.testing.schema import Table
def test_ordering(self):
a, m2m, b = (self.tables.a, self.tables.m2m, self.tables.b)
- class A(fixtures.ComparableEntity):
+ class A(ComparableEntity):
pass
- class B(fixtures.ComparableEntity):
+ class B(ComparableEntity):
pass
self.mapper_registry.map_imperatively(
def test_basic(self):
nodes = self.tables.nodes
- class Node(fixtures.ComparableEntity):
+ class Node(ComparableEntity):
def append(self, node):
self.children.append(node)
def test_lazy_fallback_doesnt_affect_eager(self):
nodes = self.tables.nodes
- class Node(fixtures.ComparableEntity):
+ class Node(ComparableEntity):
def append(self, node):
self.children.append(node)
def test_with_deferred(self):
nodes = self.tables.nodes
- class Node(fixtures.ComparableEntity):
+ class Node(ComparableEntity):
def append(self, node):
self.children.append(node)
def test_options(self):
nodes = self.tables.nodes
- class Node(fixtures.ComparableEntity):
+ class Node(ComparableEntity):
def append(self, node):
self.children.append(node)
def test_no_depth(self):
nodes = self.tables.nodes
- class Node(fixtures.ComparableEntity):
+ class Node(ComparableEntity):
def append(self, node):
self.children.append(node)
def test_basic(self):
widget, widget_rel = self.tables.widget, self.tables.widget_rel
- class Widget(fixtures.ComparableEntity):
+ class Widget(ComparableEntity):
pass
self.mapper_registry.map_imperatively(
self.tables.users_table,
)
- class User(fixtures.ComparableEntity):
+ class User(ComparableEntity):
@property
def prop_score(self):
return sum([tag.prop_score for tag in self.tags])
- class Tag(fixtures.ComparableEntity):
+ class Tag(ComparableEntity):
@property
def prop_score(self):
return self.score1 * self.score2
def _do_test(self, labeled, ondate, aliasstuff):
stuff, users = self.tables.stuff, self.tables.users
- class User(fixtures.ComparableEntity):
+ class User(ComparableEntity):
pass
- class Stuff(fixtures.ComparableEntity):
+ class Stuff(ComparableEntity):
pass
self.mapper_registry.map_imperatively(Stuff, stuff)
self.tables.sub1,
)
- class Base(fixtures.ComparableEntity):
+ class Base(ComparableEntity):
pass
- class Sub1(fixtures.ComparableEntity):
+ class Sub1(ComparableEntity):
pass
- class Sub2(fixtures.ComparableEntity):
+ class Sub2(ComparableEntity):
pass
self.mapper_registry.map_imperatively(
from sqlalchemy.testing import is_false
from sqlalchemy.testing import is_true
from sqlalchemy.testing.assertsql import CompiledSQL
+from sqlalchemy.testing.entities import ComparableEntity
from sqlalchemy.testing.fixtures import fixture_session
from sqlalchemy.testing.schema import Column
from sqlalchemy.testing.schema import Table
Column("data", MyHashType()),
)
- class Category(fixtures.ComparableEntity):
+ class Category(ComparableEntity):
pass
- class Article(fixtures.ComparableEntity):
+ class Article(ComparableEntity):
pass
self.mapper_registry.map_imperatively(Category, category)
def test_correlated_lazyload(self):
stuff, user_t = self.tables.stuff, self.tables.user_t
- class User(fixtures.ComparableEntity):
+ class User(ComparableEntity):
pass
- class Stuff(fixtures.ComparableEntity):
+ class Stuff(ComparableEntity):
pass
self.mapper_registry.map_imperatively(Stuff, stuff)
from sqlalchemy.testing import is_false
from sqlalchemy.testing import is_true
from sqlalchemy.testing import ne_
-from sqlalchemy.testing.fixtures import ComparableEntity
-from sqlalchemy.testing.fixtures import ComparableMixin
+from sqlalchemy.testing.entities import ComparableEntity
+from sqlalchemy.testing.entities import ComparableMixin
from sqlalchemy.testing.fixtures import fixture_session
from sqlalchemy.testing.schema import Column
from sqlalchemy.testing.schema import Table
assert_col = []
- class User(fixtures.ComparableEntity):
+ class User(ComparableEntity):
def _get_name(self):
assert_col.append(("get", self._name))
return self._name
from sqlalchemy.testing import in_
from sqlalchemy.testing import not_in
from sqlalchemy.testing.assertsql import CountStatements
+from sqlalchemy.testing.entities import ComparableEntity
from sqlalchemy.testing.fixtures import fixture_session
from sqlalchemy.testing.schema import Column
from sqlalchemy.testing.schema import Table
@classmethod
def setup_classes(cls):
- class Rock(cls.Basic, fixtures.ComparableEntity):
+ class Rock(cls.Basic, ComparableEntity):
pass
- class Bug(cls.Basic, fixtures.ComparableEntity):
+ class Bug(cls.Basic, ComparableEntity):
pass
def _setup_delete_orphan_o2o(self):
@classmethod
def setup_classes(cls):
- class Employee(cls.Basic, fixtures.ComparableEntity):
+ class Employee(cls.Basic, ComparableEntity):
pass
class Manager(Employee):
from sqlalchemy.testing.assertions import expect_warnings
from sqlalchemy.testing.assertions import is_not_none
from sqlalchemy.testing.assertsql import CompiledSQL
+from sqlalchemy.testing.entities import ComparableEntity
from sqlalchemy.testing.fixtures import fixture_session
from sqlalchemy.testing.schema import Column
from sqlalchemy.testing.schema import Table
s = users.outerjoin(addresses)
- class UserThing(fixtures.ComparableEntity):
+ class UserThing(ComparableEntity):
pass
registry.map_imperatively(
from sqlalchemy.testing import is_
from sqlalchemy.testing.assertsql import assert_engine
from sqlalchemy.testing.assertsql import CompiledSQL
+from sqlalchemy.testing.entities import BasicEntity
+from sqlalchemy.testing.entities import ComparableEntity
from sqlalchemy.testing.fixtures import fixture_session
from sqlalchemy.testing.schema import Column
from sqlalchemy.testing.schema import Table
def _fixture(self, uselist=False):
a, b = self.tables.a, self.tables.b
- class A(fixtures.BasicEntity):
+ class A(BasicEntity):
pass
- class B(fixtures.BasicEntity):
+ class B(BasicEntity):
pass
self.mapper_registry.map_imperatively(
)
tableC.create(connection)
- class C(fixtures.BasicEntity):
+ class C(BasicEntity):
pass
self.mapper_registry.map_imperatively(
def test_basic(self):
items = self.tables.items
- class Container(fixtures.BasicEntity):
+ class Container(BasicEntity):
pass
- class LineItem(fixtures.BasicEntity):
+ class LineItem(BasicEntity):
pass
container_select = (
def test_basic(self):
tag_foo, tags = self.tables.tag_foo, self.tables.tags
- class Tag(fixtures.ComparableEntity):
+ class Tag(ComparableEntity):
pass
- class TagInstance(fixtures.ComparableEntity):
+ class TagInstance(ComparableEntity):
pass
self.mapper_registry.map_imperatively(
def test_o2m_oncascade(self):
a, c, b = (self.tables.a, self.tables.c, self.tables.b)
- class A(fixtures.BasicEntity):
+ class A(BasicEntity):
pass
- class B(fixtures.BasicEntity):
+ class B(BasicEntity):
pass
- class C(fixtures.BasicEntity):
+ class C(BasicEntity):
pass
self.mapper_registry.map_imperatively(
def test_o2m_onflush(self):
a, c, b = (self.tables.a, self.tables.c, self.tables.b)
- class A(fixtures.BasicEntity):
+ class A(BasicEntity):
pass
- class B(fixtures.BasicEntity):
+ class B(BasicEntity):
pass
- class C(fixtures.BasicEntity):
+ class C(BasicEntity):
pass
self.mapper_registry.map_imperatively(
def test_o2m_nopoly_onflush(self):
a, c, b = (self.tables.a, self.tables.c, self.tables.b)
- class A(fixtures.BasicEntity):
+ class A(BasicEntity):
pass
- class B(fixtures.BasicEntity):
+ class B(BasicEntity):
pass
class C(B):
def test_m2o_nopoly_onflush(self):
a, b, d = (self.tables.a, self.tables.b, self.tables.d)
- class A(fixtures.BasicEntity):
+ class A(BasicEntity):
pass
class B(A):
pass
- class D(fixtures.BasicEntity):
+ class D(BasicEntity):
pass
self.mapper_registry.map_imperatively(A, a)
def test_m2o_oncascade(self):
a, b, d = (self.tables.a, self.tables.b, self.tables.d)
- class A(fixtures.BasicEntity):
+ class A(BasicEntity):
pass
- class B(fixtures.BasicEntity):
+ class B(BasicEntity):
pass
- class D(fixtures.BasicEntity):
+ class D(BasicEntity):
pass
self.mapper_registry.map_imperatively(A, a)
t2, t3, t1 = (self.tables.t2, self.tables.t3, self.tables.t1)
- class T1(fixtures.BasicEntity):
+ class T1(BasicEntity):
pass
- class T2(fixtures.BasicEntity):
+ class T2(BasicEntity):
pass
self.mapper_registry.map_imperatively(T2, t2)
)
def test_join_on_custom_op_legacy_is_comparison(self):
- class A(fixtures.BasicEntity):
+ class A(BasicEntity):
pass
- class B(fixtures.BasicEntity):
+ class B(BasicEntity):
pass
self.mapper_registry.map_imperatively(
)
def test_join_on_custom_bool_op(self):
- class A(fixtures.BasicEntity):
+ class A(BasicEntity):
pass
- class B(fixtures.BasicEntity):
+ class B(BasicEntity):
pass
self.mapper_registry.map_imperatively(
return s
def test_o2m_viewonly_oneside(self):
- class A(fixtures.ComparableEntity):
+ class A(ComparableEntity):
pass
- class B(fixtures.ComparableEntity):
+ class B(ComparableEntity):
pass
self.mapper_registry.map_imperatively(
assert b1 not in sess.dirty
def test_m2o_viewonly_oneside(self):
- class A(fixtures.ComparableEntity):
+ class A(ComparableEntity):
pass
- class B(fixtures.ComparableEntity):
+ class B(ComparableEntity):
pass
self.mapper_registry.map_imperatively(
assert b1 not in sess.dirty
def test_o2m_viewonly_only(self):
- class A(fixtures.ComparableEntity):
+ class A(ComparableEntity):
pass
- class B(fixtures.ComparableEntity):
+ class B(ComparableEntity):
pass
self.mapper_registry.map_imperatively(
self._assert_fk(a1, b1, False)
def test_m2o_viewonly_only(self):
- class A(fixtures.ComparableEntity):
+ class A(ComparableEntity):
pass
- class B(fixtures.ComparableEntity):
+ class B(ComparableEntity):
pass
self.mapper_registry.map_imperatively(A, self.tables.t1)
def test_viewonly(self):
t1t2, t2, t1 = (self.tables.t1t2, self.tables.t2, self.tables.t1)
- class A(fixtures.ComparableEntity):
+ class A(ComparableEntity):
pass
- class B(fixtures.ComparableEntity):
+ class B(ComparableEntity):
pass
self.mapper_registry.map_imperatively(
t2, t3, t1 = (self.tables.t2, self.tables.t3, self.tables.t1)
- class C1(fixtures.BasicEntity):
+ class C1(BasicEntity):
pass
- class C2(fixtures.BasicEntity):
+ class C2(BasicEntity):
pass
- class C3(fixtures.BasicEntity):
+ class C3(BasicEntity):
pass
self.mapper_registry.map_imperatively(
@testing.combinations(True, False, None, argnames="B_a_sync")
@testing.combinations(True, False, argnames="B_a_view")
def test_case(self, B_a_view, B_a_sync, A_bs_view, A_bs_sync):
- class A(fixtures.ComparableEntity):
+ class A(ComparableEntity):
pass
- class B(fixtures.ComparableEntity):
+ class B(ComparableEntity):
pass
case = self.cases[(B_a_view, B_a_sync, A_bs_view, A_bs_sync)]
t2, t3, t1 = (self.tables.t2, self.tables.t3, self.tables.t1)
- class C1(fixtures.BasicEntity):
+ class C1(BasicEntity):
pass
- class C2(fixtures.BasicEntity):
+ class C2(BasicEntity):
pass
- class C3(fixtures.BasicEntity):
+ class C3(BasicEntity):
pass
self.mapper_registry.map_imperatively(
def test_viewonly_join(self):
bars, foos = self.tables.bars, self.tables.foos
- class Foo(fixtures.ComparableEntity):
+ class Foo(ComparableEntity):
pass
- class Bar(fixtures.ComparableEntity):
+ class Bar(ComparableEntity):
pass
self.mapper_registry.map_imperatively(
def test_relationship_on_or(self):
bars, foos = self.tables.bars, self.tables.foos
- class Foo(fixtures.ComparableEntity):
+ class Foo(ComparableEntity):
pass
- class Bar(fixtures.ComparableEntity):
+ class Bar(ComparableEntity):
pass
self.mapper_registry.map_imperatively(
def test_relationship_on_or(self):
bars, foos = self.tables.bars, self.tables.foos
- class Foo(fixtures.ComparableEntity):
+ class Foo(ComparableEntity):
pass
- class Bar(fixtures.ComparableEntity):
+ class Bar(ComparableEntity):
pass
self.mapper_registry.map_imperatively(
def setup_classes(cls):
Base = cls.DeclarativeBasic
- class Network(fixtures.ComparableEntity, Base):
+ class Network(ComparableEntity, Base):
__tablename__ = "network"
id = Column(
viewonly=True,
)
- class Address(fixtures.ComparableEntity, Base):
+ class Address(ComparableEntity, Base):
__tablename__ = "address"
ip_addr = Column(Integer, primary_key=True)
from sqlalchemy.testing import fixtures
from sqlalchemy.testing import is_
from sqlalchemy.testing import mock
+from sqlalchemy.testing.entities import ComparableEntity
from sqlalchemy.testing.schema import Column
from sqlalchemy.testing.schema import Table
class CustomQuery(query.Query):
pass
- class SomeObject(fixtures.ComparableEntity):
+ class SomeObject(ComparableEntity):
query = Session.query_property()
- class SomeOtherObject(fixtures.ComparableEntity):
+ class SomeOtherObject(ComparableEntity):
query = Session.query_property()
custom_query = Session.query_property(query_cls=CustomQuery)
from sqlalchemy.testing.assertsql import AllOf
from sqlalchemy.testing.assertsql import assert_engine
from sqlalchemy.testing.assertsql import CompiledSQL
-from sqlalchemy.testing.fixtures import ComparableEntity
+from sqlalchemy.testing.entities import ComparableEntity
from sqlalchemy.testing.fixtures import fixture_session
from sqlalchemy.testing.schema import Column
from sqlalchemy.testing.schema import Table
def test_ordering(self):
a, m2m, b = (self.tables.a, self.tables.m2m, self.tables.b)
- class A(fixtures.ComparableEntity):
+ class A(ComparableEntity):
pass
- class B(fixtures.ComparableEntity):
+ class B(ComparableEntity):
pass
self.mapper_registry.map_imperatively(
def setup_classes(cls):
Base = cls.DeclarativeBasic
- class A(fixtures.ComparableEntity, Base):
+ class A(ComparableEntity, Base):
__tablename__ = "a"
id1 = Column(Integer, primary_key=True)
id2 = Column(Integer, primary_key=True)
bs = relationship("B", order_by="B.id", back_populates="a")
- class B(fixtures.ComparableEntity, Base):
+ class B(ComparableEntity, Base):
__tablename__ = "b"
id = Column(Integer, primary_key=True)
a_id1 = Column()
def setup_classes(cls):
Base = cls.DeclarativeBasic
- class A(fixtures.ComparableEntity, Base):
+ class A(ComparableEntity, Base):
__tablename__ = "a"
id = Column(Integer, primary_key=True)
bs = relationship("B", order_by="B.id", back_populates="a")
- class B(fixtures.ComparableEntity, Base):
+ class B(ComparableEntity, Base):
__tablename__ = "b"
id = Column(Integer, primary_key=True)
a_id = Column(ForeignKey("a.id"))
def setup_classes(cls):
Base = cls.DeclarativeBasic
- class Foo(fixtures.ComparableEntity, Base):
+ class Foo(ComparableEntity, Base):
__tablename__ = "foo"
id = Column(Integer, primary_key=True)
type = Column(String(50))
def setup_classes(cls):
Base = cls.DeclarativeBasic
- class A(fixtures.ComparableEntity, Base):
+ class A(ComparableEntity, Base):
__tablename__ = "a"
id = Column(Integer, primary_key=True)
b_id = Column(Integer)
b = relationship("B", primaryjoin="foreign(A.b_id) == B.id")
q = Column(Integer)
- class B(fixtures.ComparableEntity, Base):
+ class B(ComparableEntity, Base):
__tablename__ = "b"
id = Column(Integer, primary_key=True)
x = Column(Integer)
def setup_classes(cls):
Base = cls.DeclarativeBasic
- class A(fixtures.ComparableEntity, Base):
+ class A(ComparableEntity, Base):
__tablename__ = "a"
id = Column(Integer, primary_key=True)
b_id = Column(ForeignKey("b.id"))
b_no_omit_join = relationship("B", omit_join=False, overlaps="b")
q = Column(Integer)
- class B(fixtures.ComparableEntity, Base):
+ class B(ComparableEntity, Base):
__tablename__ = "b"
id = Column(Integer, primary_key=True)
x = Column(Integer)
def test_ordering(self):
a, m2m, b = (self.tables.a, self.tables.m2m, self.tables.b)
- class A(fixtures.ComparableEntity):
+ class A(ComparableEntity):
pass
- class B(fixtures.ComparableEntity):
+ class B(ComparableEntity):
pass
self.mapper_registry.map_imperatively(
def test_basic(self):
nodes = self.tables.nodes
- class Node(fixtures.ComparableEntity):
+ class Node(ComparableEntity):
def append(self, node):
self.children.append(node)
def test_lazy_fallback_doesnt_affect_eager(self):
nodes = self.tables.nodes
- class Node(fixtures.ComparableEntity):
+ class Node(ComparableEntity):
def append(self, node):
self.children.append(node)
def test_with_deferred(self):
nodes = self.tables.nodes
- class Node(fixtures.ComparableEntity):
+ class Node(ComparableEntity):
def append(self, node):
self.children.append(node)
def test_options(self):
nodes = self.tables.nodes
- class Node(fixtures.ComparableEntity):
+ class Node(ComparableEntity):
def append(self, node):
self.children.append(node)
nodes = self.tables.nodes
- class Node(fixtures.ComparableEntity):
+ class Node(ComparableEntity):
def append(self, node):
self.children.append(node)
from sqlalchemy.testing.assertsql import AllOf
from sqlalchemy.testing.assertsql import CompiledSQL
from sqlalchemy.testing.assertsql import Conditional
+from sqlalchemy.testing.entities import BasicEntity
+from sqlalchemy.testing.entities import ComparableEntity
from sqlalchemy.testing.fixtures import fixture_session
from sqlalchemy.testing.provision import normalize_sequence
from sqlalchemy.testing.schema import Column
def test_mapping(self):
t2, t1 = self.tables.t2, self.tables.t1
- class A(fixtures.ComparableEntity):
+ class A(ComparableEntity):
pass
- class B(fixtures.ComparableEntity):
+ class B(ComparableEntity):
pass
self.mapper_registry.map_imperatively(
def test_inheritance_mapping(self):
t2, t1 = self.tables.t2, self.tables.t1
- class A(fixtures.ComparableEntity):
+ class A(ComparableEntity):
pass
class B(A):
def test_naming(self):
book = self.tables.book
- class Book(fixtures.ComparableEntity):
+ class Book(ComparableEntity):
pass
self.mapper_registry.map_imperatively(Book, book)
def test_synonym(self):
users = self.tables.users
- class SUser(fixtures.BasicEntity):
+ class SUser(BasicEntity):
def _get_name(self):
return "User:" + self.name
self.classes.Item,
)
- class IKAssociation(fixtures.ComparableEntity):
+ class IKAssociation(ComparableEntity):
pass
self.mapper_registry.map_imperatively(Keyword, keywords)
t1_t = self.tables.t1_t
# use the regular mapper
- class T(fixtures.ComparableEntity):
+ class T(ComparableEntity):
pass
self.mapper_registry.map_imperatively(T, t1_t)
from sqlalchemy.testing.assertsql import CompiledSQL
from sqlalchemy.testing.assertsql import Conditional
from sqlalchemy.testing.assertsql import RegexSQL
+from sqlalchemy.testing.entities import BasicEntity
+from sqlalchemy.testing.entities import ComparableEntity
from sqlalchemy.testing.fixtures import fixture_session
from sqlalchemy.testing.provision import normalize_sequence
from sqlalchemy.testing.schema import Column
def test_flush_size(self):
foobars, nodes = self.tables.foobars, self.tables.nodes
- class Node(fixtures.ComparableEntity):
+ class Node(ComparableEntity):
pass
- class FooBar(fixtures.ComparableEntity):
+ class FooBar(ComparableEntity):
pass
self.mapper_registry.map_imperatively(
def test_many_to_many_one(self):
nodes, node_to_nodes = self.tables.nodes, self.tables.node_to_nodes
- class Node(fixtures.ComparableEntity):
+ class Node(ComparableEntity):
pass
self.mapper_registry.map_imperatively(
def _fixture(self):
parent, child = self.tables.parent, self.tables.child
- class Parent(fixtures.BasicEntity):
+ class Parent(BasicEntity):
pass
- class Child(fixtures.BasicEntity):
+ class Child(BasicEntity):
pass
self.mapper_registry.map_imperatively(
def _fixture(self):
a, b, c = self.tables.a, self.tables.b, self.tables.c
- class A(fixtures.BasicEntity):
+ class A(BasicEntity):
pass
- class B(fixtures.BasicEntity):
+ class B(BasicEntity):
pass
- class C(fixtures.BasicEntity):
+ class C(BasicEntity):
pass
self.mapper_registry.map_imperatively(
def _fixture(self, confirm_deleted_rows=True):
parent, child = self.tables.parent, self.tables.child
- class Parent(fixtures.BasicEntity):
+ class Parent(BasicEntity):
pass
- class Child(fixtures.BasicEntity):
+ class Child(BasicEntity):
pass
self.mapper_registry.map_imperatively(
t = self.tables.t
- class T(fixtures.ComparableEntity):
+ class T(ComparableEntity):
pass
mp = self.mapper_registry.map_imperatively(T, t)
from sqlalchemy.testing import assert_raises
from sqlalchemy.testing import assert_raises_message
from sqlalchemy.testing import eq_
-from sqlalchemy.testing import fixtures
from sqlalchemy.testing import ne_
+from sqlalchemy.testing.entities import ComparableEntity
from sqlalchemy.testing.fixtures import fixture_session
from test.orm import _fixtures
users = self.tables.users
canary = Mock()
- class User(fixtures.ComparableEntity):
+ class User(ComparableEntity):
@validates("name")
def validate_name(self, key, name):
canary(key, name)
canary = Mock()
- class User(fixtures.ComparableEntity):
+ class User(ComparableEntity):
@validates("addresses")
def validate_address(self, key, ad):
canary(key, ad)
self.classes.Address,
)
- class User(fixtures.ComparableEntity):
+ class User(ComparableEntity):
@validates("name")
def validate_name(self, key, name):
ne_(name, "fred")
)
canary = Mock()
- class User(fixtures.ComparableEntity):
+ class User(ComparableEntity):
@validates("name", include_removes=True)
def validate_name(self, key, item, remove):
canary(key, item, remove)
self.classes.Address,
)
- class User(fixtures.ComparableEntity):
+ class User(ComparableEntity):
@validates("addresses", include_removes=True)
def validate_address(self, key, item, remove):
if not remove:
self.classes.Address,
)
- class User(fixtures.ComparableEntity):
+ class User(ComparableEntity):
@validates("addresses", include_removes=True)
def validate_address(self, key, item, remove):
if not remove:
ne_(name, "fred")
return name + " modified"
- class User(fixtures.ComparableEntity):
+ class User(ComparableEntity):
sv = validates("name")(SomeValidator())
self.mapper_registry.map_imperatively(User, users)
bool(include_removes) and not include_removes.default
)
- class User(fixtures.ComparableEntity):
+ class User(ComparableEntity):
if need_remove_param:
@validates("addresses", **validate_kw)
canary(key, item)
return item
- class Address(fixtures.ComparableEntity):
+ class Address(ComparableEntity):
if need_remove_param:
@validates("user", **validate_kw)
--- /dev/null
+from sqlalchemy import create_engine
+from sqlalchemy import Pool
+from sqlalchemy import text
+
+
+def regular() -> None:
+ e = create_engine("sqlite://")
+
+ # EXPECTED_TYPE: Engine
+ reveal_type(e)
+
+ with e.connect() as conn:
+ # EXPECTED_TYPE: Connection
+ reveal_type(conn)
+
+ result = conn.execute(text("select * from table"))
+
+ # EXPECTED_TYPE: CursorResult[Any]
+ reveal_type(result)
+
+ with e.begin() as conn:
+ # EXPECTED_TYPE: Connection
+ reveal_type(conn)
+
+ result = conn.execute(text("select * from table"))
+
+ # EXPECTED_TYPE: CursorResult[Any]
+ reveal_type(result)
+
+ engine = create_engine("postgresql://scott:tiger@localhost/test")
+ status: str = engine.pool.status()
+ other_pool: Pool = engine.pool.recreate()
+
+ print(status, other_pool)
await session.commit()
+ trans_ctx = engine.begin()
+ async with trans_ctx as connection:
+ await connection.execute(select(A))
+
asyncio.run(async_main())
--- /dev/null
+from asyncio import current_task
+
+from sqlalchemy import text
+from sqlalchemy.ext.asyncio import async_scoped_session
+from sqlalchemy.ext.asyncio import async_sessionmaker
+from sqlalchemy.ext.asyncio import AsyncSession
+from sqlalchemy.ext.asyncio import create_async_engine
+
+
+engine = create_async_engine("")
+SM = async_sessionmaker(engine, class_=AsyncSession)
+
+async_session = AsyncSession(engine)
+
+as_session = async_scoped_session(SM, current_task)
+
+
+async def go() -> None:
+ r = await async_session.scalars(text("select 1"), params=[])
+ r.first()
+ sr = await async_session.stream_scalars(text("select 1"), params=[])
+ await sr.all()
+ r = await as_session.scalars(text("select 1"), params=[])
+ r.first()
+ sr = await as_session.stream_scalars(text("select 1"), params=[])
+ await sr.all()
+
+ async with engine.connect() as conn:
+ cr = await conn.scalars(text("select 1"))
+ cr.first()
+ scr = await conn.stream_scalars(text("select 1"))
+ await scr.all()
+
+ ast = async_session.get_transaction()
+ if ast:
+ ast.is_active
+ nt = async_session.get_nested_transaction()
+ if nt:
+ nt.is_active
--- /dev/null
+from sqlalchemy import text
+from sqlalchemy.ext.asyncio import async_scoped_session
+from sqlalchemy.ext.asyncio import AsyncConnection
+from sqlalchemy.ext.asyncio import AsyncEngine
+from sqlalchemy.ext.asyncio import AsyncSession
+from sqlalchemy.ext.asyncio import create_async_engine
+from sqlalchemy.ext.asyncio.session import async_sessionmaker
+
+# async engine
+async_engine: AsyncEngine = create_async_engine("")
+async_engine.clear_compiled_cache()
+async_engine.update_execution_options()
+async_engine.get_execution_options()
+async_engine.url
+async_engine.pool
+async_engine.dialect
+async_engine.engine
+async_engine.name
+async_engine.driver
+async_engine.echo
+
+
+# async connection
+async def go_async_conn() -> None:
+ async_conn: AsyncConnection = await async_engine.connect()
+ async_conn.closed
+ async_conn.invalidated
+ async_conn.dialect
+ async_conn.default_isolation_level
+
+
+# async session
+AsyncSession.object_session(object())
+AsyncSession.identity_key()
+async_session: AsyncSession = AsyncSession(async_engine)
+in_: bool = "foo" in async_session
+list(async_session)
+async_session.add(object())
+async_session.add_all([])
+async_session.expire(object())
+async_session.expire_all()
+async_session.expunge(object())
+async_session.expunge_all()
+async_session.get_bind()
+async_session.is_modified(object())
+async_session.in_transaction()
+async_session.in_nested_transaction()
+async_session.dirty
+async_session.deleted
+async_session.new
+async_session.identity_map
+async_session.is_active
+async_session.autoflush
+async_session.no_autoflush
+async_session.info
+
+
+# async scoped session
+async def test_async_scoped_session() -> None:
+ async_scoped_session.object_session(object())
+ async_scoped_session.identity_key()
+ await async_scoped_session.close_all()
+ asm = async_sessionmaker()
+ async_ss = async_scoped_session(asm, lambda: 42)
+ value: bool = "foo" in async_ss
+ print(value)
+ list(async_ss)
+ async_ss.add(object())
+ async_ss.add_all([])
+ async_ss.begin()
+ async_ss.begin_nested()
+ await async_ss.close()
+ await async_ss.commit()
+ await async_ss.connection()
+ await async_ss.delete(object())
+ await async_ss.execute(text("select 1"))
+ async_ss.expire(object())
+ async_ss.expire_all()
+ async_ss.expunge(object())
+ async_ss.expunge_all()
+ await async_ss.flush()
+ await async_ss.get(object, 1)
+ async_ss.get_bind()
+ async_ss.is_modified(object())
+ await async_ss.merge(object())
+ await async_ss.refresh(object())
+ await async_ss.rollback()
+ await async_ss.scalar(text("select 1"))
+ async_ss.bind
+ async_ss.dirty
+ async_ss.deleted
+ async_ss.new
+ async_ss.identity_map
+ async_ss.is_active
+ async_ss.autoflush
+ async_ss.no_autoflush
+ async_ss.info
-from sqlalchemy import create_engine
from sqlalchemy import text
from sqlalchemy.ext.asyncio import create_async_engine
-def regular() -> None:
- e = create_engine("sqlite://")
-
- # EXPECTED_TYPE: Engine
- reveal_type(e)
-
- with e.connect() as conn:
- # EXPECTED_TYPE: Connection
- reveal_type(conn)
-
- result = conn.execute(text("select * from table"))
-
- # EXPECTED_TYPE: CursorResult[Any]
- reveal_type(result)
-
- with e.begin() as conn:
- # EXPECTED_TYPE: Connection
- reveal_type(conn)
-
- result = conn.execute(text("select * from table"))
-
- # EXPECTED_TYPE: CursorResult[Any]
- reveal_type(result)
-
-
async def asyncio() -> None:
e = create_async_engine("sqlite://")
however this is not really working
"""
+from typing import Any
+from typing import Optional
+
from sqlalchemy import Column
from sqlalchemy import create_engine
from sqlalchemy import inspect
from sqlalchemy import Integer
from sqlalchemy import String
from sqlalchemy.engine.reflection import Inspector
-from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.orm import Mapper
-Base = declarative_base()
+
+class Base(DeclarativeBase):
+ pass
class A(Base):
# TODO: I can't get these to work, pylance and mypy both don't want
# to accommodate for different types for the first argument
-t: bool = inspect(a1).transient
+t: Optional[Any] = inspect(a1)
-m: Mapper = inspect(A)
+m: Mapper[Any] = inspect(A)
inspect(e).get_table_names()
-# NOPLUGINS
# this should pass typing with no plugins
from typing import Any
}
__table_args__ = (
- Index("my_index", name, type),
+ Index("my_index", name, type.desc()),
UniqueConstraint(name),
PrimaryKeyConstraint(id),
{"prefix": []},
--- /dev/null
+from datetime import datetime
+
+from sqlalchemy import create_engine
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import registry
+from sqlalchemy.orm import Session
+from sqlalchemy.sql.functions import now
+from sqlalchemy.testing.schema import mapped_column
+
+mapper_registry: registry = registry()
+e = create_engine("sqlite:///database.db", echo=True)
+
+
+@mapper_registry.mapped
+class A:
+ __tablename__ = "a"
+ id: Mapped[int] = mapped_column(primary_key=True)
+ date_time: Mapped[datetime]
+
+
+mapper_registry.metadata.create_all(e)
+
+with Session(e) as s:
+ a = A()
+ a.date_time = now()
+ s.add(a)
+ s.commit()
# EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.str\*?\]
reveal_type(Address.email_name)
- # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*?\[experimental_relationship.Address\]\]
+ # EXPECTED_RE_TYPE: sqlalchemy.*.InstrumentedAttribute\[builtins.list\*?\[relationship.Address\]\]
reveal_type(User.addresses_style_one)
- # EXPECTED_RE_TYPE: sqlalchemy.orm.attributes.InstrumentedAttribute\[builtins.set\*?\[experimental_relationship.Address\]\]
+ # EXPECTED_RE_TYPE: sqlalchemy.orm.attributes.InstrumentedAttribute\[builtins.set\*?\[relationship.Address\]\]
reveal_type(User.addresses_style_two)
--- /dev/null
+from sqlalchemy import inspect
+from sqlalchemy import text
+from sqlalchemy.orm import DeclarativeBase
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import mapped_column
+from sqlalchemy.orm import scoped_session
+from sqlalchemy.orm import sessionmaker
+
+
+class Base(DeclarativeBase):
+ pass
+
+
+class X(Base):
+ __tablename__ = "x"
+ id: Mapped[int] = mapped_column(primary_key=True)
+
+
+scoped_session.object_session(object())
+scoped_session.identity_key()
+scoped_session.close_all()
+ss = scoped_session(sessionmaker())
+value: bool = "foo" in ss
+list(ss)
+ss.add(object())
+ss.add_all([])
+ss.begin()
+ss.begin_nested()
+ss.close()
+ss.commit()
+ss.connection()
+ss.delete(object())
+ss.execute(text("select 1"))
+ss.expire(object())
+ss.expire_all()
+ss.expunge(object())
+ss.expunge_all()
+ss.flush()
+ss.get(object, 1)
+b = ss.get_bind()
+ss.is_modified(object())
+ss.bulk_save_objects([])
+ss.bulk_insert_mappings(inspect(X), [])
+ss.bulk_update_mappings(inspect(X), [])
+ss.merge(object())
+q = (ss.query(object),)
+ss.refresh(object())
+ss.rollback()
+ss.scalar(text("select 1"))
+ss.bind
+ss.dirty
+ss.deleted
+ss.new
+ss.identity_map
+ss.is_active
+ss.autoflush
+ss.no_autoflush
+ss.info
--- /dev/null
+from sqlalchemy import Boolean
+from sqlalchemy import CheckConstraint
+from sqlalchemy import Column
+from sqlalchemy import DateTime
+from sqlalchemy import FetchedValue
+from sqlalchemy import ForeignKey
+from sqlalchemy import func
+from sqlalchemy import Index
+from sqlalchemy import Integer
+from sqlalchemy import literal_column
+from sqlalchemy import MetaData
+from sqlalchemy import PrimaryKeyConstraint
+from sqlalchemy import String
+from sqlalchemy import Table
+from sqlalchemy import text
+from sqlalchemy import true
+from sqlalchemy import UUID
+
+
+m = MetaData()
+
+
+t1 = Table(
+ "t1",
+ m,
+ Column("id", Integer, primary_key=True),
+ Column("data", String),
+ Column("data2", String(50)),
+ Column("timestamp", DateTime()),
+ Index(None, "data2"),
+)
+
+t2 = Table(
+ "t2",
+ m,
+ Column("t1id", ForeignKey("t1.id")),
+ Column("q", Integer, CheckConstraint("q > 5")),
+)
+
+t3 = Table(
+ "t3",
+ m,
+ Column("x", Integer),
+ Column("y", Integer),
+ Column("t1id", ForeignKey(t1.c.id)),
+ PrimaryKeyConstraint("x", "y"),
+)
+
+t4 = Table(
+ "test_table",
+ m,
+ Column("i", UUID(as_uuid=True), nullable=False, primary_key=True),
+ Column("x", UUID(as_uuid=True), index=True),
+ Column("y", UUID(as_uuid=False), index=True),
+ Index("ix_xy_unique", "x", "y", unique=True),
+)
+
+
+# cols w/ no name or type, used by declarative
+c1: Column[int] = Column(ForeignKey(t3.c.x))
+# more colum args
+Column("name", Integer, index=True)
+Column(None, name="name")
+Column(Integer, name="name", index=True)
+Column("name", ForeignKey("a.id"))
+Column(ForeignKey("a.id"), type_=None, index=True)
+Column(ForeignKey("a.id"), name="name", type_=Integer())
+Column("name", None)
+Column("name", index=True)
+Column(ForeignKey("a.id"), name="name", index=True)
+Column(type_=None, index=True)
+Column(None, ForeignKey("a.id"))
+Column("name")
+Column(name="name", type_=None, index=True)
+Column(ForeignKey("a.id"), name="name", type_=None)
+Column(Integer)
+Column(ForeignKey("a.id"), type_=Integer())
+Column("name", Integer, ForeignKey("a.id"), index=True)
+Column("name", None, ForeignKey("a.id"), index=True)
+Column(ForeignKey("a.id"), index=True)
+Column("name", Integer)
+Column(Integer, name="name")
+Column(Integer, ForeignKey("a.id"), name="name", index=True)
+Column(ForeignKey("a.id"), type_=None)
+Column(ForeignKey("a.id"), name="name")
+Column(name="name", index=True)
+Column(type_=None)
+Column(None, index=True)
+Column(name="name", type_=None)
+Column(type_=Integer(), index=True)
+Column("name", Integer, ForeignKey("a.id"))
+Column(name="name", type_=Integer(), index=True)
+Column(Integer, ForeignKey("a.id"), index=True)
+Column("name", None, ForeignKey("a.id"))
+Column(index=True)
+Column("name", type_=None, index=True)
+Column("name", ForeignKey("a.id"), type_=Integer(), index=True)
+Column(ForeignKey("a.id"))
+Column(Integer, ForeignKey("a.id"))
+Column(Integer, ForeignKey("a.id"), name="name")
+Column("name", ForeignKey("a.id"), index=True)
+Column("name", type_=Integer(), index=True)
+Column(ForeignKey("a.id"), name="name", type_=Integer(), index=True)
+Column(name="name")
+Column("name", None, index=True)
+Column("name", ForeignKey("a.id"), type_=None, index=True)
+Column("name", type_=Integer())
+Column(None)
+Column(None, ForeignKey("a.id"), index=True)
+Column("name", ForeignKey("a.id"), type_=None)
+Column(type_=Integer())
+Column(None, ForeignKey("a.id"), name="name", index=True)
+Column(Integer, index=True)
+Column(ForeignKey("a.id"), name="name", type_=None, index=True)
+Column(ForeignKey("a.id"), type_=Integer(), index=True)
+Column(name="name", type_=Integer())
+Column(None, name="name", index=True)
+Column()
+Column(None, ForeignKey("a.id"), name="name")
+Column("name", type_=None)
+Column("name", ForeignKey("a.id"), type_=Integer())
+
+# server_default
+Column(Boolean, nullable=False, server_default=true())
+Column(DateTime, server_default=func.now(), nullable=False)
+Column(Boolean, server_default=func.xyzq(), nullable=False)
+# what would be *nice* to emit an error would be this, but this
+# is really not important, people don't usually put types in functions
+# as they are usually part of a bigger context where the type is known
+Column(Boolean, server_default=func.xyzq(type_=DateTime), nullable=False)
+Column(DateTime, server_default="now()")
+Column(DateTime, server_default=text("now()"))
+Column(DateTime, server_default=FetchedValue())
+Column(Boolean, server_default=literal_column("false", Boolean))
+Column("name", server_default=FetchedValue(), nullable=False)
+Column(server_default="now()", nullable=False)
+Column("name", Integer, server_default=text("now()"), nullable=False)
+Column(Integer, server_default=literal_column("42", Integer), nullable=False)
+
+# server_onupdate
+Column("name", server_onupdate=FetchedValue(), nullable=False)
+Column(server_onupdate=FetchedValue(), nullable=False)
+Column("name", Integer, server_onupdate=FetchedValue(), nullable=False)
+Column(Integer, server_onupdate=FetchedValue(), nullable=False)
+
+# TypeEngine.with_variant should accept both a TypeEngine instance and the Concrete Type
+Integer().with_variant(Integer, "mysql")
+Integer().with_variant(Integer(), "mysql")
+# Also test Variant.with_variant
+Integer().with_variant(Integer, "mysql").with_variant(Integer, "mysql")
+Integer().with_variant(Integer, "mysql").with_variant(Integer(), "mysql")
--- /dev/null
+from sqlalchemy import func
+from sqlalchemy.orm import DeclarativeBase
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import mapped_column
+
+
+class Base(DeclarativeBase):
+ pass
+
+
+class Foo(Base):
+ __tablename__ = "foo"
+
+ id: Mapped[int] = mapped_column(primary_key=True)
+ a: Mapped[int]
+ b: Mapped[int]
+
+
+func.row_number().over(order_by=Foo.a, partition_by=Foo.b.desc())
+func.row_number().over(order_by=[Foo.a.desc(), Foo.b.desc()])
+func.row_number().over(partition_by=[Foo.a.desc(), Foo.b.desc()])
+func.row_number().over(order_by="a", partition_by=("a", "b"))
+func.row_number().over(partition_by="a", order_by=("a", "b"))
--- /dev/null
+import sqlalchemy as sa
+
+Book = sa.table(
+ "book",
+ sa.column("id", sa.Integer),
+ sa.column("name", sa.String),
+)
+Book.append_column(sa.column("other"))
+Book.corresponding_column(Book.c.id)
+
+value_expr = sa.values(
+ sa.column("id", sa.Integer), sa.column("name", sa.String), name="my_values"
+).data([(1, "name1"), (2, "name2"), (3, "name3")])
+
+sa.select(Book)
+sa.select(sa.literal_column("42"), sa.column("foo")).select_from(sa.table("t"))
--- /dev/null
+from decimal import Decimal
+from typing import Any
+
+from sqlalchemy import ARRAY
+from sqlalchemy import BigInteger
+from sqlalchemy import column
+from sqlalchemy import ColumnElement
+from sqlalchemy import Integer
+from sqlalchemy import String
+from sqlalchemy.orm import DeclarativeBase
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import mapped_column
+
+
+class Base(DeclarativeBase):
+ pass
+
+
+class A(Base):
+ __tablename__ = "a"
+ id: Mapped[int]
+ string: Mapped[str]
+ arr: Mapped[list[int]] = mapped_column(ARRAY(Integer))
+
+
+lt1: "ColumnElement[bool]" = A.id > A.id
+lt2: "ColumnElement[bool]" = A.id > 1
+lt3: "ColumnElement[bool]" = 1 < A.id
+
+le1: "ColumnElement[bool]" = A.id >= A.id
+le2: "ColumnElement[bool]" = A.id >= 1
+le3: "ColumnElement[bool]" = 1 <= A.id
+
+eq1: "ColumnElement[bool]" = A.id == A.id
+eq2: "ColumnElement[bool]" = A.id == 1
+# eq3: "ColumnElement[bool]" = 1 == A.id
+
+ne1: "ColumnElement[bool]" = A.id != A.id
+ne2: "ColumnElement[bool]" = A.id != 1
+# ne3: "ColumnElement[bool]" = 1 != A.id
+
+gt1: "ColumnElement[bool]" = A.id < A.id
+gt2: "ColumnElement[bool]" = A.id < 1
+gt3: "ColumnElement[bool]" = 1 > A.id
+
+ge1: "ColumnElement[bool]" = A.id <= A.id
+ge2: "ColumnElement[bool]" = A.id <= 1
+ge3: "ColumnElement[bool]" = 1 >= A.id
+
+
+# TODO "in" doesn't seem to pick up the typing of __contains__?
+# EXPECTED_MYPY: Incompatible types in assignment (expression has type "bool", variable has type "ColumnElement[bool]") # noqa: E501
+contains1: "ColumnElement[bool]" = A.id in A.arr
+# EXPECTED_MYPY: Incompatible types in assignment (expression has type "bool", variable has type "ColumnElement[bool]") # noqa: E501
+contains2: "ColumnElement[bool]" = A.id in A.string
+
+lshift1: "ColumnElement[int]" = A.id << A.id
+lshift2: "ColumnElement[int]" = A.id << 1
+lshift3: "ColumnElement[Any]" = A.string << 1
+
+rshift1: "ColumnElement[int]" = A.id >> A.id
+rshift2: "ColumnElement[int]" = A.id >> 1
+rshift3: "ColumnElement[Any]" = A.string >> 1
+
+concat1: "ColumnElement[str]" = A.string.concat(A.string)
+concat2: "ColumnElement[str]" = A.string.concat(1)
+concat3: "ColumnElement[str]" = A.string.concat("a")
+
+like1: "ColumnElement[bool]" = A.string.like("test")
+like2: "ColumnElement[bool]" = A.string.like("test", escape="/")
+ilike1: "ColumnElement[bool]" = A.string.ilike("test")
+ilike2: "ColumnElement[bool]" = A.string.ilike("test", escape="/")
+
+in_: "ColumnElement[bool]" = A.id.in_([1, 2])
+not_in: "ColumnElement[bool]" = A.id.not_in([1, 2])
+
+not_like1: "ColumnElement[bool]" = A.string.not_like("test")
+not_like2: "ColumnElement[bool]" = A.string.not_like("test", escape="/")
+not_ilike1: "ColumnElement[bool]" = A.string.not_ilike("test")
+not_ilike2: "ColumnElement[bool]" = A.string.not_ilike("test", escape="/")
+
+is_: "ColumnElement[bool]" = A.string.is_("test")
+is_not: "ColumnElement[bool]" = A.string.is_not("test")
+
+startswith: "ColumnElement[bool]" = A.string.startswith("test")
+endswith: "ColumnElement[bool]" = A.string.endswith("test")
+contains: "ColumnElement[bool]" = A.string.contains("test")
+match: "ColumnElement[bool]" = A.string.match("test")
+regexp_match: "ColumnElement[bool]" = A.string.regexp_match("test")
+
+regexp_replace: "ColumnElement[str]" = A.string.regexp_replace(
+ "pattern", "replacement"
+)
+between: "ColumnElement[bool]" = A.string.between("a", "b")
+
+adds: "ColumnElement[str]" = A.string + A.string
+add1: "ColumnElement[int]" = A.id + A.id
+add2: "ColumnElement[int]" = A.id + 1
+add3: "ColumnElement[int]" = 1 + A.id
+
+sub1: "ColumnElement[int]" = A.id - A.id
+sub2: "ColumnElement[int]" = A.id - 1
+sub3: "ColumnElement[int]" = 1 - A.id
+
+mul1: "ColumnElement[int]" = A.id * A.id
+mul2: "ColumnElement[int]" = A.id * 1
+mul3: "ColumnElement[int]" = 1 * A.id
+
+div1: "ColumnElement[float|Decimal]" = A.id / A.id
+div2: "ColumnElement[float|Decimal]" = A.id / 1
+div3: "ColumnElement[float|Decimal]" = 1 / A.id
+
+mod1: "ColumnElement[int]" = A.id % A.id
+mod2: "ColumnElement[int]" = A.id % 1
+mod3: "ColumnElement[int]" = 1 % A.id
+
+# unary
+
+neg: "ColumnElement[int]" = -A.id
+
+desc: "ColumnElement[int]" = A.id.desc()
+asc: "ColumnElement[int]" = A.id.asc()
+any_: "ColumnElement[bool]" = A.id.any_()
+all_: "ColumnElement[bool]" = A.id.all_()
+nulls_first: "ColumnElement[int]" = A.id.nulls_first()
+nulls_last: "ColumnElement[int]" = A.id.nulls_last()
+collate: "ColumnElement[str]" = A.string.collate("somelang")
+distinct: "ColumnElement[int]" = A.id.distinct()
+
+
+# custom ops
+col = column("flags", Integer)
+op_a: "ColumnElement[Any]" = col.op("&")(1)
+op_b: "ColumnElement[int]" = col.op("&", return_type=Integer)(1)
+op_c: "ColumnElement[str]" = col.op("&", return_type=String)("1")
+op_d: "ColumnElement[int]" = col.op("&", return_type=BigInteger)("1")
+op_e: "ColumnElement[bool]" = col.bool_op("&")("1")
--- /dev/null
+import os
+
+from sqlalchemy import testing
+from sqlalchemy.testing import fixtures
+
+
+class MypyPlainTest(fixtures.MypyTest):
+ @testing.combinations(
+ *(
+ (os.path.basename(path), path)
+ for path in fixtures.MypyTest.file_combinations("plain_files")
+ ),
+ argnames="path",
+ id_="ia",
+ )
+ def test_mypy_no_plugin(self, mypy_typecheck_file, path):
+ mypy_typecheck_file(path)
functions_py = "lib/sqlalchemy/sql/functions.py"
-test_functions_py = "test/ext/mypy/plain_files/functions.py"
+test_functions_py = "test/typing/plain_files/sql/functions.py"
if __name__ == "__main__":