From: Mike Bayer Date: Wed, 3 Jun 2020 13:50:04 +0000 (-0400) Subject: Inline a few ORM arguments, others X-Git-Tag: rel_1_4_0b1~282 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=7f0cb933f2b1979a8d781855618b7fd3bf280037;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Inline a few ORM arguments, others small changes Change-Id: Id89a0651196c431d0aaf6935f5a4e7b12dd70c6c --- diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py index ba30d203bd..bd4074ea11 100644 --- a/lib/sqlalchemy/orm/context.py +++ b/lib/sqlalchemy/orm/context.py @@ -191,15 +191,9 @@ class ORMCompileState(CompileState): def orm_pre_session_exec( cls, session, statement, execution_options, bind_arguments ): - if execution_options: - # TODO: will have to provide public API to set some load - # options and also extract them from that API here, likely - # execution options - load_options = execution_options.get( - "_sa_orm_load_options", QueryContext.default_load_options - ) - else: - load_options = QueryContext.default_load_options + load_options = execution_options.get( + "_sa_orm_load_options", QueryContext.default_load_options + ) bind_arguments["clause"] = statement @@ -223,7 +217,9 @@ class ORMCompileState(CompileState): session._autoflush() @classmethod - def orm_setup_cursor_result(cls, session, bind_arguments, result): + def orm_setup_cursor_result( + cls, session, statement, execution_options, bind_arguments, result + ): execution_context = result.context compile_state = execution_context.compiled.compile_state @@ -231,13 +227,9 @@ class ORMCompileState(CompileState): # were passed to session.execute: # session.execute(legacy_select([User.id, User.name])) # see test_query->test_legacy_tuple_old_select - if not execution_context.compiled.statement._is_future: + if not statement._is_future: return result - execution_options = execution_context.execution_options - - # we are getting these right above in orm_pre_session_exec(), - # then getting them again right here. load_options = execution_options.get( "_sa_orm_load_options", QueryContext.default_load_options ) diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 25e2243487..ee42419a26 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -1297,17 +1297,16 @@ class Session(_SessionClassMethods): ) def _connection_for_bind(self, engine, execution_options=None, **kw): - self._autobegin() - - if self._transaction is not None: + if self._transaction is not None or self._autobegin(): return self._transaction._connection_for_bind( engine, execution_options ) - else: - conn = engine.connect(**kw) - if execution_options: - conn = conn.execution_options(**execution_options) - return conn + + assert self._transaction is None + conn = engine.connect(**kw) + if execution_options: + conn = conn.execution_options(**execution_options) + return conn def execute( self, @@ -1460,6 +1459,23 @@ class Session(_SessionClassMethods): compile_state_cls.orm_pre_session_exec( self, statement, execution_options, bind_arguments ) + + if self.dispatch.do_orm_execute: + skip_events = bind_arguments.pop("_sa_skip_events", False) + + if not skip_events: + orm_exec_state = ORMExecuteState( + self, + statement, + params, + execution_options, + bind_arguments, + ) + for fn in self.dispatch.do_orm_execute: + result = fn(orm_exec_state) + if result: + return result + else: compile_state_cls = None bind_arguments.setdefault("clause", statement) @@ -1468,18 +1484,6 @@ class Session(_SessionClassMethods): execution_options, {"future_result": True} ) - if self.dispatch.do_orm_execute: - skip_events = bind_arguments.pop("_sa_skip_events", False) - - if not skip_events: - orm_exec_state = ORMExecuteState( - self, statement, params, execution_options, bind_arguments - ) - for fn in self.dispatch.do_orm_execute: - result = fn(orm_exec_state) - if result: - return result - bind = self.get_bind(**bind_arguments) conn = self._connection_for_bind(bind, close_with_result=True) @@ -1487,7 +1491,7 @@ class Session(_SessionClassMethods): if compile_state_cls: result = compile_state_cls.orm_setup_cursor_result( - self, bind_arguments, result + self, statement, execution_options, bind_arguments, result ) return result diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 4bd19e04b5..f4160b5520 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -2829,10 +2829,12 @@ class SQLCompiler(Compiled): if self.linting & COLLECT_CARTESIAN_PRODUCTS: from_linter = FromLinter({}, set()) + warn_linting = self.linting & WARN_LINTING if toplevel: self.from_linter = from_linter else: from_linter = None + warn_linting = False if froms: text += " \nFROM " @@ -2872,10 +2874,7 @@ class SQLCompiler(Compiled): if t: text += " \nWHERE " + t - if ( - self.linting & COLLECT_CARTESIAN_PRODUCTS - and self.linting & WARN_LINTING - ): + if warn_linting: from_linter.warn() if select._group_by_clauses: diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index fa2888a23e..986bf134c7 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -2189,6 +2189,10 @@ class BooleanClauseList(ClauseList, ColumnElement): has_continue_on = None convert_clauses = [] + + against = operators._asbool + lcc = 0 + for clause in clauses: if clause is continue_on: # instance of continue_on, like and_(x, y, True, z), store it @@ -2199,24 +2203,26 @@ class BooleanClauseList(ClauseList, ColumnElement): # instance of skip_on, e.g. and_(x, y, False, z), cancels # the rest out convert_clauses = [clause] + lcc = 1 break else: + if not lcc: + lcc = 1 + else: + against = operator + # techincally this would be len(convert_clauses) + 1 + # however this only needs to indicate "greater than one" + lcc = 2 convert_clauses.append(clause) if not convert_clauses and has_continue_on is not None: convert_clauses = [has_continue_on] + lcc = 1 - lcc = len(convert_clauses) - - if lcc > 1: - against = operator - else: - against = operators._asbool return lcc, [c.self_group(against=against) for c in convert_clauses] @classmethod def _construct(cls, operator, continue_on, skip_on, *clauses, **kw): - lcc, convert_clauses = cls._process_clauses_for_boolean( operator, continue_on, diff --git a/test/aaa_profiling/test_orm.py b/test/aaa_profiling/test_orm.py index 5dbbc2f5c1..1b31c96e9c 100644 --- a/test/aaa_profiling/test_orm.py +++ b/test/aaa_profiling/test_orm.py @@ -870,7 +870,9 @@ class JoinedEagerLoadTest(fixtures.MappedTest): ) r.context.compiled.compile_state = compile_state - obj = ORMCompileState.orm_setup_cursor_result(sess, {}, r) + obj = ORMCompileState.orm_setup_cursor_result( + sess, compile_state.statement, exec_opts, {}, r + ) list(obj) sess.close() diff --git a/test/orm/test_query.py b/test/orm/test_query.py index 5cc6140a51..a7cc82ddfa 100644 --- a/test/orm/test_query.py +++ b/test/orm/test_query.py @@ -3551,6 +3551,23 @@ class CountTest(QueryTest): eq_(s.query(User).filter(users.c.name.endswith("ed")).count(), 2) + def test_basic_future(self): + User = self.classes.User + + s = create_session() + + eq_( + s.execute(future_select(func.count()).select_from(User)).scalar(), + 4, + ) + + eq_( + s.execute( + future_select(func.count()).filter(User.name.endswith("ed")) + ).scalar(), + 2, + ) + def test_count_char(self): User = self.classes.User s = create_session() @@ -3579,6 +3596,21 @@ class CountTest(QueryTest): q = s.query(User, Address).join(User.addresses) eq_(q.count(), 5) + def test_multiple_entity_future(self): + User, Address = self.classes.User, self.classes.Address + + s = create_session() + + stmt = future_select(User, Address).join(Address, true()) + + stmt = future_select(func.count()).select_from(stmt.subquery()) + eq_(s.scalar(stmt), 20) # cartesian product + + stmt = future_select(User, Address).join(Address) + + stmt = future_select(func.count()).select_from(stmt.subquery()) + eq_(s.scalar(stmt), 5) + def test_nested(self): User, Address = self.classes.User, self.classes.Address @@ -3592,6 +3624,29 @@ class CountTest(QueryTest): q = s.query(User, Address).join(User.addresses).limit(100) eq_(q.count(), 5) + def test_nested_future(self): + User, Address = self.classes.User, self.classes.Address + + s = create_session() + + stmt = future_select(User, Address).join(Address, true()).limit(2) + eq_( + s.scalar(future_select(func.count()).select_from(stmt.subquery())), + 2, + ) + + stmt = future_select(User, Address).join(Address, true()).limit(100) + eq_( + s.scalar(future_select(func.count()).select_from(stmt.subquery())), + 20, + ) + + stmt = future_select(User, Address).join(Address).limit(100) + eq_( + s.scalar(future_select(func.count()).select_from(stmt.subquery())), + 5, + ) + def test_cols(self): """test that column-based queries always nest.""" @@ -3615,6 +3670,49 @@ class CountTest(QueryTest): eq_(q.count(), 5) eq_(q.distinct().count(), 3) + def test_cols_future(self): + + User, Address = self.classes.User, self.classes.Address + + s = create_session() + + stmt = future_select(func.count(distinct(User.name))) + eq_( + s.scalar(future_select(func.count()).select_from(stmt.subquery())), + 1, + ) + + stmt = future_select(func.count(distinct(User.name))).distinct() + + eq_( + s.scalar(future_select(func.count()).select_from(stmt.subquery())), + 1, + ) + + stmt = future_select(User.name) + eq_( + s.scalar(future_select(func.count()).select_from(stmt.subquery())), + 4, + ) + + stmt = future_select(User.name, Address).join(Address, true()) + eq_( + s.scalar(future_select(func.count()).select_from(stmt.subquery())), + 20, + ) + + stmt = future_select(Address.user_id) + eq_( + s.scalar(future_select(func.count()).select_from(stmt.subquery())), + 5, + ) + + stmt = stmt.distinct() + eq_( + s.scalar(future_select(func.count()).select_from(stmt.subquery())), + 3, + ) + class DistinctTest(QueryTest, AssertsCompiledSQL): __dialect__ = "default"