From eb9d00c4b4f4f15e871aa9ea88d41023054c6e97 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Mon, 9 Jan 2023 08:53:57 -0500 Subject: [PATCH] accept TableClause through mapped selectable chain type annotation somehow decided that TableClause doesn't have primary key fields which is not the case at all. In particular the "views" recipe relies on TableClause so adding a restriction like this does not make any sense. It seems the issue was to open this up for typing, by allowing TableClause out as far as ddl.sort_tables() typing is passing for now. Support it out in get_bind() etc. Fixes: #9071 Change-Id: If0e22e0e7df7bee0ff4b295b0ffacfbc6b7a0142 --- doc/build/changelog/unreleased_20/9071.rst | 7 ++++ lib/sqlalchemy/orm/mapper.py | 26 +++++------- lib/sqlalchemy/orm/session.py | 7 ++-- lib/sqlalchemy/sql/ddl.py | 7 +++- test/orm/test_bind.py | 34 +++++++++++++++- test/orm/test_mapper.py | 46 ++++++++++++++-------- 6 files changed, 89 insertions(+), 38 deletions(-) create mode 100644 doc/build/changelog/unreleased_20/9071.rst diff --git a/doc/build/changelog/unreleased_20/9071.rst b/doc/build/changelog/unreleased_20/9071.rst new file mode 100644 index 0000000000..d4645d71a7 --- /dev/null +++ b/doc/build/changelog/unreleased_20/9071.rst @@ -0,0 +1,7 @@ +.. change:: + :tags: bug, orm + :tickets: 9071 + + Fixed issue where an overly restrictive ORM mapping rule were added in 2.0 + which prevented mappings against :class:`.TableClause` objects, such as + those used in the view recipe on the wiki. diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 5220282748..20ad635b0f 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -78,6 +78,7 @@ from ..sql import coercions from ..sql import expression from ..sql import operators from ..sql import roles +from ..sql import TableClause from ..sql import util as sql_util from ..sql import visitors from ..sql.cache_key import MemoizedHasCacheKey @@ -892,7 +893,7 @@ class Mapper( _dependency_processors: List[DependencyProcessor] _memoized_values: Dict[Any, Callable[[], Any]] _inheriting_mappers: util.WeakSequence[Mapper[Any]] - _all_tables: Set[Table] + _all_tables: Set[TableClause] _polymorphic_attr_key: Optional[str] _pks_by_table: Dict[FromClause, OrderedSet[ColumnClause[Any]]] @@ -908,9 +909,10 @@ class Mapper( Callable[[Mapper[_O], InstanceState[_O], _InstanceDict], None] ] - tables: Sequence[Table] - """A sequence containing the collection of :class:`_schema.Table` objects - which this :class:`_orm.Mapper` is aware of. + tables: Sequence[TableClause] + """A sequence containing the collection of :class:`_schema.Table` + or :class:`_schema.TableClause` objects which this :class:`_orm.Mapper` + is aware of. If the mapper is mapped to a :class:`_expression.Join`, or an :class:`_expression.Alias` @@ -1534,17 +1536,9 @@ class Mapper( self.__dict__.pop("_configure_failed", None) def _configure_pks(self) -> None: - self.tables = cast( - "List[Table]", sql_util.find_tables(self.persist_selectable) - ) - for t in self.tables: - if not isinstance(t, Table): - raise sa_exc.ArgumentError( - f"ORM mappings can only be made against schema-level " - f"Table objects, not TableClause; got " - f"tableclause {t.name !r}" - ) - self._all_tables.update(t for t in self.tables if isinstance(t, Table)) + self.tables = sql_util.find_tables(self.persist_selectable) + + self._all_tables.update(t for t in self.tables) self._pks_by_table = {} self._cols_by_table = {} @@ -3802,7 +3796,7 @@ class Mapper( @HasMemoized.memoized_attribute def _sorted_tables(self): - table_to_mapper: Dict[Table, Mapper[Any]] = {} + table_to_mapper: Dict[TableClause, Mapper[Any]] = {} for mapper in self.base_mapper.self_and_descendants: for t in mapper.tables: diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 0298b17a75..5bcb22a083 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -80,6 +80,7 @@ from ..sql import coercions from ..sql import dml from ..sql import roles from ..sql import Select +from ..sql import TableClause from ..sql import visitors from ..sql.base import CompileState from ..sql.schema import Table @@ -152,7 +153,7 @@ _PKIdentityArgument = Union[Any, Tuple[Any, ...]] _BindArguments = Dict[str, Any] _EntityBindKey = Union[Type[_O], "Mapper[_O]"] -_SessionBindKey = Union[Type[Any], "Mapper[Any]", "Table", str] +_SessionBindKey = Union[Type[Any], "Mapper[Any]", "TableClause", str] _SessionBind = Union["Engine", "Connection"] JoinTransactionMode = Literal[ @@ -2439,7 +2440,7 @@ class Session(_SessionClassMethods, EventTarget): if TYPE_CHECKING: assert isinstance(insp, Inspectable) - if isinstance(insp, Table): + if isinstance(insp, TableClause): self.__binds[insp] = bind elif insp_is_mapper(insp): self.__binds[insp.class_] = bind @@ -2480,7 +2481,7 @@ class Session(_SessionClassMethods, EventTarget): """ self._add_bind(mapper, bind) - def bind_table(self, table: Table, bind: _SessionBind) -> None: + def bind_table(self, table: TableClause, bind: _SessionBind) -> None: """Associate a :class:`_schema.Table` with a "bind", e.g. an :class:`_engine.Engine` or :class:`_engine.Connection`. diff --git a/lib/sqlalchemy/sql/ddl.py b/lib/sqlalchemy/sql/ddl.py index 9a5b002446..5ea500a32b 100644 --- a/lib/sqlalchemy/sql/ddl.py +++ b/lib/sqlalchemy/sql/ddl.py @@ -42,6 +42,7 @@ if typing.TYPE_CHECKING: from .schema import SchemaItem from .schema import Sequence from .schema import Table + from .selectable import TableClause from ..engine.base import Connection from ..engine.interfaces import CacheStats from ..engine.interfaces import CompiledCacheType @@ -1179,9 +1180,11 @@ class SchemaDropper(InvokeDropDDLBase): def sort_tables( - tables: Iterable[Table], + tables: Iterable[TableClause], skip_fn: Optional[Callable[[ForeignKeyConstraint], bool]] = None, - extra_dependencies: Optional[typing_Sequence[Tuple[Table, Table]]] = None, + extra_dependencies: Optional[ + typing_Sequence[Tuple[TableClause, TableClause]] + ] = None, ) -> List[Table]: """Sort a collection of :class:`_schema.Table` objects based on dependency. diff --git a/test/orm/test_bind.py b/test/orm/test_bind.py index 409c6244f0..13958ec91c 100644 --- a/test/orm/test_bind.py +++ b/test/orm/test_bind.py @@ -8,6 +8,7 @@ from sqlalchemy import inspect from sqlalchemy import Integer from sqlalchemy import MetaData from sqlalchemy import select +from sqlalchemy import String from sqlalchemy import table from sqlalchemy import testing from sqlalchemy import true @@ -186,6 +187,10 @@ class BindIntegrationTest(_fixtures.FixtureTest): (lambda Address: {"mapper": Address}, "e2"), (lambda Address: {"clause": Query([Address])._statement_20()}, "e2"), (lambda addresses: {"clause": select(addresses)}, "e2"), + (lambda Dingaling: {"mapper": Dingaling}, "e4"), + (lambda addresses_view: {"clause": addresses_view}, "e4"), + (lambda addresses_view: {"clause": select(addresses_view)}, "e4"), + (lambda users_view: {"clause": select(users_view)}, "e2"), ( lambda User, addresses: { "mapper": User, @@ -260,23 +265,48 @@ class BindIntegrationTest(_fixtures.FixtureTest): self.tables.addresses, self.classes.User, ) + Dingaling = self.classes.Dingaling self.mapper_registry.map_imperatively( User, users, properties={"addresses": relationship(Address)} ) self.mapper_registry.map_imperatively(Address, addresses) + users_view = table("users", Column("id", Integer, primary_key=True)) + addresses_view = table( + "addresses", + Column("id", Integer, primary_key=True), + Column("user_id", Integer), + Column("email_address", String), + ) + j = users_view.join( + addresses_view, users_view.c.id == addresses_view.c.user_id + ) + self.mapper_registry.map_imperatively( + Dingaling, + j, + properties={ + "user_t_id": users_view.c.id, + "address_id": addresses_view.c.id, + }, + ) + e1 = engines.testing_engine() e2 = engines.testing_engine() e3 = engines.testing_engine() + e4 = engines.testing_engine() testcase = testing.resolve_lambda( testcase, User=User, Address=Address, + Dingaling=Dingaling, e1=e1, e2=e2, e3=e3, + e4=e4, + users_view=users_view, + addresses_view=addresses_view, addresses=addresses, users=users, ) @@ -284,8 +314,10 @@ class BindIntegrationTest(_fixtures.FixtureTest): sess = Session(e3) sess.bind_mapper(User, e1) sess.bind_mapper(Address, e2) + sess.bind_mapper(Dingaling, e4) + sess.bind_table(users_view, e2) - engine = {"e1": e1, "e2": e2, "e3": e3}[expected] + engine = {"e1": e1, "e2": e2, "e3": e3, "e4": e4}[expected] conn = sess.connection(bind_arguments=testcase) is_(conn.engine, engine) diff --git a/test/orm/test_mapper.py b/test/orm/test_mapper.py index b1d701d039..8b36f0b59e 100644 --- a/test/orm/test_mapper.py +++ b/test/orm/test_mapper.py @@ -363,33 +363,47 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): s, ) - def test_no_tableclause(self): - """It's not tested for a Mapper to have lower-case table() objects - as part of its collection of tables, and in particular these objects - won't report on constraints or primary keys, which while this doesn't - necessarily disqualify them from being part of a mapper, we don't - have assumptions figured out right now to accommodate them. + def test_tableclause_is_ok(self): + """2.0 during typing added a rule to disallow mappers to TableClause, + however we clearly allow mappings to any FromClause including + TableClause, this is how the map to views recipe works. - found_during_type_annotation + found_during_type_annotation -> then introduced a regression :( + + issue #9071 """ + User = self.classes.User + Address = self.classes.Address users = self.tables.users address = table( "address", - column("address_id", Integer), + Column("address_id", Integer, primary_key=True), column("user_id", Integer), ) + # manufacture the primary key collection which is otherwise + # not auto-populated from the above + address.primary_key.add(address.c.address_id) - with expect_raises_message( - sa.exc.ArgumentError, - "ORM mappings can only be made against schema-level Table " - "objects, not TableClause; got tableclause 'address'", - ): - self.mapper_registry.map_imperatively( - User, users.join(address, users.c.id == address.c.user_id) - ) + self.mapper_registry.map_imperatively( + User, users.join(address, users.c.id == address.c.user_id) + ) + + self.mapper_registry.map_imperatively(Address, address) + + self.assert_compile( + select(User), + "SELECT users.id, users.name, address.address_id, " + "address.user_id FROM users " + "JOIN address ON users.id = address.user_id", + ) + + self.assert_compile( + select(Address), + "SELECT address.address_id, address.user_id FROM address", + ) def test_reconfigure_on_other_mapper(self): """A configure trigger on an already-configured mapper -- 2.47.2