From: Mike Bayer Date: Fri, 21 Nov 2025 15:41:40 +0000 (-0500) Subject: filter_by works across multiple entities X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=9a0d00433134b44a132104618b96516e47fff224;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git filter_by works across multiple entities The :meth:`_sql.Select.filter_by`, :meth:`_sql.Update.filter_by` and :meth:`_sql.Delete.filter_by` methods now search across all entities present in the statement, rather than limiting their search to only the last joined entity or the first FROM entity. This allows these methods to locate attributes unambiguously across multiple joined tables, resolving issues where changing the order of operations such as :meth:`_sql.Select.with_only_columns` would cause the method to fail. If an attribute name exists in more than one FROM clause entity, an :class:`_exc.AmbiguousColumnError` is now raised, indicating that :meth:`_sql.Select.filter` (or :meth:`_sql.Select.where`) should be used instead with explicit table-qualified column references. Fixes: #8601 Change-Id: I6a46b8f4784801f95f7980ca8ef92f1947653572 --- diff --git a/doc/build/changelog/migration_21.rst b/doc/build/changelog/migration_21.rst index bca5460cea..454a8562b2 100644 --- a/doc/build/changelog/migration_21.rst +++ b/doc/build/changelog/migration_21.rst @@ -721,6 +721,112 @@ up front, which would be verbose and not automatic. :ticket:`10635` +.. _change_8601: + +``filter_by()`` now searches across all FROM clause entities +------------------------------------------------------------- + +The :meth:`_sql.Select.filter_by` method, available for both Core +:class:`_sql.Select` objects and ORM-enabled select statements, has been +enhanced to search for attribute names across **all entities present in the +FROM clause** of the statement, rather than only looking at the last joined +entity or first FROM entity. + +This resolves a long-standing issue where the behavior of +:meth:`_sql.Select.filter_by` was sensitive to the order of operations. For +example, calling :meth:`_sql.Select.with_only_columns` after setting up joins +would reset which entity was searched, causing :meth:`_sql.Select.filter_by` +to fail even though the joined entity was still part of the FROM clause. + +Example - previously failing case now works:: + + from sqlalchemy import select, MetaData, Table, Column, Integer, String, ForeignKey + + metadata = MetaData() + + users = Table( + "users", + metadata, + Column("id", Integer, primary_key=True), + Column("name", String(50)), + ) + + addresses = Table( + "addresses", + metadata, + Column("id", Integer, primary_key=True), + Column("user_id", ForeignKey("users.id")), + Column("email", String(100)), + ) + + # This now works in 2.1 - previously raised an error + stmt = ( + select(users) + .join(addresses) + .with_only_columns(users.c.id) # changes selected columns + .filter_by(email="foo@bar.com") # searches addresses table successfully + ) + +Ambiguous Attribute Names +^^^^^^^^^^^^^^^^^^^^^^^^^^ + +When an attribute name exists in more than one entity in the FROM clause, +:meth:`_sql.Select.filter_by` now raises :class:`_exc.AmbiguousColumnError`, +indicating that :meth:`_sql.Select.filter` should be used instead with +explicit column references:: + + # Both users and addresses have 'id' column + stmt = select(users).join(addresses) + + # Raises AmbiguousColumnError in 2.1 + stmt = stmt.filter_by(id=5) + + # Use filter() with explicit qualification instead + stmt = stmt.filter(addresses.c.id == 5) + +The same behavior applies to ORM entities:: + + from sqlalchemy.orm import Session + + stmt = select(User).join(Address) + + # If both User and Address have an 'id' attribute, this raises + # AmbiguousColumnError + stmt = stmt.filter_by(id=5) + + # Use filter() with explicit entity qualification + stmt = stmt.filter(Address.id == 5) + +Legacy Query Use is Unchanged +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The change to :meth:`.Select.filter_by` has **not** been applied to the +:meth:`.Query.filter_by` method of :class:`.Query`; as :class:`.Query` is +a legacy API, its behavior hasn't changed. + +Migration Path +^^^^^^^^^^^^^^ + +Code that was previously working should continue to work without modification +in the vast majority of cases. The only breaking changes would be: + +1. **Ambiguous names that were previously accepted**: If your code had joins + where :meth:`_sql.Select.filter_by` happened to use an ambiguous column + name but it worked because it searched only one entity, this will now + raise :class:`_exc.AmbiguousColumnError`. The fix is to use + :meth:`_sql.Select.filter` with explicit column qualification. + +2. **Different entity selection**: In rare cases where the old behavior of + selecting the "last joined" or "first FROM" entity was being relied upon, + :meth:`_sql.Select.filter_by` might now find the attribute in a different + entity. Review any :meth:`_sql.Select.filter_by` calls in complex + multi-entity queries. + +It's hoped that in most cases, this change will make +:meth:`_sql.Select.filter_by` more intuitive to use. + +:ticket:`8601` + .. _change_11234: diff --git a/doc/build/changelog/unreleased_21/8601.rst b/doc/build/changelog/unreleased_21/8601.rst new file mode 100644 index 0000000000..313339824b --- /dev/null +++ b/doc/build/changelog/unreleased_21/8601.rst @@ -0,0 +1,20 @@ +.. change:: + :tags: usecase, sql, orm + :tickets: 8601 + + The :meth:`_sql.Select.filter_by`, :meth:`_sql.Update.filter_by` and + :meth:`_sql.Delete.filter_by` methods now search across all entities + present in the statement, rather than limiting their search to only the + last joined entity or the first FROM entity. This allows these methods + to locate attributes unambiguously across multiple joined tables, + resolving issues where changing the order of operations such as + :meth:`_sql.Select.with_only_columns` would cause the method to fail. + + If an attribute name exists in more than one FROM clause entity, an + :class:`_exc.AmbiguousColumnError` is now raised, indicating that + :meth:`_sql.Select.filter` (or :meth:`_sql.Select.where`) should be used + instead with explicit table-qualified column references. + + .. seealso:: + + :ref:`change_8601` - Migration notes diff --git a/doc/build/tutorial/data_select.rst b/doc/build/tutorial/data_select.rst index 0b55d06c56..24e25e0b34 100644 --- a/doc/build/tutorial/data_select.rst +++ b/doc/build/tutorial/data_select.rst @@ -415,8 +415,8 @@ of ORM entities:: For simple "equality" comparisons against a single entity, there's also a popular method known as :meth:`_sql.Select.filter_by` which accepts keyword -arguments that match to column keys or ORM attribute names. It will filter -against the leftmost FROM clause or the last entity joined:: +arguments that match to column keys or ORM attribute names. It searches +across all entities in the FROM clause for the given attribute names:: >>> print(select(User).filter_by(name="spongebob", fullname="Spongebob Squarepants")) {printsql}SELECT user_account.id, user_account.name, user_account.fullname diff --git a/lib/sqlalchemy/exc.py b/lib/sqlalchemy/exc.py index e2bf6d5fe8..6d54def5b8 100644 --- a/lib/sqlalchemy/exc.py +++ b/lib/sqlalchemy/exc.py @@ -364,6 +364,18 @@ class NoSuchColumnError(InvalidRequestError, KeyError): """A nonexistent column is requested from a ``Row``.""" +class AmbiguousColumnError(InvalidRequestError): + """Raised when a column/attribute name is ambiguous across multiple + entities. + + This can occur when using :meth:`_sql.Select.filter_by` with multiple + joined tables that have columns with the same name. + + .. versionadded:: 2.1 + + """ + + class NoResultFound(InvalidRequestError): """A database result was required but none was found. diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py index 848547a9dd..8abb20e127 100644 --- a/lib/sqlalchemy/orm/context.py +++ b/lib/sqlalchemy/orm/context.py @@ -73,7 +73,6 @@ from ..util.typing import TupleAny from ..util.typing import TypeVarTuple from ..util.typing import Unpack - if TYPE_CHECKING: from ._typing import _InternalEntityType from ._typing import OrmExecuteOptionsParameter @@ -1448,10 +1447,62 @@ class _ORMSelectCompileState(_ORMCompileState, SelectState): return self @classmethod - def determine_last_joined_entity(cls, statement): - setup_joins = statement._setup_joins + def _get_filter_by_entities(cls, statement): + """Return all ORM entities for filter_by() searches. + + the ORM version for Select is special vs. update/delete since it needs + to navigate along select.join() paths which have ORM specific + directives. + + beyond that, it delivers other entities as the Mapper or Aliased + object rather than the Table or Alias, which mostly affects + how error messages regarding ambiguous entities or entity not + found are rendered; class-specific attributes like hybrid, + column_property() etc. work either way since + _entity_namespace_key_search_all() uses _entity_namespace(). + + DML Update and Delete objects, even though they also have filter_by() + and also accept ORM objects, don't use this routine since they + typically just have a single table, and if they have multiple tables + it's only via WHERE clause, which interestingly do not maintain ORM + annotations when used (that is, (User.name == + 'foo').left.table._annotations is empty; the ORMness of User.name is + lost in the expression construction process, since we don't annotate + (copy) Column objects with ORM entities the way we do for Table. + + .. versionadded:: 2.1 + """ + + def _setup_join_targets(collection): + for (target, *_) in collection: + if isinstance(target, attributes.QueryableAttribute): + yield target.entity + elif "_no_filter_by" not in target._annotations: + yield target + + entities = set(_setup_join_targets(statement._setup_joins)) + + for memoized in statement._memoized_select_entities: + entities.update(_setup_join_targets(memoized._setup_joins)) + + entities.update( + ( + from_obj._annotations["parententity"] + if "parententity" in from_obj._annotations + else from_obj + ) + for from_obj in statement._from_obj + if "_no_filter_by" not in from_obj._annotations + ) + + for element in statement._raw_columns: + if "entity_namespace" in element._annotations: + ens = element._annotations["entity_namespace"] + entities.add(ens) + elif "_no_filter_by" not in element._annotations: + entities.update(element._from_objects) - return _determine_last_joined_entity(setup_joins, None) + return entities @classmethod def all_selected_columns(cls, statement): diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index c28c0a45d4..1000b1991f 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -1999,6 +1999,14 @@ class Query( entity of the query, or the last entity that was the target of a call to :meth:`_query.Query.join`. + .. note:: + + :class:`_query.Query` is a legacy construct as of SQLAlchemy 2.0. + See :meth:`_sql.Select.filter_by` for the comparable method on + 2.0-style :func:`_sql.select` constructs, where the behavior has + been enhanced in version 2.1 to search across all FROM clause + entities. See :ref:`change_8601` for background. + .. seealso:: :meth:`_query.Query.filter` - filter on SQL expressions. diff --git a/lib/sqlalchemy/orm/writeonly.py b/lib/sqlalchemy/orm/writeonly.py index b5aaf16e8c..d8183af86d 100644 --- a/lib/sqlalchemy/orm/writeonly.py +++ b/lib/sqlalchemy/orm/writeonly.py @@ -531,7 +531,13 @@ class _AbstractCollectionWriter(Generic[_T]): # note also, we are using the official ORM-annotated selectable # from __clause_element__(), see #7868 - self._from_obj = (prop.mapper.__clause_element__(), prop.secondary) + + # _no_filter_by annotation is to prevent this table from being + # considered by filter_by() as part of #8601 + self._from_obj = ( + prop.mapper.__clause_element__(), + prop.secondary._annotate({"_no_filter_by": True}), + ) else: self._from_obj = () diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index 67eb44fc8d..86b2662d8d 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -20,6 +20,7 @@ import re from typing import Any from typing import Callable from typing import cast +from typing import Collection from typing import Dict from typing import Final from typing import FrozenSet @@ -2551,3 +2552,50 @@ def _entity_namespace_key( raise exc.InvalidRequestError( 'Entity namespace for "%s" has no property "%s"' % (entity, key) ) from err + + +def _entity_namespace_key_search_all( + entities: Collection[Any], + key: str, +) -> SQLCoreOperations[Any]: + """Search multiple entities for a key, raise if ambiguous or not found. + + This is used by filter_by() to search across all FROM clause entities + when a single entity doesn't have the requested attribute. + + .. versionadded:: 2.1 + + Raises: + AmbiguousColumnError: If key exists in multiple entities + InvalidRequestError: If key doesn't exist in any entity + """ + + match_: SQLCoreOperations[Any] | None = None + + for entity in entities: + ns = _entity_namespace(entity) + # Check if the attribute exists + if hasattr(ns, key): + if match_ is not None: + entity_desc = ", ".join(str(e) for e in list(entities)[:3]) + if len(entities) > 3: + entity_desc += f", ... ({len(entities)} total)" + raise exc.AmbiguousColumnError( + f'Attribute name "{key}" is ambiguous; it exists in ' + f"multiple FROM clause entities ({entity_desc}). " + f"Use filter() with explicit column references instead " + f"of filter_by()." + ) + match_ = getattr(ns, key) + + if match_ is None: + # No entity has this attribute + entity_desc = ", ".join(str(e) for e in list(entities)[:3]) + if len(entities) > 3: + entity_desc += f", ... ({len(entities)} total)" + raise exc.InvalidRequestError( + f'None of the FROM clause entities have a property "{key}". ' + f"Searched entities: {entity_desc}" + ) + + return match_ diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py index e85e98a6f8..590d54db2e 100644 --- a/lib/sqlalchemy/sql/dml.py +++ b/lib/sqlalchemy/sql/dml.py @@ -38,7 +38,7 @@ from . import util as sql_util from ._typing import _unexpected_kw from ._typing import is_column_element from ._typing import is_named_from_clause -from .base import _entity_namespace_key +from .base import _entity_namespace_key_search_all from .base import _exclusive_against from .base import _from_objects from .base import _generative @@ -1545,18 +1545,52 @@ class DMLWhereBase: return self.where(*criteria) - def _filter_by_zero(self) -> _DMLTableElement: - return self.table - def filter_by(self, **kwargs: Any) -> Self: - r"""apply the given filtering criterion as a WHERE clause - to this select. + r"""Apply the given filtering criterion as a WHERE clause + to this DML statement, using keyword expressions. - """ - from_entity = self._filter_by_zero() + E.g.:: + + stmt = update(User).filter_by(name="some name").values(fullname="New Name") + + Multiple criteria may be specified as comma separated; the effect + is that they will be joined together using the :func:`.and_` + function:: + + stmt = delete(User).filter_by(name="some name", id=5) + + The keyword expressions are extracted by searching across **all + entities present in the FROM clause** of the statement. + + .. versionchanged:: 2.1 + + :meth:`.DMLWhereBase.filter_by` now searches across all FROM clause + entities, consistent with :meth:`_sql.Select.filter_by`. + + .. seealso:: + + :meth:`.DMLWhereBase.where` - filter on SQL expressions. + + :meth:`_sql.Select.filter_by` + + """ # noqa: E501 + + entities: set[Any] + + if not isinstance(self.table, TableClause): + entities = set( + sql_util.find_tables( + self.table, check_columns=False, include_joins=False + ) + ) + else: + entities = {self.table} + + if self.whereclause is not None: + entities.update(self.whereclause._from_objects) clauses = [ - _entity_namespace_key(from_entity, key) == value + _entity_namespace_key_search_all(entities, key) == value for key, value in kwargs.items() ] return self.filter(*clauses) diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 8113944caa..0668bc3672 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -21,6 +21,7 @@ from typing import Any as TODO_Any from typing import Any from typing import Callable from typing import cast +from typing import Collection from typing import Dict from typing import Generic from typing import Iterable @@ -61,7 +62,7 @@ from .annotation import SupportsCloneAnnotations from .base import _clone from .base import _cloned_difference from .base import _cloned_intersection -from .base import _entity_namespace_key +from .base import _entity_namespace_key_search_all from .base import _EntityNamespace from .base import _expand_cloned from .base import _from_objects @@ -110,7 +111,6 @@ from ..util.typing import Self from ..util.typing import TupleAny from ..util.typing import Unpack - and_ = BooleanClauseList.and_ @@ -5095,13 +5095,43 @@ class SelectState(util.MemoizedSlots, CompileState): return with_cols, only_froms, only_cols @classmethod - def determine_last_joined_entity( - cls, stmt: Select[Unpack[TupleAny]] - ) -> Optional[_JoinTargetElement]: - if stmt._setup_joins: - return stmt._setup_joins[-1][0] - else: - return None + def _get_filter_by_entities( + cls, statement: Select[Unpack[TupleAny]] + ) -> Collection[ + Union[FromClause, _JoinTargetProtocol, ColumnElement[Any]] + ]: + """Return all entities to search for filter_by() attributes. + + This includes: + + * All joined entities from _setup_joins + * Memoized entities from previous operations (e.g., + before with_only_columns) + * Explicit FROM objects from _from_obj + * Entities inferred from _raw_columns + + .. versionadded:: 2.1 + + """ + entities: set[ + Union[FromClause, _JoinTargetProtocol, ColumnElement[Any]] + ] + + entities = set( + join_element[0] for join_element in statement._setup_joins + ) + + for memoized in statement._memoized_select_entities: + entities.update( + join_element[0] for join_element in memoized._setup_joins + ) + + entities.update(statement._from_obj) + + for col in statement._raw_columns: + entities.update(col._from_objects) + + return entities @classmethod def all_selected_columns( @@ -5556,24 +5586,6 @@ class Select( return self.where(*criteria) - def _filter_by_zero( - self, - ) -> Union[ - FromClause, _JoinTargetProtocol, ColumnElement[Any], TextClause - ]: - if self._setup_joins: - meth = SelectState.get_plugin_class( - self - ).determine_last_joined_entity - _last_joined_entity = meth(self) - if _last_joined_entity is not None: - return _last_joined_entity - - if self._from_obj: - return self._from_obj[0] - - return self._raw_columns[0] - if TYPE_CHECKING: @overload @@ -5592,14 +5604,60 @@ class Select( def scalar_subquery(self) -> ScalarSelect[Any]: ... def filter_by(self, **kwargs: Any) -> Self: - r"""apply the given filtering criterion as a WHERE clause - to this select. + r"""Apply the given filtering criterion as a WHERE clause + to this select, using keyword expressions. + + E.g.:: + + stmt = select(User).filter_by(name="some name") + + Multiple criteria may be specified as comma separated; the effect + is that they will be joined together using the :func:`.and_` + function:: + + stmt = select(User).filter_by(name="some name", id=5) + + The keyword expressions are extracted by searching across **all + entities present in the FROM clause** of the statement. If a + keyword name is present in more than one entity, + :class:`_exc.AmbiguousColumnError` is raised. In this case, use + :meth:`_sql.Select.filter` or :meth:`_sql.Select.where` with + explicit column references:: + + # both User and Address have an 'id' attribute + stmt = select(User).join(Address).filter_by(id=5) + # raises AmbiguousColumnError + + # use filter() with explicit qualification instead + stmt = select(User).join(Address).filter(Address.id == 5) + + .. versionchanged:: 2.1 + + :meth:`_sql.Select.filter_by` now searches across all FROM clause + entities rather than only searching the last joined entity or first + FROM entity. This allows the method to locate attributes + unambiguously across multiple joined tables. The new + :class:`_exc.AmbiguousColumnError` is raised when an attribute name + is present in more than one entity. + + See :ref:`change_8601` for migration notes. + + .. seealso:: + + :ref:`tutorial_selecting_data` - in the :ref:`unified_tutorial` + + :meth:`_sql.Select.filter` - filter on SQL expressions. + + :meth:`_sql.Select.where` - filter on SQL expressions. """ - from_entity = self._filter_by_zero() + # Get all entities via plugin system + all_entities = SelectState.get_plugin_class( + self + )._get_filter_by_entities(self) clauses = [ - _entity_namespace_key(from_entity, key) == value + _entity_namespace_key_search_all(all_entities, key) == value for key, value in kwargs.items() ] return self.filter(*clauses) diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py index 05f8d355d6..f0c8fd48dd 100644 --- a/lib/sqlalchemy/testing/assertions.py +++ b/lib/sqlalchemy/testing/assertions.py @@ -673,7 +673,14 @@ class AssertsCompiledSQL: cc = re.sub(r"[\n\t]", "", str(c)) - eq_(cc, result, "%r != %r on dialect %r" % (cc, result, dialect)) + if isinstance(result, re.Pattern): + assert result.match(cc), "%r !~ %r on dialect %r" % ( + cc, + result, + dialect, + ) + else: + eq_(cc, result, "%r != %r on dialect %r" % (cc, result, dialect)) if checkparams is not None: if render_postcompile: diff --git a/test/base/test_except.py b/test/base/test_except.py index 2a45cdcc95..44f7931a11 100644 --- a/test/base/test_except.py +++ b/test/base/test_except.py @@ -440,6 +440,7 @@ ALL_EXC = [ ( [ sa_exceptions.ArgumentError, + sa_exceptions.AmbiguousColumnError, sa_exceptions.DuplicateColumnError, sa_exceptions.ConstraintColumnNotFoundError, sa_exceptions.NoSuchModuleError, diff --git a/test/ext/test_hybrid.py b/test/ext/test_hybrid.py index 97c81fd532..a184195417 100644 --- a/test/ext/test_hybrid.py +++ b/test/ext/test_hybrid.py @@ -5,6 +5,7 @@ from decimal import Decimal from typing import TYPE_CHECKING from sqlalchemy import column +from sqlalchemy import delete from sqlalchemy import exc from sqlalchemy import ForeignKey from sqlalchemy import from_dml_column @@ -710,6 +711,29 @@ class PropertyExpressionTest(fixtures.TestBase, AssertsCompiledSQL): a1 = A(_value=10) eq_(a1.value, 5) + @testing.variation("use_inplace", [True, False]) + @testing.variation("use_classmethod", [True, False]) + def test_filter_by_update_dml(self, use_inplace, use_classmethod): + A = self._fixture( + use_inplace=use_inplace, use_classmethod=use_classmethod + ) + self.assert_compile( + update(A).filter_by(value="foo").values(value="bar"), + "UPDATE a SET foo(value) + bar(value)=:param_1 " + "WHERE foo(a.value) + bar(a.value) = :param_2", + ) + + @testing.variation("use_inplace", [True, False]) + @testing.variation("use_classmethod", [True, False]) + def test_filter_by_delete_dml(self, use_inplace, use_classmethod): + A = self._fixture( + use_inplace=use_inplace, use_classmethod=use_classmethod + ) + self.assert_compile( + delete(A).filter_by(value="foo"), + "DELETE FROM a WHERE foo(a.value) + bar(a.value) = :param_1", + ) + class PropertyValueTest(fixtures.TestBase, AssertsCompiledSQL): __dialect__ = "default" diff --git a/test/orm/dml/test_orm_upd_del_assorted.py b/test/orm/dml/test_orm_upd_del_assorted.py index 28d729a3c8..cd1bcf0046 100644 --- a/test/orm/dml/test_orm_upd_del_assorted.py +++ b/test/orm/dml/test_orm_upd_del_assorted.py @@ -1,10 +1,13 @@ from __future__ import annotations +import re import uuid from sqlalchemy import Computed from sqlalchemy import delete +from sqlalchemy import exc from sqlalchemy import FetchedValue +from sqlalchemy import ForeignKey from sqlalchemy import insert from sqlalchemy import Integer from sqlalchemy import literal @@ -14,8 +17,11 @@ from sqlalchemy import testing from sqlalchemy import update from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column +from sqlalchemy.orm import relationship from sqlalchemy.orm import Session +from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import eq_ +from sqlalchemy.testing import expect_raises_message from sqlalchemy.testing import fixtures from sqlalchemy.testing.entities import ComparableEntity from sqlalchemy.testing.fixtures import fixture_session @@ -431,3 +437,340 @@ class PGIssue11849Test(fixtures.DeclarativeMappedTest): # synchronizes on load eq_(obj.test_field, {"test1": 1, "test2": "2", "test3": {"test4": 4}}) + + +class _FilterByDMLSuite(fixtures.MappedTest, AssertsCompiledSQL): + """Base test suite for filter_by() on ORM DML statements. + + Tests filter_by() functionality for UPDATE and DELETE with ORM entities, + verifying it can locate attributes across multiple joined tables and + raises AmbiguousColumnError for ambiguous names. + """ + + __dialect__ = "default_enhanced" + + @classmethod + def define_tables(cls, metadata): + Table( + "users", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("name", String(30), nullable=False), + Column("department_id", ForeignKey("departments.id")), + ) + Table( + "addresses", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("user_id", None, ForeignKey("users.id")), + Column("name", String(30), nullable=False), + Column("email_address", String(50), nullable=False), + ) + Table( + "dingalings", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("address_id", None, ForeignKey("addresses.id")), + Column("data", String(30)), + ) + Table( + "departments", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("name", String(30)), + ) + + @classmethod + def setup_classes(cls): + class User(cls.Comparable): + pass + + class Address(cls.Comparable): + pass + + class Dingaling(cls.Comparable): + pass + + class Department(cls.Comparable): + pass + + @classmethod + def setup_mappers(cls): + User = cls.classes.User + users = cls.tables.users + + Address = cls.classes.Address + addresses = cls.tables.addresses + + Dingaling = cls.classes.Dingaling + dingalings = cls.tables.dingalings + + Department = cls.classes.Department + departments = cls.tables.departments + + cls.mapper_registry.map_imperatively( + User, + users, + properties={ + "addresses": relationship(Address), + "department": relationship(Department), + }, + ) + cls.mapper_registry.map_imperatively( + Address, + addresses, + properties={"dingalings": relationship(Dingaling)}, + ) + cls.mapper_registry.map_imperatively(Dingaling, dingalings) + cls.mapper_registry.map_imperatively(Department, departments) + + def test_filter_by_basic(self, one_table_statement): + """Test filter_by with a single ORM entity.""" + stmt = one_table_statement + + stmt = stmt.filter_by(name="somename") + self.assert_compile( + stmt, + re.compile(r"(?:UPDATE|DELETE) .* WHERE users\.name = :name_1"), + params={"name_1": "somename"}, + ) + + def test_filter_by_two_tables_ambiguous_id(self, two_table_statement): + """Test filter_by raises error when 'id' is ambiguous.""" + stmt = two_table_statement + + # Filter by 'id' which exists in both tables - should raise error + with expect_raises_message( + exc.AmbiguousColumnError, + 'Attribute name "id" is ambiguous', + ): + stmt.filter_by(id=5) + + def test_filter_by_two_tables_secondary(self, two_table_statement): + """Test filter_by finds attribute in secondary table.""" + stmt = two_table_statement + + # Filter by 'email_address' which only exists in addresses table + stmt = stmt.filter_by(email_address="test@example.com") + self.assert_compile( + stmt, + re.compile( + r"(?:UPDATE|DELETE) .* addresses\.email_address = " + r":email_address_1" + ), + ) + + def test_filter_by_three_tables_ambiguous(self, three_table_statement): + """Test filter_by raises AmbiguousColumnError for ambiguous + names.""" + stmt = three_table_statement + + # interestingly, UPDATE/DELETE dont use an ORM specific version + # for filter_by() entity lookup, unlike SELECT + with expect_raises_message( + exc.AmbiguousColumnError, + 'Attribute name "name" is ambiguous; it exists in multiple FROM ' + "clause entities " + r"\((?:users(?:, )?|dingalings(?:, )?|addresses(?:, )?){3}\).", + ): + stmt.filter_by(name="ambiguous") + + def test_filter_by_four_tables_ambiguous(self, four_table_statement): + """test the ellipses version of the ambiguous message""" + stmt = four_table_statement + + # interestingly, UPDATE/DELETE dont use an ORM specific version + # for filter_by() entity lookup, unlike SELECT + with expect_raises_message( + exc.AmbiguousColumnError, + r'Attribute name "name" is ambiguous; it exists in multiple ' + r"FROM clause entities " + r"\((?:dingalings, |departments, |users, |addresses, ){3}\.\.\. " + r"\(4 total\)\)", + ): + stmt.filter_by(name="ambiguous") + + def test_filter_by_three_tables_notfound(self, three_table_statement): + """test the three or fewer table not found message""" + stmt = three_table_statement + + with expect_raises_message( + exc.InvalidRequestError, + r'None of the FROM clause entities have a property "unknown". ' + r"Searched entities: (?:dingalings(?:, )?" + r"|users(?:, )?|addresses(?:, )?){3}", + ): + stmt.filter_by(unknown="notfound") + + def test_filter_by_four_tables_notfound(self, four_table_statement): + """test the ellipses version of the not found message""" + stmt = four_table_statement + + with expect_raises_message( + exc.InvalidRequestError, + r'None of the FROM clause entities have a property "unknown". ' + r"Searched entities: " + r"(?:dingalings, |departments, |users, |addresses, ){3}\.\.\. " + r"\(4 total\)", + ): + stmt.filter_by(unknown="notfound") + + def test_filter_by_three_tables_primary(self, three_table_statement): + """Test filter_by finds attribute in primary table with three + tables.""" + stmt = three_table_statement + + # Filter by 'id' - ambiguous across all three tables + with expect_raises_message( + exc.AmbiguousColumnError, + 'Attribute name "id" is ambiguous', + ): + stmt.filter_by(id=5) + + def test_filter_by_three_tables_secondary(self, three_table_statement): + """Test filter_by finds attribute in secondary table.""" + stmt = three_table_statement + + # Filter by 'email_address' which only exists in Address + stmt = stmt.filter_by(email_address="test@example.com") + self.assert_compile( + stmt, + re.compile( + r"(?:UPDATE|DELETE) .* addresses\.email_address = " + r":email_address_1" + ), + ) + + def test_filter_by_three_tables_tertiary(self, three_table_statement): + """Test filter_by finds attribute in third table (Dingaling).""" + stmt = three_table_statement + + # Filter by 'data' which only exists in dingalings + stmt = stmt.filter_by(data="somedata") + self.assert_compile( + stmt, + re.compile(r"(?:UPDATE|DELETE) .* dingalings\.data = :data_1"), + ) + + def test_filter_by_three_tables_user_id(self, three_table_statement): + """Test filter_by finds user_id in Address (unambiguous).""" + stmt = three_table_statement + + # Filter by 'user_id' which only exists in addresses + stmt = stmt.filter_by(user_id=7) + self.assert_compile( + stmt, + re.compile( + r"(?:UPDATE|DELETE) .* addresses\.user_id = :user_id_1" + ), + ) + + def test_filter_by_three_tables_address_id(self, three_table_statement): + """Test filter_by finds address_id in Dingaling (unambiguous).""" + stmt = three_table_statement + + # Filter by 'address_id' which only exists in dingalings + stmt = stmt.filter_by(address_id=3) + self.assert_compile( + stmt, + re.compile( + r"(?:UPDATE|DELETE) .* dingalings\.address_id = " + r":address_id_1" + ), + ) + + +class UpdateFilterByTest(_FilterByDMLSuite): + @testing.fixture + def one_table_statement(self): + User = self.classes.User + + return update(User).values(name="newname") + + @testing.fixture + def two_table_statement(self): + User = self.classes.User + Address = self.classes.Address + + return ( + update(User) + .values(name="newname") + .where(User.id == Address.user_id) + ) + + @testing.fixture + def three_table_statement(self): + User = self.classes.User + Address = self.classes.Address + Dingaling = self.classes.Dingaling + + return ( + update(User) + .values(name="newname") + .where(User.id == Address.user_id) + .where(Address.id == Dingaling.address_id) + ) + + @testing.fixture + def four_table_statement(self): + User = self.classes.User + Address = self.classes.Address + Dingaling = self.classes.Dingaling + Department = self.classes.Department + + return ( + update(User) + .values(name="newname") + .where(User.id == Address.user_id) + .where(Address.id == Dingaling.address_id) + .where(Department.id == User.department_id) + ) + + +class DeleteFilterByTest(_FilterByDMLSuite): + @testing.fixture + def one_table_statement(self): + User = self.classes.User + + return delete(User) + + @testing.fixture + def two_table_statement(self): + User = self.classes.User + Address = self.classes.Address + + return delete(User).where(User.id == Address.user_id) + + @testing.fixture + def three_table_statement(self): + User = self.classes.User + Address = self.classes.Address + Dingaling = self.classes.Dingaling + + return ( + delete(User) + .where(User.id == Address.user_id) + .where(Address.id == Dingaling.address_id) + ) + + @testing.fixture + def four_table_statement(self): + User = self.classes.User + Address = self.classes.Address + Dingaling = self.classes.Dingaling + Department = self.classes.Department + + return ( + delete(User) + .where(User.id == Address.user_id) + .where(Address.id == Dingaling.address_id) + .where(Department.id == User.department_id) + ) diff --git a/test/orm/test_core_compilation.py b/test/orm/test_core_compilation.py index 3f6ddb28fe..4244fdbe7b 100644 --- a/test/orm/test_core_compilation.py +++ b/test/orm/test_core_compilation.py @@ -14,6 +14,7 @@ from sqlalchemy import null from sqlalchemy import or_ from sqlalchemy import select from sqlalchemy import String +from sqlalchemy import Table from sqlalchemy import testing from sqlalchemy import text from sqlalchemy import true @@ -45,6 +46,7 @@ from sqlalchemy.sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import eq_ +from sqlalchemy.testing import expect_raises_message from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ from sqlalchemy.testing import mock @@ -65,17 +67,6 @@ from ..sql.test_compiler import CorrelateTest as _CoreCorrelateTest class SelectableTest(QueryTest, AssertsCompiledSQL): __dialect__ = "default" - def test_filter_by(self): - User, Address = self.classes("User", "Address") - - stmt = select(User).filter_by(name="ed") - - self.assert_compile( - stmt, - "SELECT users.id, users.name FROM users " - "WHERE users.name = :name_1", - ) - def test_c_accessor_not_mutated_subq(self): """test #6394, ensure all_selected_columns is generated each time""" User = self.classes.User @@ -874,15 +865,22 @@ class JoinTest(QueryTest, AssertsCompiledSQL): "User", "Address", "Order", "Item", "Keyword" ) + # Note: Both Order and Item have 'description' column + # After joining Item, filter_by(description=...) would be ambiguous + # Use explicit filter() for Item.description to avoid ambiguity stmt = ( select(User) .filter_by(name="n1") .join(User.addresses) .filter_by(email_address="a1") .join_from(User, Order, User.orders) - .filter_by(description="d1") + .filter_by( + description="d1" + ) # Order.description (no ambiguity yet) .join(Order.items) - .filter_by(description="d2") + .filter( + Item.description == "d2" + ) # Use explicit filter() to avoid ambiguity ) self.assert_compile( stmt, @@ -3222,3 +3220,207 @@ class CrudParamOverlapTest(test_compiler.CrudParamOverlapTest): type_.fail() yield table1 + + +class FilterByTest(QueryTest, AssertsCompiledSQL): + __dialect__ = "default" + + def test_filter_by(self): + User, Address = self.classes("User", "Address") + + stmt = select(User).filter_by(name="ed") + + self.assert_compile( + stmt, + "SELECT users.id, users.name FROM users " + "WHERE users.name = :name_1", + ) + + def test_filter_by_w_join(self): + User, Address = self.classes("User", "Address") + + stmt = select(User).join(Address).filter_by(name="ed") + + self.assert_compile( + stmt, + "SELECT users.id, users.name FROM users JOIN addresses " + "ON users.id = addresses.user_id WHERE users.name = :name_1", + ) + + def test_filter_by_core_column_elem_only(self): + User, Address = self.classes("User", "Address") + + stmt = ( + select(Address.__table__.c.id) + .select_from(User) + .filter_by(email_address="ed") + ) + + self.assert_compile( + stmt, + "SELECT addresses.id FROM users, addresses " + "WHERE addresses.email_address = :email_address_1", + ) + + def test_filter_by_select_from(self): + User, Address = self.classes("User", "Address") + + stmt = select("*").select_from(User).filter_by(name="ed") + + self.assert_compile( + stmt, "SELECT * FROM users WHERE users.name = :name_1" + ) + + def test_filter_by_across_join_entities_issue_8601(self): + """Test issue #8601 - filter_by after with_only_columns.""" + User, Address = self.classes("User", "Address") + + # The original failing case from issue #8601 + stmt = ( + select(User) + .join(Address) + .with_only_columns(User.id) + .filter_by(email_address="foo@bar.com") + ) + + self.assert_compile( + stmt, + "SELECT users.id FROM users " + "JOIN addresses ON users.id = addresses.user_id " + "WHERE addresses.email_address = :email_address_1", + ) + + def test_filter_by_unambiguous_across_orm_joins(self): + """Test filter_by finds unambiguous attributes in ORM joins.""" + User, Address = self.classes("User", "Address") + + # email_address only exists in Address + stmt = ( + select(User) + .join(Address) + .filter_by(email_address="test@example.com") + ) + + self.assert_compile( + stmt, + "SELECT users.id, users.name FROM users " + "JOIN addresses ON users.id = addresses.user_id " + "WHERE addresses.email_address = :email_address_1", + ) + + def test_filter_by_searches_all_joined_entities(self): + """Test that filter_by searches all joined entities, not just last""" + User, Address, Order = self.classes("User", "Address", "Order") + + # Filter by Address attribute after joining to Order + stmt = ( + select(User) + .join(User.addresses) + .join(User.orders) + .filter_by(email_address="test@example.com") + ) + + self.assert_compile( + stmt, + "SELECT users.id, users.name FROM users " + "JOIN addresses ON users.id = addresses.user_id " + "JOIN orders ON users.id = orders.user_id " + "WHERE addresses.email_address = :email_address_1", + ) + + def test_filter_by_with_only_columns_preserves_joins(self): + """Verify with_only_columns doesn't affect filter_by entity search""" + User, Address = self.classes("User", "Address") + + # Change selected columns but still search all FROM entities + stmt = ( + select(User) + .join(User.addresses) + .with_only_columns(User.id, User.name) + .filter_by(email_address="foo") + ) + + self.assert_compile( + stmt, + "SELECT users.id, users.name FROM users " + "JOIN addresses ON users.id = addresses.user_id " + "WHERE addresses.email_address = :email_address_1", + ) + + def test_filter_by_column_not_in_any_orm_entity(self): + """Test error when attribute not found in any ORM entity""" + User, Address = self.classes("User", "Address") + + stmt = select(User).join(Address) + + with expect_raises_message( + exc.InvalidRequestError, + 'None of the FROM clause entities have a property "nonexistent"', + ): + stmt.filter_by(nonexistent="foo") + + @testing.fixture + def m2m_fixture(self, decl_base): + atob = Table( + "atob", + decl_base.metadata, + Column("a_id", ForeignKey("a.a_id")), + Column("b_id", ForeignKey("b.b_id")), + Column("association", String(50)), + ) + + class A(decl_base): + __tablename__ = "a" + + a_id: Mapped[int] = mapped_column(primary_key=True) + bs = relationship("B", secondary=atob) + + class B(decl_base): + __tablename__ = "b" + + b_id: Mapped[int] = mapped_column(primary_key=True) + + return A, B, atob + + def test_filter_by_ignores_secondary_w_overlap(self, m2m_fixture): + A, B, _ = m2m_fixture + stmt = select(A).join(A.bs).filter_by(a_id=5) + self.assert_compile( + stmt, + "SELECT a.a_id FROM a JOIN atob AS atob_1 ON a.a_id = atob_1.a_id " + "JOIN b ON b.b_id = atob_1.b_id WHERE a.a_id = :a_id_1", + ) + + def test_filter_by_ignores_secondary_will_raise(self, m2m_fixture): + A, B, _ = m2m_fixture + + with expect_raises_message( + exc.InvalidRequestError, + 'None of the FROM clause entities have a property "association". ' + r"Searched entities: Mapper\[(?:A|B).*], Mapper\[(?:A|B).*]", + ): + select(A).join(A.bs).filter_by(association="hi") + + @testing.variation("jointype", ["join", "froms"]) + def test_filter_by_with_table(self, m2m_fixture, jointype): + A, B, atob = m2m_fixture + + if jointype.join: + stmt = select(A).join(atob).filter_by(b_id=5) + self.assert_compile( + stmt, + "SELECT a.a_id FROM a JOIN atob ON a.a_id = atob.a_id " + "WHERE atob.b_id = :b_id_1", + ) + elif jointype.froms: + stmt = ( + select(A) + .select_from(A, atob) + .where(A.a_id == atob.c.a_id) + .filter_by(b_id=5) + ) + self.assert_compile( + stmt, + "SELECT a.a_id FROM a, atob WHERE a.a_id = atob.a_id " + "AND atob.b_id = :b_id_1", + ) diff --git a/test/orm/test_dynamic.py b/test/orm/test_dynamic.py index f35e5f0471..9661c22421 100644 --- a/test/orm/test_dynamic.py +++ b/test/orm/test_dynamic.py @@ -926,6 +926,9 @@ class WriteOnlyTest( "polymorphic_identity": "sub", } + # NOTE: keep filter_by(id=1) here because this also tests that an + # overlap issue does not occur with filter_by and the secondary table + # being explicitly added to _from_obj gp = GrandParent(id=1) make_transient_to_detached(gp) self.assert_compile( diff --git a/test/orm/test_events.py b/test/orm/test_events.py index e08987eb80..c1da9b06a4 100644 --- a/test/orm/test_events.py +++ b/test/orm/test_events.py @@ -455,7 +455,7 @@ class ORMExecuteTest(RemoveORMEventsGlobally, _fixtures.FixtureTest): sess.execute( select(User.id, Address.email_address, User.name) .join(Address) - .filter_by(id=7) + .filter_by(name="somename") ) eq_( @@ -510,7 +510,7 @@ class ORMExecuteTest(RemoveORMEventsGlobally, _fixtures.FixtureTest): canary = self._flag_fixture(sess) - sess.execute(select(User).join(Address).filter_by(id=7)) + sess.execute(select(User).join(Address).filter_by(name="somename")) eq_( canary.mock_calls, diff --git a/test/sql/test_delete.py b/test/sql/test_delete.py index 2c3187fd2d..83f92dcdd0 100644 --- a/test/sql/test_delete.py +++ b/test/sql/test_delete.py @@ -17,6 +17,7 @@ from sqlalchemy.testing import expect_deprecated from sqlalchemy.testing import fixtures from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table +from .test_update import _FilterByDMLSuite class _DeleteTestBase: @@ -381,3 +382,44 @@ class DeleteFromRoundTripTest(fixtures.TablesTest): def _assert_table(self, connection, table, expected): stmt = table.select().order_by(table.c.id) eq_(connection.execute(stmt).fetchall(), expected) + + +class DeleteFilterByTest(_FilterByDMLSuite): + @testing.fixture + def one_table_statement(self): + users = self.tables.users + + return users.delete() + + @testing.fixture + def two_table_statement(self): + users = self.tables.users + addresses = self.tables.addresses + + return users.delete().where(users.c.id == addresses.c.user_id) + + @testing.fixture + def three_table_statement(self): + users = self.tables.users + addresses = self.tables.addresses + dingalings = self.tables.dingalings + + return ( + users.delete() + .where(users.c.id == addresses.c.user_id) + .where(addresses.c.id == dingalings.c.address_id) + ) + + @testing.fixture + def four_table_statement(self): + users = self.tables.users + addresses = self.tables.addresses + dingalings = self.tables.dingalings + departments = self.tables.departments + + return ( + users.delete() + .where(users.c.id == addresses.c.user_id) + .where(addresses.c.id == dingalings.c.address_id) + .where(departments.c.id == users.c.department_id) + ) diff --git a/test/sql/test_select.py b/test/sql/test_select.py index e655003781..1b810073e3 100644 --- a/test/sql/test_select.py +++ b/test/sql/test_select.py @@ -408,13 +408,17 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): ) def test_joins_w_filter_by(self): + # Note: Both parent and child have a "data" column + # After the join, filter_by will see both entities + # To avoid ambiguity, filter first on parent before join, or use + # filter() with explicit column references stmt = ( select(parent) - .filter_by(data="p1") + .filter_by(data="p1") # Filter parent.data before the join .join(child) - .filter_by(data="c1") + .filter(child.c.data == "c1") # Explicit to avoid ambiguity .join_from(table1, table2, table1.c.myid == table2.c.otherid) - .filter_by(otherid=5) + .filter_by(otherid=5) # otherid is unambiguous ) self.assert_compile( @@ -482,7 +486,8 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): def test_filter_by_no_property_from_table(self): assert_raises_message( exc.InvalidRequestError, - 'Entity namespace for "mytable" has no property "foo"', + 'None of the FROM clause entities have a property "foo". ' + "Searched entities: mytable", select(table1).filter_by, foo="bar", ) @@ -490,11 +495,100 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): def test_filter_by_no_property_from_col(self): assert_raises_message( exc.InvalidRequestError, - 'Entity namespace for "mytable.myid" has no property "foo"', + 'None of the FROM clause entities have a property "foo". ' + "Searched entities: mytable", select(table1.c.myid).filter_by, foo="bar", ) + def test_filter_by_across_join_entities_issue_8601(self): + """Test issue #8601 - filter_by after with_only_columns.""" + # The original failing case from issue #8601 + # Use 'parent_id' which only exists in child table + stmt = ( + select(parent) + .join(child) + .with_only_columns(parent.c.id) + .filter_by(parent_id=5) + ) + self.assert_compile( + stmt, + "SELECT parent.id FROM parent " + "JOIN child ON parent.id = child.parent_id " + "WHERE child.parent_id = :parent_id_1", + checkparams={"parent_id_1": 5}, + ) + + def test_filter_by_ambiguous_column_error(self): + """Test filter_by() raises AmbiguousColumnError.""" + # Both parent and child have 'data' column + stmt = select(parent).join(child) + + with expect_raises_message( + exc.AmbiguousColumnError, + 'Attribute name "data" is ambiguous; it exists in multiple ' + r"FROM clause entities \((?:parent(?:, )?" + r"|child(?:, )?){2}\).", + ): + stmt.filter_by(data="foo") + + def test_filter_by_unambiguous_across_joins(self): + """Test filter_by finds unambiguous columns across multiple joins""" + # 'parent_id' only exists in child + stmt = select(parent).join(child).filter_by(parent_id=5) + + self.assert_compile( + stmt, + "SELECT parent.id, parent.data FROM parent " + "JOIN child ON parent.id = child.parent_id " + "WHERE child.parent_id = :parent_id_1", + checkparams={"parent_id_1": 5}, + ) + + def test_filter_by_column_not_in_any_entity(self): + """Test error when attribute not found in any FROM entity""" + stmt = select(parent).join(child) + + with expect_raises_message( + exc.InvalidRequestError, + 'None of the FROM clause entities have a property "nonexistent". ' + r"Searched entities: (?:parent(?:, )?" + r"|child(?:, )?){2}", + ): + stmt.filter_by(nonexistent="foo") + + def test_filter_by_multiple_joins(self): + """Test filter_by() with multiple joins""" + # grandchild has unique 'child_id' column + stmt = ( + select(parent) + .join(child, parent.c.id == child.c.parent_id) + .join(grandchild, child.c.id == grandchild.c.child_id) + .filter_by(child_id=3) + ) + + self.assert_compile( + stmt, + "SELECT parent.id, parent.data FROM parent " + "JOIN child ON parent.id = child.parent_id " + "JOIN grandchild ON child.id = grandchild.child_id " + "WHERE grandchild.child_id = :child_id_1", + checkparams={"child_id_1": 3}, + ) + + def test_filter_by_explicit_from_with_join(self): + """Test filter_by with explicit FROM and joins""" + stmt = select(parent.c.id).select_from(parent).join(child) + + # Should be ambiguous since both have 'data' + with expect_raises_message( + exc.AmbiguousColumnError, + 'Attribute name "data" is ambiguous; it exists in multiple ' + r"FROM clause entities \((?:parent(?:, )?" + r"|child(?:, )?){2}\).", + ): + stmt.filter_by(data="child_data") + def test_select_tuple_outer(self): stmt = select(tuple_(table1.c.myid, table1.c.name)) diff --git a/test/sql/test_update.py b/test/sql/test_update.py index e5991663dc..b0c520eb3f 100644 --- a/test/sql/test_update.py +++ b/test/sql/test_update.py @@ -1,5 +1,6 @@ import itertools import random +import re from sqlalchemy import bindparam from sqlalchemy import cast @@ -1934,3 +1935,231 @@ class UpdateFromMultiTableUpdateDefaultsTest( def _assert_users(self, connection, users, expected): stmt = users.select().order_by(users.c.id) eq_(connection.execute(stmt).fetchall(), expected) + + +class _FilterByDMLSuite(fixtures.TablesTest, AssertsCompiledSQL): + __dialect__ = "default_enhanced" + + @classmethod + def define_tables(cls, metadata): + Table( + "users", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("department_id", ForeignKey("departments.id")), + Column("name", String(30), nullable=False), + ) + Table( + "addresses", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("user_id", None, ForeignKey("users.id")), + Column("name", String(30), nullable=False), + Column("email_address", String(50), nullable=False), + ) + Table( + "dingalings", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("address_id", None, ForeignKey("addresses.id")), + Column("data", String(30)), + ) + Table( + "departments", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("name", String(30)), + ) + + def test_filter_by_basic(self, one_table_statement): + """Test filter_by with a single table.""" + stmt = one_table_statement + + stmt = stmt.filter_by(name="somename") + self.assert_compile( + stmt, + re.compile(r"(?:UPDATE|DELETE) .* WHERE users\.name = :name_1"), + params={"name_1": "somename"}, + ) + + def test_filter_by_three_tables_ambiguous(self, three_table_statement): + """test the three or fewer table ambiguous message""" + stmt = three_table_statement + + with expect_raises_message( + exc.AmbiguousColumnError, + r'Attribute name "name" is ambiguous; it exists in multiple ' + r"FROM clause entities \((?:dingalings(?:, )?" + r"|users(?:, )?|addresses(?:, )?){3}\).", + ): + stmt.filter_by(name="ambiguous") + + def test_filter_by_four_tables_ambiguous(self, four_table_statement): + """test the ellipses version of the ambiguous message""" + stmt = four_table_statement + + with expect_raises_message( + exc.AmbiguousColumnError, + r'Attribute name "name" is ambiguous; it exists in multiple ' + r"FROM clause entities " + r"\((?:dingalings, |departments, |users, |addresses, ){3}\.\.\. " + r"\(4 total\)\)", + ): + stmt.filter_by(name="ambiguous") + + def test_filter_by_three_tables_notfound(self, three_table_statement): + """test the three or fewer table not found message""" + stmt = three_table_statement + + with expect_raises_message( + exc.InvalidRequestError, + r'None of the FROM clause entities have a property "unknown". ' + r"Searched entities: (?:dingalings(?:, )?" + r"|users(?:, )?|addresses(?:, )?){3}", + ): + stmt.filter_by(unknown="notfound") + + def test_filter_by_four_tables_notfound(self, four_table_statement): + """test the ellipses version of the not found message""" + stmt = four_table_statement + + with expect_raises_message( + exc.InvalidRequestError, + r'None of the FROM clause entities have a property "unknown". ' + r"Searched entities: " + r"(?:dingalings, |departments, |users, |addresses, ){3}\.\.\. " + r"\(4 total\)", + ): + stmt.filter_by(unknown="notfound") + + def test_filter_by_two_tables_secondary(self, two_table_statement): + """Test filter_by finds attribute in secondary table (addresses).""" + stmt = two_table_statement + + # Filter by 'email_address' which only exists in addresses table + stmt = stmt.filter_by(email_address="test@example.com") + self.assert_compile( + stmt, + re.compile( + r"(?:UPDATE|DELETE) .* addresses\.email_address = " + r":email_address_1" + ), + ) + + def test_filter_by_three_tables_primary(self, three_table_statement): + """Test filter_by finds attribute in primary table with three + tables.""" + stmt = three_table_statement + + # Filter by 'id' - ambiguous across all three tables + with expect_raises_message( + exc.AmbiguousColumnError, + 'Attribute name "id" is ambiguous', + ): + stmt.filter_by(id=5) + + def test_filter_by_three_tables_secondary(self, three_table_statement): + """Test filter_by finds attribute in secondary table (addresses).""" + stmt = three_table_statement + + # Filter by 'email_address' which only exists in addresses + stmt = stmt.filter_by(email_address="test@example.com") + self.assert_compile( + stmt, + re.compile( + r"(?:UPDATE|DELETE) .* addresses\.email_address = " + r":email_address_1" + ), + ) + + def test_filter_by_three_tables_tertiary(self, three_table_statement): + """Test filter_by finds attribute in third table (dingalings).""" + stmt = three_table_statement + + # Filter by 'data' which only exists in dingalings + stmt = stmt.filter_by(data="somedata") + self.assert_compile( + stmt, + re.compile(r"(?:UPDATE|DELETE) .* dingalings\.data = :data_1"), + ) + + def test_filter_by_three_tables_user_id(self, three_table_statement): + """Test filter_by finds user_id in addresses (unambiguous).""" + stmt = three_table_statement + + # Filter by 'user_id' which only exists in addresses + stmt = stmt.filter_by(user_id=7) + self.assert_compile( + stmt, + re.compile( + r"(?:UPDATE|DELETE) .* addresses\.user_id = :user_id_1" + ), + ) + + def test_filter_by_three_tables_address_id(self, three_table_statement): + """Test filter_by finds address_id in dingalings (unambiguous).""" + stmt = three_table_statement + + # Filter by 'address_id' which only exists in dingalings + stmt = stmt.filter_by(address_id=3) + self.assert_compile( + stmt, + re.compile( + r"(?:UPDATE|DELETE) .* dingalings\.address_id = :address_id_1" + ), + ) + + +class UpdateFilterByTest(_FilterByDMLSuite): + @testing.fixture + def one_table_statement(self): + users = self.tables.users + + return users.update().values(name="newname") + + @testing.fixture + def two_table_statement(self): + users = self.tables.users + addresses = self.tables.addresses + + return ( + users.update() + .values(name="newname") + .where(users.c.id == addresses.c.user_id) + ) + + @testing.fixture + def three_table_statement(self): + users = self.tables.users + addresses = self.tables.addresses + dingalings = self.tables.dingalings + + return ( + users.update() + .values(name="newname") + .where(users.c.id == addresses.c.user_id) + .where(addresses.c.id == dingalings.c.address_id) + ) + + @testing.fixture + def four_table_statement(self): + users = self.tables.users + addresses = self.tables.addresses + dingalings = self.tables.dingalings + departments = self.tables.departments + + return ( + users.update() + .values(name="newname") + .where(users.c.id == addresses.c.user_id) + .where(addresses.c.id == dingalings.c.address_id) + .where(departments.c.id == users.c.department_id) + )