from ..sql._typing import _HasClauseElement
from ..sql.elements import ColumnClause
from ..sql.elements import ColumnElement
+from ..sql.util import _deep_annotate
from ..sql.util import _deep_deannotate
from ..sql.util import _shallow_annotate
from ..sql.util import adapt_criterion_to_null
from ..sql._typing import _EquivalentColumnMap
from ..sql._typing import _InfoType
from ..sql.annotation import _AnnotationDict
+ from ..sql.annotation import SupportsAnnotations
from ..sql.elements import BinaryExpression
from ..sql.elements import BindParameter
from ..sql.elements import ClauseElement
primaryjoin = primaryjoin & single_crit
if extra_criteria:
+
+ def mark_unrelated_columns_as_ok_to_adapt(
+ elem: SupportsAnnotations, annotations: _AnnotationDict
+ ) -> SupportsAnnotations:
+ """note unrelated columns in the "extra criteria" as OK
+ to adapt, even though they are not part of our "local"
+ or "remote" side.
+
+ see #9779 for this case
+
+ """
+
+ parentmapper_for_element = elem._annotations.get(
+ "parentmapper", None
+ )
+ if (
+ parentmapper_for_element is not self.prop.parent
+ and parentmapper_for_element is not self.prop.mapper
+ ):
+ return elem._annotate(annotations)
+ else:
+ return elem
+
+ extra_criteria = tuple(
+ _deep_annotate(
+ elem,
+ {"ok_to_adapt_in_join_condition": True},
+ annotate_callable=mark_unrelated_columns_as_ok_to_adapt,
+ )
+ for elem in extra_criteria
+ )
+
if secondaryjoin is not None:
secondaryjoin = secondaryjoin & sql.and_(*extra_criteria)
else:
self.name = name
def __call__(self, c: ClauseElement) -> bool:
- return self.name in c._annotations
+ return (
+ self.name in c._annotations
+ or "ok_to_adapt_in_join_condition" in c._annotations
+ )
class Relationship( # type: ignore
element: _SA,
annotations: _AnnotationDict,
exclude: Optional[Sequence[SupportsAnnotations]] = None,
+ *,
detect_subquery_cols: bool = False,
ind_cols_on_fromclause: bool = False,
+ annotate_callable: Optional[
+ Callable[[SupportsAnnotations, _AnnotationDict], SupportsAnnotations]
+ ] = None,
) -> _SA:
"""Deep copy the given ClauseElement, annotating each element
with the given annotations dictionary.
newelem = elem._clone(clone=clone, **kw)
elif annotations != elem._annotations:
if detect_subquery_cols and elem._is_immutable:
- newelem = elem._clone(clone=clone, **kw)._annotate(annotations)
+ to_annotate = elem._clone(clone=clone, **kw)
else:
- newelem = elem._annotate(annotations)
+ to_annotate = elem
+ if annotate_callable:
+ newelem = annotate_callable(to_annotate, annotations)
+ else:
+ newelem = to_annotate._annotate(annotations)
else:
newelem = elem
from typing import Optional
+from sqlalchemy import and_
from sqlalchemy import exists
from sqlalchemy import ForeignKey
from sqlalchemy import func
from sqlalchemy.testing import AssertsExecutionResults
from sqlalchemy.testing import config
from sqlalchemy.testing import eq_
+from sqlalchemy.testing import expect_warnings
from sqlalchemy.testing import fixtures
from sqlalchemy.testing.fixtures import ComparableEntity
from sqlalchemy.testing.fixtures import fixture_session
assert False
self._run_load(opt)
+
+
+class AdaptExistsSubqTest(fixtures.DeclarativeMappedTest):
+ """test for #9777"""
+
+ @classmethod
+ def setup_classes(cls):
+ Base = cls.DeclarativeBasic
+
+ class Discriminator(Base):
+ __tablename__ = "discriminator"
+ id = Column(Integer, primary_key=True, autoincrement=False)
+ value = Column(String(50))
+
+ class Entity(Base):
+ __tablename__ = "entity"
+ __mapper_args__ = {"polymorphic_on": "type"}
+
+ id = Column(Integer, primary_key=True, autoincrement=False)
+ type = Column(String(50))
+
+ discriminator_id = Column(
+ ForeignKey("discriminator.id"), nullable=False
+ )
+ discriminator = relationship(
+ "Discriminator", foreign_keys=discriminator_id
+ )
+
+ class Parent(Entity):
+ __tablename__ = "parent"
+ __mapper_args__ = {"polymorphic_identity": "parent"}
+
+ id = Column(Integer, ForeignKey("entity.id"), primary_key=True)
+ some_data = Column(String(30))
+
+ class Child(Entity):
+ __tablename__ = "child"
+ __mapper_args__ = {"polymorphic_identity": "child"}
+
+ id = Column(Integer, ForeignKey("entity.id"), primary_key=True)
+
+ some_data = Column(String(30))
+ parent_id = Column(ForeignKey("parent.id"), nullable=False)
+ parent = relationship(
+ "Parent",
+ foreign_keys=parent_id,
+ backref="children",
+ )
+
+ @classmethod
+ def insert_data(cls, connection):
+ Parent, Child, Discriminator = cls.classes(
+ "Parent", "Child", "Discriminator"
+ )
+
+ with Session(connection) as sess:
+ discriminator_zero = Discriminator(id=1, value="zero")
+ discriminator_one = Discriminator(id=2, value="one")
+ discriminator_two = Discriminator(id=3, value="two")
+
+ parent = Parent(id=1, discriminator=discriminator_zero)
+ child_1 = Child(
+ id=2,
+ discriminator=discriminator_one,
+ parent=parent,
+ some_data="c1data",
+ )
+ child_2 = Child(
+ id=3,
+ discriminator=discriminator_two,
+ parent=parent,
+ some_data="c2data",
+ )
+ sess.add_all([parent, child_1, child_2])
+ sess.commit()
+
+ def test_explicit_aliasing(self):
+ Parent, Child, Discriminator = self.classes(
+ "Parent", "Child", "Discriminator"
+ )
+
+ parent_id = 1
+ discriminator_one_id = 2
+
+ session = fixture_session()
+ c_alias = aliased(Child, flat=True)
+ retrieved = (
+ session.query(Parent)
+ .filter_by(id=parent_id)
+ .outerjoin(
+ Parent.children.of_type(c_alias).and_(
+ c_alias.discriminator.has(
+ and_(
+ Discriminator.id == discriminator_one_id,
+ c_alias.some_data == "c1data",
+ )
+ )
+ )
+ )
+ .options(contains_eager(Parent.children.of_type(c_alias)))
+ .populate_existing()
+ .one()
+ )
+ eq_(len(retrieved.children), 1)
+
+ def test_implicit_aliasing(self):
+ Parent, Child, Discriminator = self.classes(
+ "Parent", "Child", "Discriminator"
+ )
+
+ parent_id = 1
+ discriminator_one_id = 2
+
+ session = fixture_session()
+ q = (
+ session.query(Parent)
+ .filter_by(id=parent_id)
+ .outerjoin(
+ Parent.children.and_(
+ Child.discriminator.has(
+ and_(
+ Discriminator.id == discriminator_one_id,
+ Child.some_data == "c1data",
+ )
+ )
+ )
+ )
+ .options(contains_eager(Parent.children))
+ .populate_existing()
+ )
+
+ with expect_warnings("An alias is being generated automatically"):
+ retrieved = q.one()
+
+ eq_(len(retrieved.children), 1)
+
+ @testing.combinations(joinedload, selectinload, argnames="loader")
+ def test_eager_loaders(self, loader):
+ Parent, Child, Discriminator = self.classes(
+ "Parent", "Child", "Discriminator"
+ )
+
+ parent_id = 1
+ discriminator_one_id = 2
+
+ session = fixture_session()
+ retrieved = (
+ session.query(Parent)
+ .filter_by(id=parent_id)
+ .options(
+ loader(
+ Parent.children.and_(
+ Child.discriminator.has(
+ and_(
+ Discriminator.id == discriminator_one_id,
+ Child.some_data == "c1data",
+ )
+ )
+ )
+ )
+ )
+ .populate_existing()
+ .one()
+ )
+
+ eq_(len(retrieved.children), 1)