--- /dev/null
+.. change::
+ :tags: bug, orm
+ :tickets: 6060
+
+ Fixed bug where ORM queries using a correlated subquery in conjunction with
+ :func:`_orm.column_property` would fail to correlate correctly to an
+ enclosing subquery or to a CTE when :meth:`_sql.Select.correlate_except`
+ were used in the property to control correlation, in cases where the
+ subquery contained the same selectables as ones within the correlated
+ subquery that were intended to not be correlated.
equivalents=mapper._equivalent_columns,
adapt_on_names=adapt_on_names,
anonymize_labels=True,
+ # make sure the adapter doesn't try to grab other tables that
+ # are not even the thing we are mapping, such as embedded
+ # selectables in subqueries or CTEs. See issue #6060
+ adapt_from_selectables=[
+ m.selectable for m in self.with_polymorphic_mappers
+ ],
)
+
if inspected.is_aliased_class:
self._adapter = inspected._adapter.wrap(self._adapter)
exclude_fn=None,
adapt_on_names=False,
anonymize_labels=False,
+ adapt_from_selectables=None,
):
self.__traverse_options__ = {
"stop_on": [selectable],
self.exclude_fn = exclude_fn
self.equivalents = util.column_dict(equivalents or {})
self.adapt_on_names = adapt_on_names
+ self.adapt_from_selectables = adapt_from_selectables
def _corresponding_column(
self, col, require_embedded, _seen=util.EMPTY_SET
if isinstance(col, FromClause) and not isinstance(
col, functions.FunctionElement
):
+ if self.adapt_from_selectables:
+ for adp in self.adapt_from_selectables:
+ if adp.is_derived_from(col):
+ break
+ else:
+ return None
+
if self.selectable.is_derived_from(col):
return self.selectable
elif isinstance(col, Alias) and isinstance(
if "adapt_column" in col._annotations:
col = col._annotations["adapt_column"]
+ if self.adapt_from_selectables and col not in self.equivalents:
+ for adp in self.adapt_from_selectables:
+ if adp.c.corresponding_column(col, False) is not None:
+ break
+ else:
+ return None
+
if self.include_fn and not self.include_fn(col):
return None
elif self.exclude_fn and self.exclude_fn(col):
adapt_on_names=False,
allow_label_resolve=True,
anonymize_labels=False,
+ adapt_from_selectables=None,
):
ClauseAdapter.__init__(
self,
exclude_fn=exclude_fn,
adapt_on_names=adapt_on_names,
anonymize_labels=anonymize_labels,
+ adapt_from_selectables=adapt_from_selectables,
)
self.columns = util.WeakPopulateDict(self._locate_col)
# unfortunately there's a lot of cycles with an aliased()
# for now, however calling upon clause_element does not seem
# to make it worse which is what this was looking to test
- @assert_cycles(68)
+ @assert_cycles(69)
def go():
a1 = aliased(Foo)
a1.user_name.__clause_element__()
from sqlalchemy import String
from sqlalchemy import Table
from sqlalchemy import testing
+from sqlalchemy import Text
from sqlalchemy import text
from sqlalchemy import true
from sqlalchemy import union
from sqlalchemy import util
from sqlalchemy.engine import default
+from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import aliased
from sqlalchemy.orm import backref
from sqlalchemy.orm import clear_mappers
q = sess.query(User)
def go():
- ulist_alias = aliased(User, alias=query.alias("ulist"))
+ ulist = query.alias("ulist")
+ ulist_alias = aliased(User, alias=ulist)
result = (
- q.options(contains_eager("addresses"))
+ q.options(contains_eager("addresses", alias=ulist))
.select_entity_from(ulist_alias)
.all()
)
# all three columns are loaded independently without
# overlap, no additional SQL to load all attributes
self.assert_sql_count(testing.db, go, 0)
+
+
+class CorrelateORMTest(fixtures.TestBase, testing.AssertsCompiledSQL):
+ __dialect__ = "default"
+
+ @testing.fixture
+ def mapping(self):
+ Base = declarative_base()
+
+ def go(include_property, correlate_style, include_from):
+ class Address(Base):
+ __tablename__ = "addresses"
+
+ id = Column(Integer, primary_key=True)
+ user_id = Column(
+ Integer, ForeignKey("users.id"), nullable=False
+ )
+ city = Column(Text)
+
+ class User(Base):
+ __tablename__ = "users"
+
+ id = Column(Integer, primary_key=True)
+ name = Column(Text)
+
+ stmt = select(func.count(Address.id)).where(
+ Address.user_id == User.id
+ )
+ if include_from:
+ stmt = stmt.select_from(Address)
+
+ if include_property:
+ if correlate_style == "correlate":
+ User.total_addresses = column_property(
+ stmt.correlate(User).scalar_subquery()
+ )
+ elif correlate_style == "correlate_except":
+ User.total_addresses = column_property(
+ stmt.correlate_except(Address).scalar_subquery()
+ )
+ elif correlate_style is None:
+ User.total_addresses = column_property(
+ stmt.scalar_subquery()
+ )
+ total_addresses = None
+ else:
+
+ def total_addresses(cls):
+ stmt = select(func.count(Address.id)).where(
+ Address.user_id == cls.id
+ )
+
+ if correlate_style == "correlate":
+ stmt = stmt.correlate(cls)
+ elif correlate_style == "correlate_except":
+ stmt = stmt.correlate_except(Address)
+
+ stmt = stmt.scalar_subquery()
+
+ return stmt
+
+ return User, Address, total_addresses
+
+ yield go
+ Base.registry.dispose()
+
+ def _combinations(fn):
+
+ return testing.combinations(
+ (True,), (False,), argnames="include_property"
+ )(
+ testing.combinations(
+ ("correlate",),
+ ("correlate_except",),
+ (None,),
+ argnames="correlate_style",
+ )(
+ testing.combinations(
+ (True,), (False), argnames="include_from"
+ )(fn)
+ )
+ )
+
+ @_combinations
+ def test_correlate_to_cte_legacy(
+ self, mapping, include_property, correlate_style, include_from
+ ):
+ User, Address, total_addresses = mapping(
+ include_property, correlate_style, include_from
+ )
+ session = fixture_session()
+
+ filtered_users = (
+ session.query(User.id, User.name)
+ .join(Address)
+ .filter(Address.city == "somewhere")
+ .cte("filtered_users")
+ )
+
+ filtered_users_alias = aliased(User, filtered_users)
+
+ paginated_users = (
+ session.query(filtered_users_alias.id, filtered_users_alias.name)
+ .order_by(func.lower(filtered_users_alias.name).asc())
+ .limit(25)
+ .cte("paginated_users")
+ )
+
+ paginated_users_alias = aliased(User, paginated_users)
+
+ if total_addresses:
+ q = session.query(
+ paginated_users_alias, total_addresses(paginated_users_alias)
+ )
+ else:
+ q = session.query(paginated_users_alias)
+ self.assert_compile(
+ q,
+ "WITH filtered_users AS "
+ "(SELECT users.id AS id, users.name AS name "
+ "FROM users JOIN addresses ON users.id = addresses.user_id "
+ "WHERE addresses.city = :city_1), "
+ "paginated_users AS (SELECT filtered_users.id AS id, "
+ "filtered_users.name AS name FROM filtered_users "
+ "ORDER BY lower(filtered_users.name) ASC LIMIT :param_1) "
+ "SELECT "
+ "paginated_users.id AS paginated_users_id, "
+ "paginated_users.name AS paginated_users_name, "
+ "(SELECT count(addresses.id) AS count_1 FROM addresses "
+ "WHERE addresses.user_id = paginated_users.id) AS anon_1 "
+ "FROM paginated_users",
+ )
column_adapter = sql_util.ColumnAdapter(stmt2)
is_(column_adapter.columns[expr], stmt2.selected_columns[3])
- def test_correlate_except_on_clone(self):
+ @testing.combinations((True,), (False,), argnames="use_adapt_from")
+ def test_correlate_except_on_clone(self, use_adapt_from):
# test [ticket:4537]'s issue
t1alias = t1.alias("t1alias")
j = t1.join(t1alias, t1.c.col1 == t1alias.c.col2)
- vis = sql_util.ClauseAdapter(j)
+ if use_adapt_from:
+ vis = sql_util.ClauseAdapter(j, adapt_from_selectables=[t1])
+ else:
+ vis = sql_util.ClauseAdapter(j)
# "control" subquery - uses correlate which has worked w/ adaption
# for a long time
"JOIN table2 ON table1.col1 = table2.col1",
)
+ @testing.combinations((True,), (False,), argnames="use_adapt_from")
+ def test_correlate_except_with_mixed_tables(self, use_adapt_from):
+ # test [ticket:6060]'s issue
+
+ stmt = select(
+ t1.c.col1,
+ select(func.count(t2.c.col1))
+ .where(t2.c.col1 == t1.c.col1)
+ .correlate_except(t2)
+ .scalar_subquery(),
+ )
+ self.assert_compile(
+ stmt,
+ "SELECT table1.col1, "
+ "(SELECT count(table2.col1) AS count_1 FROM table2 "
+ "WHERE table2.col1 = table1.col1) AS anon_1 "
+ "FROM table1",
+ )
+
+ subq = (
+ select(t1)
+ .join(t2, t1.c.col1 == t2.c.col1)
+ .where(t2.c.col2 == "x")
+ .subquery()
+ )
+
+ if use_adapt_from:
+ vis = sql_util.ClauseAdapter(subq, adapt_from_selectables=[t1])
+ else:
+ vis = sql_util.ClauseAdapter(subq)
+
+ if use_adapt_from:
+ self.assert_compile(
+ vis.traverse(stmt),
+ "SELECT anon_1.col1, "
+ "(SELECT count(table2.col1) AS count_1 FROM table2 WHERE "
+ "table2.col1 = anon_1.col1) AS anon_2 "
+ "FROM (SELECT table1.col1 AS col1, table1.col2 AS col2, "
+ "table1.col3 AS col3 FROM table1 JOIN table2 ON table1.col1 = "
+ "table2.col1 WHERE table2.col2 = :col2_1) AS anon_1",
+ )
+ else:
+ # here's the buggy version. table2 gets yanked out of the
+ # correlated subquery also. AliasedClass now uses
+ # adapt_from_selectables in all cases
+ self.assert_compile(
+ vis.traverse(stmt),
+ "SELECT anon_1.col1, "
+ "(SELECT count(table2.col1) AS count_1 FROM table2, "
+ "(SELECT table1.col1 AS col1, table1.col2 AS col2, "
+ "table1.col3 AS col3 FROM table1 JOIN table2 ON "
+ "table1.col1 = table2.col1 WHERE table2.col2 = :col2_1) AS "
+ "anon_1 WHERE table2.col1 = anon_1.col1) AS anon_2 "
+ "FROM (SELECT table1.col1 AS col1, table1.col2 AS col2, "
+ "table1.col3 AS col3 FROM table1 JOIN table2 "
+ "ON table1.col1 = table2.col1 "
+ "WHERE table2.col2 = :col2_1) AS anon_1",
+ )
+
@testing.fails_on_everything_except()
def test_joins_dont_adapt(self):
# adapting to a join, i.e. ClauseAdapter(t1.join(t2)), doesn't
"addresses.user_id",
)
- def test_table_to_alias_1(self):
+ @testing.combinations((True,), (False,), argnames="use_adapt_from")
+ def test_table_to_alias_1(self, use_adapt_from):
t1alias = t1.alias("t1alias")
- vis = sql_util.ClauseAdapter(t1alias)
+ if use_adapt_from:
+ vis = sql_util.ClauseAdapter(t1alias, adapt_from_selectables=[t1])
+ else:
+ vis = sql_util.ClauseAdapter(t1alias)
ff = vis.traverse(func.count(t1.c.col1).label("foo"))
assert list(_from_objects(ff)) == [t1alias]
- def test_table_to_alias_2(self):
+ @testing.combinations((True,), (False,), argnames="use_adapt_from")
+ def test_table_to_alias_2(self, use_adapt_from):
t1alias = t1.alias("t1alias")
- vis = sql_util.ClauseAdapter(t1alias)
+ if use_adapt_from:
+ vis = sql_util.ClauseAdapter(t1alias, adapt_from_selectables=[t1])
+ else:
+ vis = sql_util.ClauseAdapter(t1alias)
self.assert_compile(
vis.traverse(select(literal_column("*")).select_from(t1)),
"SELECT * FROM table1 AS t1alias",
)
- def test_table_to_alias_3(self):
+ @testing.combinations((True,), (False,), argnames="use_adapt_from")
+ def test_table_to_alias_3(self, use_adapt_from):
t1alias = t1.alias("t1alias")
- vis = sql_util.ClauseAdapter(t1alias)
+ if use_adapt_from:
+ vis = sql_util.ClauseAdapter(t1alias, adapt_from_selectables=[t1])
+ else:
+ vis = sql_util.ClauseAdapter(t1alias)
self.assert_compile(
vis.traverse(
select(literal_column("*")).where(t1.c.col1 == t2.c.col2)
"WHERE t1alias.col1 = table2.col2",
)
- def test_table_to_alias_4(self):
+ @testing.combinations((True,), (False,), argnames="use_adapt_from")
+ def test_table_to_alias_4(self, use_adapt_from):
t1alias = t1.alias("t1alias")
- vis = sql_util.ClauseAdapter(t1alias)
+ if use_adapt_from:
+ vis = sql_util.ClauseAdapter(t1alias, adapt_from_selectables=[t1])
+ else:
+ vis = sql_util.ClauseAdapter(t1alias)
self.assert_compile(
vis.traverse(
select(literal_column("*"))
"WHERE t1alias.col1 = table2.col2",
)
- def test_table_to_alias_5(self):
+ @testing.combinations((True,), (False,), argnames="use_adapt_from")
+ def test_table_to_alias_5(self, use_adapt_from):
t1alias = t1.alias("t1alias")
- vis = sql_util.ClauseAdapter(t1alias)
+ if use_adapt_from:
+ vis = sql_util.ClauseAdapter(t1alias, adapt_from_selectables=[t1])
+ else:
+ vis = sql_util.ClauseAdapter(t1alias)
self.assert_compile(
select(t1alias, t2).where(
t1alias.c.col1
"(SELECT * FROM table2 WHERE t1alias.col1 = table2.col2)",
)
- def test_table_to_alias_6(self):
+ @testing.combinations((True,), (False,), argnames="use_adapt_from")
+ def test_table_to_alias_6(self, use_adapt_from):
t1alias = t1.alias("t1alias")
- vis = sql_util.ClauseAdapter(t1alias)
+ if use_adapt_from:
+ vis = sql_util.ClauseAdapter(t1alias, adapt_from_selectables=[t1])
+ else:
+ vis = sql_util.ClauseAdapter(t1alias)
self.assert_compile(
select(t1alias, t2).where(
t1alias.c.col1