--- /dev/null
+.. change::
+ :tags: bug, orm
+ :tickets: 9805
+
+ As more projects are using new-style "2.0" ORM querying, it's becoming
+ apparent that the conditional nature of "autoflush", being based on whether
+ or not the given statement refers to ORM entities, is becoming more of a
+ key behavior. Up until now, the "ORM" flag for a statement has been loosely
+ based around whether or not the statement returns rows that correspond to
+ ORM entities or columns; the original purpose of the "ORM" flag was to
+ enable ORM-entity fetching rules which apply post-processing to Core result
+ sets as well as ORM loader strategies to the statement. For statements
+ that don't build on rows that contain ORM entities, the "ORM" flag was
+ considered to be mostly unnecessary.
+
+ It still may be the case that "autoflush" would be better taking effect for
+ *all* usage of :meth:`_orm.Session.execute` and related methods, even for
+ purely Core SQL constructs. However, this still could impact legacy cases
+ where this is not expected and may be more of a 2.1 thing. For now however,
+ the rules for the "ORM-flag" have been opened up so that a statement that
+ includes ORM entities or attributes anywhere within, including in the WHERE
+ / ORDER BY / GROUP BY clause alone, within scalar subqueries, etc. will
+ enable this flag. This will cause "autoflush" to occur for such statements
+ and also be visible via the :attr:`_orm.ORMExecuteState.is_orm_statement`
+ event-level attribute.
+
+
update_options += {"_subject_mapper": plugin_subject.mapper}
- if not isinstance(params, list):
+ if "parententity" not in statement.table._annotations:
+ update_options += {"_dml_strategy": "core_only"}
+ elif not isinstance(params, list):
if update_options._dml_strategy == "auto":
update_options += {"_dml_strategy": "orm"}
elif update_options._dml_strategy == "bulk":
if toplevel and dml_strategy == "bulk":
self._setup_for_bulk_update(statement, compiler)
+ elif (
+ dml_strategy == "core_only"
+ or dml_strategy == "unspecified"
+ and "parententity" not in statement.table._annotations
+ ):
+ UpdateDMLState.__init__(self, statement, compiler, **kw)
elif not toplevel or dml_strategy in ("orm", "unspecified"):
self._setup_for_orm_update(statement, compiler)
"_sa_orm_update_options", cls.default_update_options
)
- if update_options._dml_strategy not in ("orm", "auto", "bulk"):
+ if update_options._dml_strategy not in (
+ "orm",
+ "auto",
+ "bulk",
+ "core_only",
+ ):
raise sa_exc.ArgumentError(
"Valid strategies for ORM UPDATE strategy "
- "are 'orm', 'auto', 'bulk'"
+ "are 'orm', 'auto', 'bulk', 'core_only'"
)
result: _result.Result[Any]
def create_for_statement(cls, statement, compiler, **kw):
self = cls.__new__(cls)
+ dml_strategy = statement._annotations.get(
+ "dml_strategy", "unspecified"
+ )
+
+ if (
+ dml_strategy == "core_only"
+ or dml_strategy == "unspecified"
+ and "parententity" not in statement.table._annotations
+ ):
+ DeleteDMLState.__init__(self, statement, compiler, **kw)
+ return self
+
toplevel = not compiler.stack
orm_level_statement = statement
"session.connection().execute(stmt, parameters)"
)
- if update_options._dml_strategy not in (
- "orm",
- "auto",
- ):
+ if update_options._dml_strategy not in ("orm", "auto", "core_only"):
raise sa_exc.ArgumentError(
- "Valid strategies for ORM DELETE strategy are 'orm', 'auto'"
+ "Valid strategies for ORM DELETE strategy are 'orm', 'auto', "
+ "'core_only'"
)
return super().orm_execute_statement(
from ..sql.dml import UpdateBase
from ..sql.elements import GroupedElement
from ..sql.elements import TextClause
+from ..sql.selectable import CompoundSelectState
from ..sql.selectable import LABEL_STYLE_DISAMBIGUATE_ONLY
from ..sql.selectable import LABEL_STYLE_NONE
from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
raise NotImplementedError()
+class AutoflushOnlyORMCompileState(AbstractORMCompileState):
+ """ORM compile state that is a passthrough, except for autoflush."""
+
+ @classmethod
+ def orm_pre_session_exec(
+ cls,
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ is_pre_event,
+ ):
+
+ # consume result-level load_options. These may have been set up
+ # in an ORMExecuteState hook
+ (
+ load_options,
+ execution_options,
+ ) = QueryContext.default_load_options.from_execution_options(
+ "_sa_orm_load_options",
+ {
+ "autoflush",
+ },
+ execution_options,
+ statement._execution_options,
+ )
+
+ if not is_pre_event and load_options._autoflush:
+ session._autoflush()
+
+ return statement, execution_options
+
+ @classmethod
+ def orm_setup_cursor_result(
+ cls,
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ result,
+ ):
+ return result
+
+
class ORMCompileState(AbstractORMCompileState):
class default_compile_options(CacheableOptions):
_cache_key_traversal = [
return self.element._inline if is_insert_update(self.element) else None
+@sql.base.CompileState.plugin_for("orm", "compound_select")
+class CompoundSelectCompileState(
+ AutoflushOnlyORMCompileState, CompoundSelectState
+):
+ pass
+
+
@sql.base.CompileState.plugin_for("orm", "select")
class ORMSelectCompileState(ORMCompileState, SelectState):
_already_joined_edges = ()
if typing.TYPE_CHECKING:
assert isinstance(resolved, (SQLCoreOperations, ClauseElement))
- if (
- not apply_propagate_attrs._propagate_attrs
- and resolved._propagate_attrs
+ if not apply_propagate_attrs._propagate_attrs and getattr(
+ resolved, "_propagate_attrs", None
):
apply_propagate_attrs._propagate_attrs = resolved._propagate_attrs
for criterion in whereclause:
where_criteria: ColumnElement[Any] = coercions.expect(
- roles.WhereHavingRole, criterion
+ roles.WhereHavingRole, criterion, apply_propagate_attrs=self
)
self._where_criteria += (where_criteria,)
return self
# nulltype assignment issue
self.type = getattr(element, "type", type_api.NULLTYPE) # type: ignore
+ self._propagate_attrs = element._propagate_attrs
def _with_binary_element_type(self, type_):
return self.__class__(self.element._with_binary_element_type(type_))
self._order_by_clauses = ()
elif __first is not _NoArg.NO_ARG:
self._order_by_clauses += tuple(
- coercions.expect(roles.OrderByRole, clause)
+ coercions.expect(
+ roles.OrderByRole, clause, apply_propagate_attrs=self
+ )
for clause in (__first,) + clauses
)
return self
self._group_by_clauses = ()
elif __first is not _NoArg.NO_ARG:
self._group_by_clauses += tuple(
- coercions.expect(roles.GroupByRole, clause)
+ coercions.expect(
+ roles.GroupByRole, clause, apply_propagate_attrs=self
+ )
for clause in (__first,) + clauses
)
return self
):
self.keyword = keyword
self.selects = [
- coercions.expect(roles.CompoundElementRole, s).self_group(
- against=self
- )
+ coercions.expect(
+ roles.CompoundElementRole, s, apply_propagate_attrs=self
+ ).self_group(against=self)
for s in selects
]
for criterion in whereclause:
where_criteria: ColumnElement[Any] = coercions.expect(
- roles.WhereHavingRole, criterion
+ roles.WhereHavingRole, criterion, apply_propagate_attrs=self
)
self._where_criteria += (where_criteria,)
return self
for criterion in having:
having_criteria = coercions.expect(
- roles.WhereHavingRole, criterion
+ roles.WhereHavingRole, criterion, apply_propagate_attrs=self
)
self._having_criteria += (having_criteria,)
return self
if expr:
self._distinct = True
self._distinct_on = self._distinct_on + tuple(
- coercions.expect(roles.ByOfRole, e) for e in expr
+ coercions.expect(roles.ByOfRole, e, apply_propagate_attrs=self)
+ for e in expr
)
else:
self._distinct = True
def __init__(self, element: SelectBase) -> None:
self.element = element
self.type = element._scalar_type()
+ self._propagate_attrs = element._propagate_attrs
def __getattr__(self, attr: str) -> Any:
return getattr(self.element, attr)
),
(
lambda User: select(1).where(User.name == "ed"),
- # no mapper for this one because the plugin is not "orm"
- lambda User: {"clause": mock.ANY},
+ # changed by #9805
+ lambda User: {"clause": mock.ANY, "mapper": inspect(User)},
"e1",
),
(
from sqlalchemy import Column
from sqlalchemy import delete
from sqlalchemy import exc
+from sqlalchemy import exists
from sqlalchemy import ForeignKey
from sqlalchemy import func
from sqlalchemy import insert
from sqlalchemy import String
from sqlalchemy import testing
from sqlalchemy import text
+from sqlalchemy import true
from sqlalchemy import union
from sqlalchemy import update
from sqlalchemy import util
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import query_expression
from sqlalchemy.orm import relationship
+from sqlalchemy.orm import Session
from sqlalchemy.orm import undefer
from sqlalchemy.orm import with_expression
from sqlalchemy.orm import with_loader_criteria
from sqlalchemy.testing import eq_
from sqlalchemy.testing import fixtures
from sqlalchemy.testing import is_
+from sqlalchemy.testing import mock
from sqlalchemy.testing import Variation
from sqlalchemy.testing.fixtures import fixture_session
from sqlalchemy.testing.util import resolve_lambda
)
+class PropagateAttrsTest(QueryTest):
+ def propagate_cases():
+ return testing.combinations(
+ (lambda: select(1), False),
+ (lambda User: select(func.count(User.id)), True),
+ (
+ lambda User: select(1).select_from(select(User).subquery()),
+ True,
+ ),
+ (
+ lambda User: select(
+ select(User.id).where(User.id == 5).scalar_subquery()
+ ),
+ True,
+ ),
+ (
+ lambda User: select(
+ select(User.id).where(User.id == 5).label("x")
+ ),
+ True,
+ ),
+ (lambda User: select(1).select_from(User), True),
+ (lambda User: select(1).where(exists(User.id)), True),
+ (lambda User: select(1).where(~exists(User.id)), True),
+ (
+ # changed as part of #9805
+ lambda User: select(1).where(User.id == 1),
+ True,
+ ),
+ (
+ # changed as part of #9805
+ lambda User, user_table: select(func.count(1))
+ .select_from(user_table)
+ .group_by(user_table.c.id)
+ .having(User.id == 1),
+ True,
+ ),
+ (
+ # changed as part of #9805
+ lambda User, user_table: select(1)
+ .select_from(user_table)
+ .order_by(User.id),
+ True,
+ ),
+ (
+ # changed as part of #9805
+ lambda User, user_table: select(1)
+ .select_from(user_table)
+ .group_by(User.id),
+ True,
+ ),
+ (
+ lambda User, user_table: select(user_table).join(
+ aliased(User), true()
+ ),
+ True,
+ ),
+ (
+ # changed as part of #9805
+ lambda User, user_table: select(1)
+ .distinct(User.id)
+ .select_from(user_table),
+ True,
+ testing.requires.supports_distinct_on,
+ ),
+ (lambda user_table: select(user_table), False),
+ (lambda User: select(User), True),
+ (lambda User: union(select(User), select(User)), True),
+ (
+ lambda User: select(1).select_from(
+ union(select(User), select(User)).subquery()
+ ),
+ True,
+ ),
+ (lambda User: select(User.id), True),
+ # these are meaningless, correlate by itself has no effect
+ (lambda User: select(1).correlate(User), False),
+ (lambda User: select(1).correlate_except(User), False),
+ (lambda User: delete(User).where(User.id > 20), True),
+ (
+ lambda User, user_table: delete(user_table).where(
+ User.id > 20
+ ),
+ True,
+ ),
+ (lambda User: update(User).values(name="x"), True),
+ (
+ lambda User, user_table: update(user_table)
+ .values(name="x")
+ .where(User.id > 20),
+ True,
+ ),
+ (lambda User: insert(User).values(name="x"), True),
+ )
+
+ @propagate_cases()
+ def test_propagate_attr_yesno(self, test_case, expected):
+ User = self.classes.User
+ user_table = self.tables.users
+
+ stmt = resolve_lambda(test_case, User=User, user_table=user_table)
+
+ eq_(bool(stmt._propagate_attrs), expected)
+
+ @propagate_cases()
+ def test_autoflushes(self, test_case, expected):
+ User = self.classes.User
+ user_table = self.tables.users
+
+ stmt = resolve_lambda(test_case, User=User, user_table=user_table)
+
+ with Session(testing.db) as s:
+
+ with mock.patch.object(s, "_autoflush", wrap=True) as before_flush:
+ r = s.execute(stmt)
+ r.close()
+
+ if expected:
+ eq_(before_flush.mock_calls, [mock.call()])
+ else:
+ eq_(before_flush.mock_calls, [])
+
+
class DMLTest(QueryTest, AssertsCompiledSQL):
__dialect__ = "default"
@testing.combinations(
(lambda: select(1), True),
- (lambda User: select(User).union(select(User)), True),
+ (
+ lambda user_table: select(user_table).union(select(user_table)),
+ True,
+ ),
(lambda: text("select * from users"), False),
)
def test_non_orm_statements(self, stmt, is_select):
canary = self._flag_fixture(sess)
- User, Address = self.classes("User", "Address")
- stmt = testing.resolve_lambda(stmt, User=User)
+ user_table = self.tables.users
+ stmt = testing.resolve_lambda(stmt, user_table=user_table)
sess.execute(stmt).all()
eq_(