def orm_pre_session_exec(
cls, session, statement, execution_options, bind_arguments
):
- if execution_options:
- # TODO: will have to provide public API to set some load
- # options and also extract them from that API here, likely
- # execution options
- load_options = execution_options.get(
- "_sa_orm_load_options", QueryContext.default_load_options
- )
- else:
- load_options = QueryContext.default_load_options
+ load_options = execution_options.get(
+ "_sa_orm_load_options", QueryContext.default_load_options
+ )
bind_arguments["clause"] = statement
session._autoflush()
@classmethod
- def orm_setup_cursor_result(cls, session, bind_arguments, result):
+ def orm_setup_cursor_result(
+ cls, session, statement, execution_options, bind_arguments, result
+ ):
execution_context = result.context
compile_state = execution_context.compiled.compile_state
# were passed to session.execute:
# session.execute(legacy_select([User.id, User.name]))
# see test_query->test_legacy_tuple_old_select
- if not execution_context.compiled.statement._is_future:
+ if not statement._is_future:
return result
- execution_options = execution_context.execution_options
-
- # we are getting these right above in orm_pre_session_exec(),
- # then getting them again right here.
load_options = execution_options.get(
"_sa_orm_load_options", QueryContext.default_load_options
)
)
def _connection_for_bind(self, engine, execution_options=None, **kw):
- self._autobegin()
-
- if self._transaction is not None:
+ if self._transaction is not None or self._autobegin():
return self._transaction._connection_for_bind(
engine, execution_options
)
- else:
- conn = engine.connect(**kw)
- if execution_options:
- conn = conn.execution_options(**execution_options)
- return conn
+
+ assert self._transaction is None
+ conn = engine.connect(**kw)
+ if execution_options:
+ conn = conn.execution_options(**execution_options)
+ return conn
def execute(
self,
compile_state_cls.orm_pre_session_exec(
self, statement, execution_options, bind_arguments
)
+
+ if self.dispatch.do_orm_execute:
+ skip_events = bind_arguments.pop("_sa_skip_events", False)
+
+ if not skip_events:
+ orm_exec_state = ORMExecuteState(
+ self,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ )
+ for fn in self.dispatch.do_orm_execute:
+ result = fn(orm_exec_state)
+ if result:
+ return result
+
else:
compile_state_cls = None
bind_arguments.setdefault("clause", statement)
execution_options, {"future_result": True}
)
- if self.dispatch.do_orm_execute:
- skip_events = bind_arguments.pop("_sa_skip_events", False)
-
- if not skip_events:
- orm_exec_state = ORMExecuteState(
- self, statement, params, execution_options, bind_arguments
- )
- for fn in self.dispatch.do_orm_execute:
- result = fn(orm_exec_state)
- if result:
- return result
-
bind = self.get_bind(**bind_arguments)
conn = self._connection_for_bind(bind, close_with_result=True)
if compile_state_cls:
result = compile_state_cls.orm_setup_cursor_result(
- self, bind_arguments, result
+ self, statement, execution_options, bind_arguments, result
)
return result
if self.linting & COLLECT_CARTESIAN_PRODUCTS:
from_linter = FromLinter({}, set())
+ warn_linting = self.linting & WARN_LINTING
if toplevel:
self.from_linter = from_linter
else:
from_linter = None
+ warn_linting = False
if froms:
text += " \nFROM "
if t:
text += " \nWHERE " + t
- if (
- self.linting & COLLECT_CARTESIAN_PRODUCTS
- and self.linting & WARN_LINTING
- ):
+ if warn_linting:
from_linter.warn()
if select._group_by_clauses:
has_continue_on = None
convert_clauses = []
+
+ against = operators._asbool
+ lcc = 0
+
for clause in clauses:
if clause is continue_on:
# instance of continue_on, like and_(x, y, True, z), store it
# instance of skip_on, e.g. and_(x, y, False, z), cancels
# the rest out
convert_clauses = [clause]
+ lcc = 1
break
else:
+ if not lcc:
+ lcc = 1
+ else:
+ against = operator
+ # techincally this would be len(convert_clauses) + 1
+ # however this only needs to indicate "greater than one"
+ lcc = 2
convert_clauses.append(clause)
if not convert_clauses and has_continue_on is not None:
convert_clauses = [has_continue_on]
+ lcc = 1
- lcc = len(convert_clauses)
-
- if lcc > 1:
- against = operator
- else:
- against = operators._asbool
return lcc, [c.self_group(against=against) for c in convert_clauses]
@classmethod
def _construct(cls, operator, continue_on, skip_on, *clauses, **kw):
-
lcc, convert_clauses = cls._process_clauses_for_boolean(
operator,
continue_on,
)
r.context.compiled.compile_state = compile_state
- obj = ORMCompileState.orm_setup_cursor_result(sess, {}, r)
+ obj = ORMCompileState.orm_setup_cursor_result(
+ sess, compile_state.statement, exec_opts, {}, r
+ )
list(obj)
sess.close()
eq_(s.query(User).filter(users.c.name.endswith("ed")).count(), 2)
+ def test_basic_future(self):
+ User = self.classes.User
+
+ s = create_session()
+
+ eq_(
+ s.execute(future_select(func.count()).select_from(User)).scalar(),
+ 4,
+ )
+
+ eq_(
+ s.execute(
+ future_select(func.count()).filter(User.name.endswith("ed"))
+ ).scalar(),
+ 2,
+ )
+
def test_count_char(self):
User = self.classes.User
s = create_session()
q = s.query(User, Address).join(User.addresses)
eq_(q.count(), 5)
+ def test_multiple_entity_future(self):
+ User, Address = self.classes.User, self.classes.Address
+
+ s = create_session()
+
+ stmt = future_select(User, Address).join(Address, true())
+
+ stmt = future_select(func.count()).select_from(stmt.subquery())
+ eq_(s.scalar(stmt), 20) # cartesian product
+
+ stmt = future_select(User, Address).join(Address)
+
+ stmt = future_select(func.count()).select_from(stmt.subquery())
+ eq_(s.scalar(stmt), 5)
+
def test_nested(self):
User, Address = self.classes.User, self.classes.Address
q = s.query(User, Address).join(User.addresses).limit(100)
eq_(q.count(), 5)
+ def test_nested_future(self):
+ User, Address = self.classes.User, self.classes.Address
+
+ s = create_session()
+
+ stmt = future_select(User, Address).join(Address, true()).limit(2)
+ eq_(
+ s.scalar(future_select(func.count()).select_from(stmt.subquery())),
+ 2,
+ )
+
+ stmt = future_select(User, Address).join(Address, true()).limit(100)
+ eq_(
+ s.scalar(future_select(func.count()).select_from(stmt.subquery())),
+ 20,
+ )
+
+ stmt = future_select(User, Address).join(Address).limit(100)
+ eq_(
+ s.scalar(future_select(func.count()).select_from(stmt.subquery())),
+ 5,
+ )
+
def test_cols(self):
"""test that column-based queries always nest."""
eq_(q.count(), 5)
eq_(q.distinct().count(), 3)
+ def test_cols_future(self):
+
+ User, Address = self.classes.User, self.classes.Address
+
+ s = create_session()
+
+ stmt = future_select(func.count(distinct(User.name)))
+ eq_(
+ s.scalar(future_select(func.count()).select_from(stmt.subquery())),
+ 1,
+ )
+
+ stmt = future_select(func.count(distinct(User.name))).distinct()
+
+ eq_(
+ s.scalar(future_select(func.count()).select_from(stmt.subquery())),
+ 1,
+ )
+
+ stmt = future_select(User.name)
+ eq_(
+ s.scalar(future_select(func.count()).select_from(stmt.subquery())),
+ 4,
+ )
+
+ stmt = future_select(User.name, Address).join(Address, true())
+ eq_(
+ s.scalar(future_select(func.count()).select_from(stmt.subquery())),
+ 20,
+ )
+
+ stmt = future_select(Address.user_id)
+ eq_(
+ s.scalar(future_select(func.count()).select_from(stmt.subquery())),
+ 5,
+ )
+
+ stmt = stmt.distinct()
+ eq_(
+ s.scalar(future_select(func.count()).select_from(stmt.subquery())),
+ 3,
+ )
+
class DistinctTest(QueryTest, AssertsCompiledSQL):
__dialect__ = "default"