]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Inline a few ORM arguments, others
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 3 Jun 2020 13:50:04 +0000 (09:50 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 3 Jun 2020 15:36:37 +0000 (11:36 -0400)
small changes

Change-Id: Id89a0651196c431d0aaf6935f5a4e7b12dd70c6c

lib/sqlalchemy/orm/context.py
lib/sqlalchemy/orm/session.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/elements.py
test/aaa_profiling/test_orm.py
test/orm/test_query.py

index ba30d203bde7925bd22d455bd691cecf5bbff9fa..bd4074ea1115609a662a6c7f7cff657587d6a4f7 100644 (file)
@@ -191,15 +191,9 @@ class ORMCompileState(CompileState):
     def orm_pre_session_exec(
         cls, session, statement, execution_options, bind_arguments
     ):
-        if execution_options:
-            # TODO: will have to provide public API to set some load
-            # options and also extract them from that API here, likely
-            # execution options
-            load_options = execution_options.get(
-                "_sa_orm_load_options", QueryContext.default_load_options
-            )
-        else:
-            load_options = QueryContext.default_load_options
+        load_options = execution_options.get(
+            "_sa_orm_load_options", QueryContext.default_load_options
+        )
 
         bind_arguments["clause"] = statement
 
@@ -223,7 +217,9 @@ class ORMCompileState(CompileState):
             session._autoflush()
 
     @classmethod
-    def orm_setup_cursor_result(cls, session, bind_arguments, result):
+    def orm_setup_cursor_result(
+        cls, session, statement, execution_options, bind_arguments, result
+    ):
         execution_context = result.context
         compile_state = execution_context.compiled.compile_state
 
@@ -231,13 +227,9 @@ class ORMCompileState(CompileState):
         # were passed to session.execute:
         # session.execute(legacy_select([User.id, User.name]))
         # see test_query->test_legacy_tuple_old_select
-        if not execution_context.compiled.statement._is_future:
+        if not statement._is_future:
             return result
 
-        execution_options = execution_context.execution_options
-
-        # we are getting these right above in orm_pre_session_exec(),
-        # then getting them again right here.
         load_options = execution_options.get(
             "_sa_orm_load_options", QueryContext.default_load_options
         )
index 25e2243487f60bf012a751dee9bc1dbc6c7e992b..ee42419a261d0759214527dc9ed27899af2cfa02 100644 (file)
@@ -1297,17 +1297,16 @@ class Session(_SessionClassMethods):
         )
 
     def _connection_for_bind(self, engine, execution_options=None, **kw):
-        self._autobegin()
-
-        if self._transaction is not None:
+        if self._transaction is not None or self._autobegin():
             return self._transaction._connection_for_bind(
                 engine, execution_options
             )
-        else:
-            conn = engine.connect(**kw)
-            if execution_options:
-                conn = conn.execution_options(**execution_options)
-            return conn
+
+        assert self._transaction is None
+        conn = engine.connect(**kw)
+        if execution_options:
+            conn = conn.execution_options(**execution_options)
+        return conn
 
     def execute(
         self,
@@ -1460,6 +1459,23 @@ class Session(_SessionClassMethods):
             compile_state_cls.orm_pre_session_exec(
                 self, statement, execution_options, bind_arguments
             )
+
+            if self.dispatch.do_orm_execute:
+                skip_events = bind_arguments.pop("_sa_skip_events", False)
+
+                if not skip_events:
+                    orm_exec_state = ORMExecuteState(
+                        self,
+                        statement,
+                        params,
+                        execution_options,
+                        bind_arguments,
+                    )
+                    for fn in self.dispatch.do_orm_execute:
+                        result = fn(orm_exec_state)
+                        if result:
+                            return result
+
         else:
             compile_state_cls = None
             bind_arguments.setdefault("clause", statement)
@@ -1468,18 +1484,6 @@ class Session(_SessionClassMethods):
                     execution_options, {"future_result": True}
                 )
 
-        if self.dispatch.do_orm_execute:
-            skip_events = bind_arguments.pop("_sa_skip_events", False)
-
-            if not skip_events:
-                orm_exec_state = ORMExecuteState(
-                    self, statement, params, execution_options, bind_arguments
-                )
-                for fn in self.dispatch.do_orm_execute:
-                    result = fn(orm_exec_state)
-                    if result:
-                        return result
-
         bind = self.get_bind(**bind_arguments)
 
         conn = self._connection_for_bind(bind, close_with_result=True)
@@ -1487,7 +1491,7 @@ class Session(_SessionClassMethods):
 
         if compile_state_cls:
             result = compile_state_cls.orm_setup_cursor_result(
-                self, bind_arguments, result
+                self, statement, execution_options, bind_arguments, result
             )
 
         return result
index 4bd19e04b53579fdb907fda137d56273dc89f226..f4160b5520cb06c031d483f1f702bbf5dbd0941d 100644 (file)
@@ -2829,10 +2829,12 @@ class SQLCompiler(Compiled):
 
         if self.linting & COLLECT_CARTESIAN_PRODUCTS:
             from_linter = FromLinter({}, set())
+            warn_linting = self.linting & WARN_LINTING
             if toplevel:
                 self.from_linter = from_linter
         else:
             from_linter = None
+            warn_linting = False
 
         if froms:
             text += " \nFROM "
@@ -2872,10 +2874,7 @@ class SQLCompiler(Compiled):
             if t:
                 text += " \nWHERE " + t
 
-        if (
-            self.linting & COLLECT_CARTESIAN_PRODUCTS
-            and self.linting & WARN_LINTING
-        ):
+        if warn_linting:
             from_linter.warn()
 
         if select._group_by_clauses:
index fa2888a23eae54d3f2e4b9fa374bef0d6abf91ca..986bf134c7787ed97f0ae214abe35fe034b86c18 100644 (file)
@@ -2189,6 +2189,10 @@ class BooleanClauseList(ClauseList, ColumnElement):
         has_continue_on = None
 
         convert_clauses = []
+
+        against = operators._asbool
+        lcc = 0
+
         for clause in clauses:
             if clause is continue_on:
                 # instance of continue_on, like and_(x, y, True, z), store it
@@ -2199,24 +2203,26 @@ class BooleanClauseList(ClauseList, ColumnElement):
                 # instance of skip_on, e.g. and_(x, y, False, z), cancels
                 # the rest out
                 convert_clauses = [clause]
+                lcc = 1
                 break
             else:
+                if not lcc:
+                    lcc = 1
+                else:
+                    against = operator
+                    # techincally this would be len(convert_clauses) + 1
+                    # however this only needs to indicate "greater than one"
+                    lcc = 2
                 convert_clauses.append(clause)
 
         if not convert_clauses and has_continue_on is not None:
             convert_clauses = [has_continue_on]
+            lcc = 1
 
-        lcc = len(convert_clauses)
-
-        if lcc > 1:
-            against = operator
-        else:
-            against = operators._asbool
         return lcc, [c.self_group(against=against) for c in convert_clauses]
 
     @classmethod
     def _construct(cls, operator, continue_on, skip_on, *clauses, **kw):
-
         lcc, convert_clauses = cls._process_clauses_for_boolean(
             operator,
             continue_on,
index 5dbbc2f5c176a21730e9c44308451c86070e3d56..1b31c96e9c8917f40c58c4386465973e03c8231c 100644 (file)
@@ -870,7 +870,9 @@ class JoinedEagerLoadTest(fixtures.MappedTest):
                 )
 
                 r.context.compiled.compile_state = compile_state
-                obj = ORMCompileState.orm_setup_cursor_result(sess, {}, r)
+                obj = ORMCompileState.orm_setup_cursor_result(
+                    sess, compile_state.statement, exec_opts, {}, r
+                )
                 list(obj)
                 sess.close()
 
index 5cc6140a512f12f3eea1a2cc9c6ab4991414f484..a7cc82ddfa1d6748b41f062d9967ce5e76d4608f 100644 (file)
@@ -3551,6 +3551,23 @@ class CountTest(QueryTest):
 
         eq_(s.query(User).filter(users.c.name.endswith("ed")).count(), 2)
 
+    def test_basic_future(self):
+        User = self.classes.User
+
+        s = create_session()
+
+        eq_(
+            s.execute(future_select(func.count()).select_from(User)).scalar(),
+            4,
+        )
+
+        eq_(
+            s.execute(
+                future_select(func.count()).filter(User.name.endswith("ed"))
+            ).scalar(),
+            2,
+        )
+
     def test_count_char(self):
         User = self.classes.User
         s = create_session()
@@ -3579,6 +3596,21 @@ class CountTest(QueryTest):
         q = s.query(User, Address).join(User.addresses)
         eq_(q.count(), 5)
 
+    def test_multiple_entity_future(self):
+        User, Address = self.classes.User, self.classes.Address
+
+        s = create_session()
+
+        stmt = future_select(User, Address).join(Address, true())
+
+        stmt = future_select(func.count()).select_from(stmt.subquery())
+        eq_(s.scalar(stmt), 20)  # cartesian product
+
+        stmt = future_select(User, Address).join(Address)
+
+        stmt = future_select(func.count()).select_from(stmt.subquery())
+        eq_(s.scalar(stmt), 5)
+
     def test_nested(self):
         User, Address = self.classes.User, self.classes.Address
 
@@ -3592,6 +3624,29 @@ class CountTest(QueryTest):
         q = s.query(User, Address).join(User.addresses).limit(100)
         eq_(q.count(), 5)
 
+    def test_nested_future(self):
+        User, Address = self.classes.User, self.classes.Address
+
+        s = create_session()
+
+        stmt = future_select(User, Address).join(Address, true()).limit(2)
+        eq_(
+            s.scalar(future_select(func.count()).select_from(stmt.subquery())),
+            2,
+        )
+
+        stmt = future_select(User, Address).join(Address, true()).limit(100)
+        eq_(
+            s.scalar(future_select(func.count()).select_from(stmt.subquery())),
+            20,
+        )
+
+        stmt = future_select(User, Address).join(Address).limit(100)
+        eq_(
+            s.scalar(future_select(func.count()).select_from(stmt.subquery())),
+            5,
+        )
+
     def test_cols(self):
         """test that column-based queries always nest."""
 
@@ -3615,6 +3670,49 @@ class CountTest(QueryTest):
         eq_(q.count(), 5)
         eq_(q.distinct().count(), 3)
 
+    def test_cols_future(self):
+
+        User, Address = self.classes.User, self.classes.Address
+
+        s = create_session()
+
+        stmt = future_select(func.count(distinct(User.name)))
+        eq_(
+            s.scalar(future_select(func.count()).select_from(stmt.subquery())),
+            1,
+        )
+
+        stmt = future_select(func.count(distinct(User.name))).distinct()
+
+        eq_(
+            s.scalar(future_select(func.count()).select_from(stmt.subquery())),
+            1,
+        )
+
+        stmt = future_select(User.name)
+        eq_(
+            s.scalar(future_select(func.count()).select_from(stmt.subquery())),
+            4,
+        )
+
+        stmt = future_select(User.name, Address).join(Address, true())
+        eq_(
+            s.scalar(future_select(func.count()).select_from(stmt.subquery())),
+            20,
+        )
+
+        stmt = future_select(Address.user_id)
+        eq_(
+            s.scalar(future_select(func.count()).select_from(stmt.subquery())),
+            5,
+        )
+
+        stmt = stmt.distinct()
+        eq_(
+            s.scalar(future_select(func.count()).select_from(stmt.subquery())),
+            3,
+        )
+
 
 class DistinctTest(QueryTest, AssertsCompiledSQL):
     __dialect__ = "default"