]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
cx_Oracle modernize
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 4 Apr 2022 14:13:23 +0000 (10:13 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 7 Apr 2022 14:47:53 +0000 (10:47 -0400)
Full "RETURNING" support is implemented for the cx_Oracle dialect, meaning
multiple RETURNING rows are now recived for DML statements that produce
more than one row for RETURNING.

cx_Oracle 7 is now the minimum version for cx_Oracle.

Getting Oracle to do multirow returning took about 5 minutes.  however,
getting Oracle's RETURNING system to integrate with ORM-enabled
insert, update, delete, is a big deal because that architecture wasn't
really working very robustly, including some recent changes in 1.4
for FromStatement were done in a hurry, so this patch also cleans up
the FromStatement situation and begins to establish it more concretely
as the base for all ReturnsRows / TextClause ORM scenarios.

Fixes: #6245
Change-Id: I2b4e6007affa51ce311d2d5baa3917f356ab961f

24 files changed:
doc/build/changelog/unreleased_20/6245.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/mssql/base.py
lib/sqlalchemy/dialects/oracle/base.py
lib/sqlalchemy/dialects/oracle/cx_oracle.py
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/dialects/postgresql/psycopg2.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/orm/_typing.py [new file with mode: 0644]
lib/sqlalchemy/orm/context.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/sql/_typing.py
lib/sqlalchemy/sql/base.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/crud.py
lib/sqlalchemy/sql/dml.py
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/selectable.py
setup.cfg
test/dialect/oracle/test_dialect.py
test/engine/test_deprecations.py
test/ext/declarative/test_reflection.py
test/orm/test_session.py
test/orm/test_transaction.py
test/sql/test_insert.py

diff --git a/doc/build/changelog/unreleased_20/6245.rst b/doc/build/changelog/unreleased_20/6245.rst
new file mode 100644 (file)
index 0000000..1247544
--- /dev/null
@@ -0,0 +1,13 @@
+.. change::
+    :tags: oracle, feature
+    :tickets: 6245
+
+    Full "RETURNING" support is implemented for the cx_Oracle dialect, meaning
+    multiple RETURNING rows are now recived for DML statements that produce
+    more than one row for RETURNING.
+
+
+.. change::
+    :tags: oracle
+
+    cx_Oracle 7 is now the minimum version for cx_Oracle.
index 07ff495a717b8f3fe0678b4f7c2dfbbcb79b1089..ac02b98a04d6fdf94fb89c5900dd9fe71fb14505 100644 (file)
@@ -805,10 +805,13 @@ https://msdn.microsoft.com/en-us/library/ms175095.aspx.
 
 """  # noqa
 
+from __future__ import annotations
+
 import codecs
 import datetime
 import operator
 import re
+from typing import TYPE_CHECKING
 
 from . import information_schema as ischema
 from .json import JSON
@@ -833,6 +836,7 @@ from ...sql import func
 from ...sql import quoted_name
 from ...sql import roles
 from ...sql import util as sql_util
+from ...sql._typing import is_sql_compiler
 from ...types import BIGINT
 from ...types import BINARY
 from ...types import CHAR
@@ -849,6 +853,10 @@ from ...types import TEXT
 from ...types import VARCHAR
 from ...util import update_wrapper
 
+if TYPE_CHECKING:
+    from ...sql.compiler import SQLCompiler
+    from ...sql.dml import DMLState
+    from ...sql.selectable import TableClause
 
 # https://sqlserverbuilds.blogspot.com/
 MS_2017_VERSION = (14,)
@@ -1623,6 +1631,8 @@ class MSExecutionContext(default.DefaultExecutionContext):
     _lastrowid = None
     _rowcount = None
 
+    dialect: MSDialect
+
     def _opt_encode(self, statement):
 
         if self.compiled and self.compiled.schema_translate_map:
@@ -1636,13 +1646,20 @@ class MSExecutionContext(default.DefaultExecutionContext):
         """Activate IDENTITY_INSERT if needed."""
 
         if self.isinsert:
+            if TYPE_CHECKING:
+                assert is_sql_compiler(self.compiled)
+                assert isinstance(self.compiled.compile_state, DMLState)
+                assert isinstance(
+                    self.compiled.compile_state.dml_table, TableClause
+                )
+
             tbl = self.compiled.compile_state.dml_table
             id_column = tbl._autoincrement_column
-            insert_has_identity = (id_column is not None) and (
-                not isinstance(id_column.default, Sequence)
-            )
 
-            if insert_has_identity:
+            if id_column is not None and (
+                not isinstance(id_column.default, Sequence)
+            ):
+                insert_has_identity = True
                 compile_state = self.compiled.compile_state
                 self._enable_identity_insert = (
                     id_column.key in self.compiled_parameters[0]
@@ -1655,12 +1672,13 @@ class MSExecutionContext(default.DefaultExecutionContext):
                 )
 
             else:
+                insert_has_identity = False
                 self._enable_identity_insert = False
 
             self._select_lastrowid = (
                 not self.compiled.inline
                 and insert_has_identity
-                and not self.compiled.returning
+                and not self.compiled.effective_returning
                 and not self._enable_identity_insert
                 and not self.executemany
             )
@@ -1701,8 +1719,10 @@ class MSExecutionContext(default.DefaultExecutionContext):
             self._lastrowid = int(row[0])
 
         elif (
-            self.isinsert or self.isupdate or self.isdelete
-        ) and self.compiled.returning:
+            self.compiled is not None
+            and is_sql_compiler(self.compiled)
+            and self.compiled.effective_returning
+        ):
             self.cursor_fetch_strategy = (
                 _cursor.FullyBufferedCursorFetchStrategy(
                     self.cursor,
@@ -1712,6 +1732,12 @@ class MSExecutionContext(default.DefaultExecutionContext):
             )
 
         if self._enable_identity_insert:
+            if TYPE_CHECKING:
+                assert is_sql_compiler(self.compiled)
+                assert isinstance(self.compiled.compile_state, DMLState)
+                assert isinstance(
+                    self.compiled.compile_state.dml_table, TableClause
+                )
             conn._cursor_execute(
                 self.cursor,
                 self._opt_encode(
@@ -2065,17 +2091,21 @@ class MSSQLCompiler(compiler.SQLCompiler):
             )
         return super(MSSQLCompiler, self).visit_binary(binary, **kwargs)
 
-    def returning_clause(self, stmt, returning_cols):
+    def returning_clause(
+        self, stmt, returning_cols, *, populate_result_map, **kw
+    ):
         # SQL server returning clause requires that the columns refer to
         # the virtual table names "inserted" or "deleted".   Here, we make
         # a simple alias of our table with that name, and then adapt the
         # columns we have from the list of RETURNING columns to that new name
         # so that they render as "inserted.<colname>" / "deleted.<colname>".
 
-        if self.isinsert or self.isupdate:
+        if stmt.is_insert or stmt.is_update:
             target = stmt.table.alias("inserted")
-        else:
+        elif stmt.is_delete:
             target = stmt.table.alias("deleted")
+        else:
+            assert False, "expected Insert, Update or Delete statement"
 
         adapter = sql_util.ClauseAdapter(target)
 
@@ -2091,6 +2121,7 @@ class MSSQLCompiler(compiler.SQLCompiler):
             self._label_returning_column(
                 stmt,
                 adapter.traverse(c),
+                populate_result_map,
                 {"result_map_targets": (c,)},
             )
             for c in expression._select_iterables(returning_cols)
index 3ee38c0cf1836c87f9090de58283c586b1e28db0..39a542cce87d88b020c092bc064a6e4b6ba15d9e 100644 (file)
@@ -1089,7 +1089,9 @@ class OracleCompiler(compiler.SQLCompiler):
 
         return " " + alias_name_text
 
-    def returning_clause(self, stmt, returning_cols):
+    def returning_clause(
+        self, stmt, returning_cols, *, populate_result_map, **kw
+    ):
         columns = []
         binds = []
 
@@ -1122,23 +1124,34 @@ class OracleCompiler(compiler.SQLCompiler):
                 self.bindparam_string(self._truncate_bindparam(outparam))
             )
 
-            # ensure the ExecutionContext.get_out_parameters() method is
-            # *not* called; the cx_Oracle dialect wants to handle these
-            # parameters separately
-            self.has_out_parameters = False
+            # has_out_parameters would in a normal case be set to True
+            # as a result of the compiler visiting an outparam() object.
+            # in this case, the above outparam() objects are not being
+            # visited.   Ensure the statement itself didn't have other
+            # outparam() objects independently.
+            # technically, this could be supported, but as it would be
+            # a very strange use case without a clear rationale, disallow it
+            if self.has_out_parameters:
+                raise exc.InvalidRequestError(
+                    "Using explicit outparam() objects with "
+                    "UpdateBase.returning() in the same Core DML statement "
+                    "is not supported in the Oracle dialect."
+                )
 
-            columns.append(self.process(col_expr, within_columns_clause=False))
+            self._oracle_returning = True
 
-            self._add_to_result_map(
-                getattr(col_expr, "name", col_expr._anon_name_label),
-                getattr(col_expr, "name", col_expr._anon_name_label),
-                (
-                    column,
-                    getattr(column, "name", None),
-                    getattr(column, "key", None),
-                ),
-                column.type,
-            )
+            columns.append(self.process(col_expr, within_columns_clause=False))
+            if populate_result_map:
+                self._add_to_result_map(
+                    getattr(col_expr, "name", col_expr._anon_name_label),
+                    getattr(col_expr, "name", col_expr._anon_name_label),
+                    (
+                        column,
+                        getattr(column, "name", None),
+                        getattr(column, "key", None),
+                    ),
+                    column.type,
+                )
 
         return "RETURNING " + ", ".join(columns) + " INTO " + ", ".join(binds)
 
@@ -1510,6 +1523,7 @@ class OracleDialect(default.DefaultDialect):
     max_identifier_length = 128
 
     implicit_returning = True
+    full_returning = True
 
     div_is_floordiv = False
 
index 5208f96718b4b39baa1e74141e8ccf01e739ed88..9f3394533167fde6f67070c17cce078351ff1209 100644 (file)
@@ -352,8 +352,7 @@ RETURNING Support
 -----------------
 
 The cx_Oracle dialect implements RETURNING using OUT parameters.
-The dialect supports RETURNING fully, however cx_Oracle 6 is recommended
-for complete support.
+The dialect supports RETURNING fully.
 
 .. _cx_oracle_lob:
 
@@ -430,6 +429,8 @@ SQLAlchemy type (or a subclass of such).
    as better integration of outputtypehandlers.
 
 """  # noqa
+from __future__ import annotations
+
 import decimal
 import random
 import re
@@ -444,6 +445,7 @@ from ... import util
 from ...engine import cursor as _cursor
 from ...engine import interfaces
 from ...engine import processors
+from ...sql._typing import is_sql_compiler
 
 
 class _OracleInteger(sqltypes.Integer):
@@ -452,10 +454,13 @@ class _OracleInteger(sqltypes.Integer):
         # 208#issuecomment-409715955
         return int
 
-    def _cx_oracle_var(self, dialect, cursor):
+    def _cx_oracle_var(self, dialect, cursor, arraysize=None):
         cx_Oracle = dialect.dbapi
         return cursor.var(
-            cx_Oracle.STRING, 255, arraysize=cursor.arraysize, outconverter=int
+            cx_Oracle.STRING,
+            255,
+            arraysize=arraysize if arraysize is not None else cursor.arraysize,
+            outconverter=int,
         )
 
     def _cx_oracle_outputtypehandler(self, dialect):
@@ -494,8 +499,6 @@ class _OracleNumeric(sqltypes.Numeric):
     def _cx_oracle_outputtypehandler(self, dialect):
         cx_Oracle = dialect.dbapi
 
-        is_cx_oracle_6 = dialect._is_cx_oracle_6
-
         def handler(cursor, name, default_type, size, precision, scale):
             outconverter = None
 
@@ -506,11 +509,8 @@ class _OracleNumeric(sqltypes.Numeric):
                         # allows for float("inf") to be handled
                         type_ = default_type
                         outconverter = decimal.Decimal
-                    elif is_cx_oracle_6:
-                        type_ = decimal.Decimal
                     else:
-                        type_ = cx_Oracle.STRING
-                        outconverter = dialect._to_decimal
+                        type_ = decimal.Decimal
                 else:
                     if self.is_number and scale == 0:
                         # integer. cx_Oracle is observed to handle the widest
@@ -525,11 +525,8 @@ class _OracleNumeric(sqltypes.Numeric):
                     if default_type == cx_Oracle.NATIVE_FLOAT:
                         type_ = default_type
                         outconverter = decimal.Decimal
-                    elif is_cx_oracle_6:
-                        type_ = decimal.Decimal
                     else:
-                        type_ = cx_Oracle.STRING
-                        outconverter = dialect._to_decimal
+                        type_ = decimal.Decimal
                 else:
                     if self.is_number and scale == 0:
                         # integer. cx_Oracle is observed to handle the widest
@@ -670,6 +667,8 @@ class _OracleRowid(oracle.ROWID):
 class OracleCompiler_cx_oracle(OracleCompiler):
     _oracle_cx_sql_compiler = True
 
+    _oracle_returning = False
+
     def bindparam_string(self, name, **kw):
         quote = getattr(name, "quote", None)
         if (
@@ -696,7 +695,7 @@ class OracleExecutionContext_cx_oracle(OracleExecutionContext):
     def _generate_out_parameter_vars(self):
         # check for has_out_parameters or RETURNING, create cx_Oracle.var
         # objects if so
-        if self.compiled.returning or self.compiled.has_out_parameters:
+        if self.compiled.has_out_parameters or self.compiled._oracle_returning:
             quoted_bind_names = self.compiled.escaped_bind_names
             for bindparam in self.compiled.binds.values():
                 if bindparam.isoutparam:
@@ -705,7 +704,7 @@ class OracleExecutionContext_cx_oracle(OracleExecutionContext):
 
                     if hasattr(type_impl, "_cx_oracle_var"):
                         self.out_parameters[name] = type_impl._cx_oracle_var(
-                            self.dialect, self.cursor
+                            self.dialect, self.cursor, arraysize=1
                         )
                     else:
                         dbtype = type_impl.get_dbapi_type(self.dialect.dbapi)
@@ -726,10 +725,14 @@ class OracleExecutionContext_cx_oracle(OracleExecutionContext):
                             cx_Oracle.NCLOB,
                         ):
                             self.out_parameters[name] = self.cursor.var(
-                                dbtype, outconverter=lambda value: value.read()
+                                dbtype,
+                                outconverter=lambda value: value.read(),
+                                arraysize=1,
                             )
                         else:
-                            self.out_parameters[name] = self.cursor.var(dbtype)
+                            self.out_parameters[name] = self.cursor.var(
+                                dbtype, arraysize=1
+                            )
                     self.parameters[0][
                         quoted_bind_names.get(name, name)
                     ] = self.out_parameters[name]
@@ -782,22 +785,33 @@ class OracleExecutionContext_cx_oracle(OracleExecutionContext):
         self._generate_cursor_outputtype_handler()
 
     def post_exec(self):
-        if self.compiled and self.out_parameters and self.compiled.returning:
+        if (
+            self.compiled
+            and is_sql_compiler(self.compiled)
+            and self.compiled._oracle_returning
+        ):
             # create a fake cursor result from the out parameters. unlike
             # get_out_parameter_values(), the result-row handlers here will be
             # applied at the Result level
-            returning_params = [
-                self.dialect._returningval(self.out_parameters["ret_%d" % i])
-                for i in range(len(self.out_parameters))
+
+            numrows = len(self.out_parameters["ret_0"].values[0])
+            numcols = len(self.out_parameters)
+
+            initial_buffer = [
+                tuple(
+                    self.out_parameters[f"ret_{j}"].values[0][i]
+                    for j in range(numcols)
+                )
+                for i in range(numrows)
             ]
 
             fetch_strategy = _cursor.FullyBufferedCursorFetchStrategy(
                 self.cursor,
                 [
-                    (getattr(col, "name", col._anon_name_label), None)
-                    for col in self.compiled.returning
+                    (entry.keyname, None)
+                    for entry in self.compiled._result_columns
                 ],
-                initial_buffer=[tuple(returning_params)],
+                initial_buffer=initial_buffer,
             )
 
             self.cursor_fetch_strategy = fetch_strategy
@@ -908,18 +922,11 @@ class OracleDialect_cx_oracle(OracleDialect):
             self.cx_oracle_ver = (0, 0, 0)
         else:
             self.cx_oracle_ver = self._parse_cx_oracle_ver(cx_Oracle.version)
-            if self.cx_oracle_ver < (5, 2) and self.cx_oracle_ver > (0, 0, 0):
+            if self.cx_oracle_ver < (7,) and self.cx_oracle_ver > (0, 0, 0):
                 raise exc.InvalidRequestError(
-                    "cx_Oracle version 5.2 and above are supported"
+                    "cx_Oracle version 7 and above are supported"
                 )
 
-            if encoding_errors and self.cx_oracle_ver < (6, 4):
-                util.warn(
-                    "cx_oracle version %r does not support encodingErrors"
-                    % (self.cx_oracle_ver,)
-                )
-                self._cursor_var_unicode_kwargs = util.immutabledict()
-
             self.include_set_input_sizes = {
                 cx_Oracle.DATETIME,
                 cx_Oracle.NCLOB,
@@ -937,24 +944,6 @@ class OracleDialect_cx_oracle(OracleDialect):
 
             self._paramval = lambda value: value.getvalue()
 
-            # https://github.com/oracle/python-cx_Oracle/issues/176#issuecomment-386821291
-            # https://github.com/oracle/python-cx_Oracle/issues/224
-            self._values_are_lists = self.cx_oracle_ver >= (6, 3)
-            if self._values_are_lists:
-                cx_Oracle.__future__.dml_ret_array_val = True
-
-                def _returningval(value):
-                    try:
-                        return value.values[0][0]
-                    except IndexError:
-                        return None
-
-                self._returningval = _returningval
-            else:
-                self._returningval = self._paramval
-
-        self._is_cx_oracle_6 = self.cx_oracle_ver >= (6,)
-
     def _parse_cx_oracle_ver(self, version):
         m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", version)
         if m:
index 9bfde47688c627c89bf58f85152c196646d9b218..99945280491eed50535e8f989690a91abba8f471 100644 (file)
@@ -2401,10 +2401,11 @@ class PGCompiler(compiler.SQLCompiler):
 
         return tmp
 
-    def returning_clause(self, stmt, returning_cols):
-
+    def returning_clause(
+        self, stmt, returning_cols, *, populate_result_map, **kw
+    ):
         columns = [
-            self._label_returning_column(stmt, c)
+            self._label_returning_column(stmt, c, populate_result_map)
             for c in expression._select_iterables(returning_cols)
         ]
 
index 391368c5fc68547da967ece920d7fef957f35ab4..07783ced78a95e1a8922637279b70ee4e13fc8fa 100644 (file)
@@ -478,7 +478,7 @@ class PGExecutionContext_psycopg2(_PGExecutionContext_common_psycopg):
         if (
             self._psycopg2_fetched_rows
             and self.compiled
-            and self.compiled.returning
+            and self.compiled.effective_returning
         ):
             # psycopg2 execute_values will provide for a real cursor where
             # cursor.description works correctly. however, it executes the
@@ -736,7 +736,7 @@ class PGDialect_psycopg2(_PGDialect_common_psycopg):
                 statement,
                 parameters,
                 template=executemany_values,
-                fetch=bool(context.compiled.returning),
+                fetch=bool(context.compiled.effective_returning),
                 **kwargs,
             )
 
index 85ce91deb1720a8ba79634dcf8c90ef399388952..4d700866fe07f770c90998d176a16677a6b7eb29 100644 (file)
@@ -933,8 +933,11 @@ class DefaultExecutionContext(ExecutionContext):
                 assert isinstance(compiled.statement, UpdateBase)
             self.is_crud = True
             self._is_explicit_returning = bool(compiled.statement._returning)
-            self._is_implicit_returning = bool(
-                compiled.returning and not compiled.statement._returning
+            self._is_implicit_returning = is_implicit_returning = bool(
+                compiled.implicit_returning
+            )
+            assert not (
+                is_implicit_returning and compiled.statement._returning
             )
 
         if not parameters:
@@ -1165,12 +1168,6 @@ class DefaultExecutionContext(ExecutionContext):
         else:
             return ()
 
-    @util.memoized_property
-    def returning_cols(self) -> Optional[Sequence[ColumnsClauseRole]]:
-        if TYPE_CHECKING:
-            assert isinstance(self.compiled, SQLCompiler)
-        return self.compiled.returning
-
     @util.memoized_property
     def no_parameters(self):
         return self.execution_options.get("no_parameters", False)
@@ -1349,6 +1346,7 @@ class DefaultExecutionContext(ExecutionContext):
             result = _cursor.CursorResult(self, strategy, cursor_description)
 
         compiled = self.compiled
+
         if (
             compiled
             and not self.isddl
diff --git a/lib/sqlalchemy/orm/_typing.py b/lib/sqlalchemy/orm/_typing.py
new file mode 100644 (file)
index 0000000..e9ddf6d
--- /dev/null
@@ -0,0 +1,11 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+from typing import Union
+
+
+if TYPE_CHECKING:
+    from .mapper import Mapper
+    from .util import AliasedInsp
+
+_EntityType = Union[Mapper, AliasedInsp]
index bcd3e7d2334b3ee0d6ffc5e3d49bef30590ae8f7..edd3fb56bfab1b8c37af920d7b2a93463e5937c1 100644 (file)
@@ -8,7 +8,15 @@
 from __future__ import annotations
 
 import itertools
+from typing import Any
+from typing import Dict
 from typing import List
+from typing import Optional
+from typing import Set
+from typing import Tuple
+from typing import Type
+from typing import TYPE_CHECKING
+from typing import Union
 
 from . import attributes
 from . import interfaces
@@ -27,23 +35,39 @@ from .. import future
 from .. import inspect
 from .. import sql
 from .. import util
-from ..sql import ClauseElement
 from ..sql import coercions
 from ..sql import expression
 from ..sql import roles
 from ..sql import util as sql_util
 from ..sql import visitors
+from ..sql._typing import is_dml
+from ..sql._typing import is_insert_update
+from ..sql._typing import is_select_base
 from ..sql.base import _select_iterables
 from ..sql.base import CacheableOptions
 from ..sql.base import CompileState
+from ..sql.base import Executable
 from ..sql.base import Options
+from ..sql.dml import UpdateBase
+from ..sql.elements import GroupedElement
+from ..sql.elements import TextClause
 from ..sql.selectable import LABEL_STYLE_DISAMBIGUATE_ONLY
 from ..sql.selectable import LABEL_STYLE_NONE
 from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
+from ..sql.selectable import ReturnsRows
 from ..sql.selectable import Select
+from ..sql.selectable import SelectLabelStyle
 from ..sql.selectable import SelectState
 from ..sql.visitors import InternalTraversal
 
+if TYPE_CHECKING:
+    from ._typing import _EntityType
+    from ..sql.compiler import _CompilerStackEntry
+    from ..sql.dml import _DMLTableElement
+    from ..sql.elements import ColumnElement
+    from ..sql.selectable import _LabelConventionCallable
+    from ..sql.selectable import SelectBase
+
 _path_registry = PathRegistry.root
 
 _EMPTY_DICT = util.immutabledict()
@@ -144,16 +168,6 @@ class QueryContext:
         self.yield_per = load_options._yield_per
         self.identity_token = load_options._refresh_identity_token
 
-        if self.yield_per and compile_state._no_yield_pers:
-            raise sa_exc.InvalidRequestError(
-                "The yield_per Query option is currently not "
-                "compatible with %s eager loading.  Please "
-                "specify lazyload('*') or query.enable_eagerloads(False) in "
-                "order to "
-                "proceed with query.yield_per()."
-                % ", ".join(compile_state._no_yield_pers)
-            )
-
 
 _orm_load_exec_options = util.immutabledict(
     {"_result_disable_adapt_to_context": True, "future_result": True}
@@ -196,7 +210,24 @@ class ORMCompileState(CompileState):
         _for_refresh_state = False
         _render_for_subquery = False
 
-    current_path = _path_registry
+    statement: Union[Select, FromStatement]
+    select_statement: Union[Select, FromStatement]
+    _entities: List[_QueryEntity]
+    _polymorphic_adapters: Dict[_EntityType, ORMAdapter]
+    compile_options: Union[
+        Type[default_compile_options], default_compile_options
+    ]
+    _primary_entity: Optional[_QueryEntity]
+    use_legacy_query_style: bool
+    _label_convention: _LabelConventionCallable
+    primary_columns: List[ColumnElement[Any]]
+    secondary_columns: List[ColumnElement[Any]]
+    dedupe_columns: Set[ColumnElement[Any]]
+    create_eager_joins: List[
+        # TODO: this structure is set up by JoinedLoader
+        Tuple[Any, ...]
+    ]
+    current_path: PathRegistry = _path_registry
 
     def __init__(self, *arg, **kw):
         raise NotImplementedError()
@@ -208,7 +239,9 @@ class ORMCompileState(CompileState):
             col_collection.append(obj)
 
     @classmethod
-    def _column_naming_convention(cls, label_style, legacy):
+    def _column_naming_convention(
+        cls, label_style: SelectLabelStyle, legacy: bool
+    ) -> _LabelConventionCallable:
 
         if legacy:
 
@@ -388,6 +421,10 @@ class ORMFromStatementCompileState(ORMCompileState):
     _from_obj_alias = None
     _has_mapper_entities = False
 
+    statement_container: FromStatement
+    requested_statement: Union[SelectBase, TextClause, UpdateBase]
+    dml_table: _DMLTableElement
+
     _has_orm_entities = False
     multi_row_eager_loaders = False
     compound_eager_adapter = None
@@ -403,6 +440,12 @@ class ORMFromStatementCompileState(ORMCompileState):
         else:
             toplevel = True
 
+        if not toplevel:
+            raise sa_exc.CompileError(
+                "The ORM FromStatement construct only supports being "
+                "invoked as the topmost statement, as it is only intended to "
+                "define how result rows should be returned."
+            )
         self = cls.__new__(cls)
         self._primary_entity = None
 
@@ -417,7 +460,6 @@ class ORMFromStatementCompileState(ORMCompileState):
 
         self._entities = []
         self._polymorphic_adapters = {}
-        self._no_yield_pers = set()
 
         self.compile_options = statement_container._compile_options
 
@@ -474,37 +516,47 @@ class ORMFromStatementCompileState(ORMCompileState):
 
         self.order_by = None
 
-        if isinstance(
-            self.statement, (expression.TextClause, expression.UpdateBase)
-        ):
-
+        if isinstance(self.statement, expression.TextClause):
+            # TextClause has no "column" objects at all.  for this case,
+            # we generate columns from our _QueryEntity objects, then
+            # flip on all the "please match no matter what" parameters.
             self.extra_criteria_entities = {}
 
-            # setup for all entities. Currently, this is not useful
-            # for eager loaders, as the eager loaders that work are able
-            # to do their work entirely in row_processor.
             for entity in self._entities:
                 entity.setup_compile_state(self)
 
-            # we did the setup just to get primary columns.
-            self.statement = _AdHocColumnsStatement(
-                self.statement, self.primary_columns
-            )
+            compiler._ordered_columns = (
+                compiler._textual_ordered_columns
+            ) = False
+
+            # enable looser result column matching.  this is shown to be
+            # needed by test_query.py::TextTest
+            compiler._loose_column_name_matching = True
+
+            for c in self.primary_columns:
+                compiler.process(
+                    c,
+                    within_columns_clause=True,
+                    add_to_result_map=compiler._add_to_result_map,
+                )
         else:
-            # allow TextualSelect with implicit columns as well
-            # as select() with ad-hoc columns, see test_query::TextTest
+            # for everyone else, Select, Insert, Update, TextualSelect, they
+            # have column objects already.  After much
+            # experimentation here, the best approach seems to be, use
+            # those columns completely, don't interfere with the compiler
+            # at all; just in ORM land, use an adapter to convert from
+            # our ORM columns to whatever columns are in the statement,
+            # before we look in the result row.  If the inner statement is
+            # not ORM enabled, assume looser col matching based on name
+            statement_is_orm = (
+                self.statement._propagate_attrs.get(
+                    "compile_state_plugin", None
+                )
+                == "orm"
+            )
             self._from_obj_alias = sql.util.ColumnAdapter(
-                self.statement, adapt_on_names=True
+                self.statement, adapt_on_names=not statement_is_orm
             )
-            # set up for eager loaders, however if we fix subqueryload
-            # it should not need to do this here.  the model of eager loaders
-            # that can work entirely in row_processor might be interesting
-            # here though subqueryloader has a lot of upfront work to do
-            # see test/orm/test_query.py -> test_related_eagerload_against_text
-            # for where this part makes a difference.  would rather have
-            # subqueryload figure out what it needs more intelligently.
-            #            for entity in self._entities:
-            #                entity.setup_compile_state(self)
 
         return self
 
@@ -515,63 +567,91 @@ class ORMFromStatementCompileState(ORMCompileState):
         return None
 
 
-class _AdHocColumnsStatement(ClauseElement):
-    """internal object created to somewhat act like a SELECT when we
-    are selecting columns from a DML RETURNING.
+class FromStatement(GroupedElement, ReturnsRows, Executable):
+    """Core construct that represents a load of ORM objects from various
+    :class:`.ReturnsRows` and other classes including:
 
+    :class:`.Select`, :class:`.TextClause`, :class:`.TextualSelect`,
+    :class:`.CompoundSelect`, :class`.Insert`, :class:`.Update`,
+    and in theory, :class:`.Delete`.
 
     """
 
-    __visit_name__ = None
+    __visit_name__ = "orm_from_statement"
 
-    def __init__(self, text, columns):
-        self.element = text
-        self.column_args = [
-            coercions.expect(roles.ColumnsClauseRole, c) for c in columns
-        ]
+    _compile_options = ORMFromStatementCompileState.default_compile_options
 
-    def _generate_cache_key(self):
-        raise NotImplementedError()
+    _compile_state_factory = ORMFromStatementCompileState.create_for_statement
 
-    def _gen_cache_key(self, anon_map, bindparams):
-        raise NotImplementedError()
+    _for_update_arg = None
 
-    def _compiler_dispatch(
-        self, compiler, compound_index=None, asfrom=False, **kw
-    ):
-        """provide a fixed _compiler_dispatch method."""
+    element: Union[SelectBase, TextClause, UpdateBase]
 
-        toplevel = not compiler.stack
-        entry = (
-            compiler._default_stack_entry if toplevel else compiler.stack[-1]
-        )
+    _traverse_internals = [
+        ("_raw_columns", InternalTraversal.dp_clauseelement_list),
+        ("element", InternalTraversal.dp_clauseelement),
+    ] + Executable._executable_traverse_internals
 
-        populate_result_map = (
-            toplevel
-            # these two might not be needed
-            or (
-                compound_index == 0
-                and entry.get("need_result_map_for_compound", False)
+    _cache_key_traversal = _traverse_internals + [
+        ("_compile_options", InternalTraversal.dp_has_cache_key)
+    ]
+
+    def __init__(self, entities, element):
+        self._raw_columns = [
+            coercions.expect(
+                roles.ColumnsClauseRole,
+                ent,
+                apply_propagate_attrs=self,
+                post_inspect=True,
             )
-            or entry.get("need_result_map_for_nested", False)
+            for ent in util.to_list(entities)
+        ]
+        self.element = element
+        self.is_dml = element.is_dml
+        self._label_style = (
+            element._label_style if is_select_base(element) else None
         )
 
-        if populate_result_map:
-            compiler._ordered_columns = (
-                compiler._textual_ordered_columns
-            ) = False
+    def _compiler_dispatch(self, compiler, **kw):
 
-            # enable looser result column matching.  this is shown to be
-            # needed by test_query.py::TextTest
-            compiler._loose_column_name_matching = True
+        """provide a fixed _compiler_dispatch method.
 
-            for c in self.column_args:
-                compiler.process(
-                    c,
-                    within_columns_clause=True,
-                    add_to_result_map=compiler._add_to_result_map,
-                )
-        return compiler.process(self.element, **kw)
+        This is roughly similar to using the sqlalchemy.ext.compiler
+        ``@compiles`` extension.
+
+        """
+
+        compile_state = self._compile_state_factory(self, compiler, **kw)
+
+        toplevel = not compiler.stack
+
+        if toplevel:
+            compiler.compile_state = compile_state
+
+        return compiler.process(compile_state.statement, **kw)
+
+    def _ensure_disambiguated_names(self):
+        return self
+
+    def get_children(self, **kw):
+        for elem in itertools.chain.from_iterable(
+            element._from_objects for element in self._raw_columns
+        ):
+            yield elem
+        for elem in super(FromStatement, self).get_children(**kw):
+            yield elem
+
+    @property
+    def _all_selected_columns(self):
+        return self.element._all_selected_columns
+
+    @property
+    def _returning(self):
+        return self.element._returning if is_dml(self.element) else None
+
+    @property
+    def _inline(self):
+        return self.element._inline if is_insert_update(self.element) else None
 
 
 @sql.base.CompileState.plugin_for("orm", "select")
@@ -633,7 +713,6 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
         self._entities = []
         self._primary_entity = None
         self._polymorphic_adapters = {}
-        self._no_yield_pers = set()
 
         self.compile_options = select_statement._compile_options
 
@@ -2059,6 +2138,9 @@ class _QueryEntity:
     _null_column_type = False
     use_id_for_hash = False
 
+    def setup_compile_state(self, compile_state: ORMCompileState) -> None:
+        raise NotImplementedError()
+
     @classmethod
     def to_compile_state(
         cls, compile_state, entities, entities_collection, is_current_entities
index b9ced44d53a7f64bc1b8fd2986982185b22b3d34..a754bd4f2a6435bf86da32aa68adab22d0106f13 100644 (file)
@@ -21,7 +21,6 @@ database to return iterable result sets.
 from __future__ import annotations
 
 import collections.abc as collections_abc
-import itertools
 import operator
 from typing import Any
 from typing import Generic
@@ -40,9 +39,9 @@ from .base import _assertions
 from .context import _column_descriptions
 from .context import _determine_last_joined_entity
 from .context import _legacy_filter_by_entity_zero
+from .context import FromStatement
 from .context import LABEL_STYLE_LEGACY_ORM
 from .context import ORMCompileState
-from .context import ORMFromStatementCompileState
 from .context import QueryContext
 from .interfaces import ORMColumnDescription
 from .interfaces import ORMColumnsClauseRole
@@ -71,14 +70,10 @@ from ..sql.expression import Exists
 from ..sql.selectable import _MemoizedSelectEntities
 from ..sql.selectable import _SelectFromElements
 from ..sql.selectable import ForUpdateArg
-from ..sql.selectable import GroupedElement
 from ..sql.selectable import HasHints
 from ..sql.selectable import HasPrefixes
 from ..sql.selectable import HasSuffixes
 from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
-from ..sql.selectable import SelectBase
-from ..sql.selectable import SelectStatementGrouping
-from ..sql.visitors import InternalTraversal
 
 if TYPE_CHECKING:
     from ..sql.selectable import _SetupJoinsElement
@@ -2765,91 +2760,6 @@ class Query(
         return context
 
 
-class FromStatement(GroupedElement, SelectBase, Executable):
-    """Core construct that represents a load of ORM objects from a finished
-    select or text construct.
-
-    """
-
-    __visit_name__ = "orm_from_statement"
-
-    _compile_options = ORMFromStatementCompileState.default_compile_options
-
-    _compile_state_factory = ORMFromStatementCompileState.create_for_statement
-
-    _for_update_arg = None
-
-    _traverse_internals = [
-        ("_raw_columns", InternalTraversal.dp_clauseelement_list),
-        ("element", InternalTraversal.dp_clauseelement),
-    ] + Executable._executable_traverse_internals
-
-    _cache_key_traversal = _traverse_internals + [
-        ("_compile_options", InternalTraversal.dp_has_cache_key)
-    ]
-
-    def __init__(self, entities, element):
-        self._raw_columns = [
-            coercions.expect(
-                roles.ColumnsClauseRole,
-                ent,
-                apply_propagate_attrs=self,
-                post_inspect=True,
-            )
-            for ent in util.to_list(entities)
-        ]
-        self.element = element
-
-    def get_label_style(self):
-        return self._label_style
-
-    def set_label_style(self, label_style):
-        return SelectStatementGrouping(
-            self.element.set_label_style(label_style)
-        )
-
-    @property
-    def _label_style(self):
-        return self.element._label_style
-
-    def _compiler_dispatch(self, compiler, **kw):
-
-        """provide a fixed _compiler_dispatch method.
-
-        This is roughly similar to using the sqlalchemy.ext.compiler
-        ``@compiles`` extension.
-
-        """
-
-        compile_state = self._compile_state_factory(self, compiler, **kw)
-
-        toplevel = not compiler.stack
-
-        if toplevel:
-            compiler.compile_state = compile_state
-
-        return compiler.process(compile_state.statement, **kw)
-
-    def _ensure_disambiguated_names(self):
-        return self
-
-    def get_children(self, **kw):
-        for elem in itertools.chain.from_iterable(
-            element._from_objects for element in self._raw_columns
-        ):
-            yield elem
-        for elem in super(FromStatement, self).get_children(**kw):
-            yield elem
-
-    @property
-    def _returning(self):
-        return self.element._returning if self.element.is_dml else None
-
-    @property
-    def _inline(self):
-        return self.element._inline if self.element.is_dml else None
-
-
 class AliasOption(interfaces.LoaderOption):
     inherit_cache = False
 
index 0a72a93c5ca3865ee942520fb81547b48d45b526..bc1e0672c4cfe779e3d0d8399bce0cd06a2976f3 100644 (file)
@@ -14,6 +14,11 @@ from ..util.typing import Literal
 from ..util.typing import Protocol
 
 if TYPE_CHECKING:
+    from .compiler import Compiled
+    from .compiler import DDLCompiler
+    from .compiler import SQLCompiler
+    from .dml import UpdateBase
+    from .dml import ValuesBase
     from .elements import ClauseElement
     from .elements import ColumnClause
     from .elements import ColumnElement
@@ -38,6 +43,7 @@ if TYPE_CHECKING:
     from .type_api import TypeEngine
     from ..util.typing import TypeGuard
 
+
 _T = TypeVar("_T", bound=Any)
 
 
@@ -153,6 +159,12 @@ _TypeEngineArgument = Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"]
 
 if TYPE_CHECKING:
 
+    def is_sql_compiler(c: Compiled) -> TypeGuard[SQLCompiler]:
+        ...
+
+    def is_ddl_compiler(c: Compiled) -> TypeGuard[DDLCompiler]:
+        ...
+
     def is_named_from_clause(t: FromClauseRole) -> TypeGuard[NamedFromClause]:
         ...
 
@@ -183,7 +195,13 @@ if TYPE_CHECKING:
     def is_subquery(t: FromClause) -> TypeGuard[Subquery]:
         ...
 
+    def is_dml(c: ClauseElement) -> TypeGuard[UpdateBase]:
+        ...
+
 else:
+
+    is_sql_compiler = operator.attrgetter("is_sql")
+    is_ddl_compiler = operator.attrgetter("is_ddl")
     is_named_from_clause = operator.attrgetter("named_with_column")
     is_column_element = operator.attrgetter("_is_column_element")
     is_text_clause = operator.attrgetter("_is_text_clause")
@@ -194,6 +212,7 @@ else:
     is_select_statement = operator.attrgetter("_is_select_statement")
     is_table = operator.attrgetter("_is_table")
     is_subquery = operator.attrgetter("_is_subquery")
+    is_dml = operator.attrgetter("is_dml")
 
 
 def has_schema_attr(t: FromClauseRole) -> TypeGuard[TableClause]:
@@ -206,3 +225,7 @@ def is_quoted_name(s: str) -> TypeGuard[quoted_name]:
 
 def is_has_clause_element(s: object) -> TypeGuard[_HasClauseElement]:
     return hasattr(s, "__clause_element__")
+
+
+def is_insert_update(c: ClauseElement) -> TypeGuard[ValuesBase]:
+    return c.is_dml and (c.is_insert or c.is_update)  # type: ignore
index 6b25d8fcd59446a822285c1a1b2de497d4e5596c..f766a5ac50756ac4c09b83823b1aa29f1b69c725 100644 (file)
@@ -937,7 +937,7 @@ class Executable(roles.StatementRole, Generative):
     _with_context_options: Tuple[
         Tuple[Callable[[CompileState], None], Any], ...
     ] = ()
-    _compile_options: Optional[CacheableOptions]
+    _compile_options: Optional[Union[Type[CacheableOptions], CacheableOptions]]
 
     _executable_traverse_internals = [
         ("_with_options", InternalTraversal.dp_executable_options),
@@ -982,7 +982,7 @@ class Executable(roles.StatementRole, Generative):
         ) -> Result:
             ...
 
-    @util.non_memoized_property
+    @util.ro_non_memoized_property
     def _all_selected_columns(self):
         raise NotImplementedError()
 
index 6ecfbf986682ae9ed9af420513f18a2e0b3fdf4d..522a0bd4a0830b6e6b135897c69a5259bf071ba3 100644 (file)
@@ -61,6 +61,8 @@ from . import operators
 from . import schema
 from . import selectable
 from . import sqltypes
+from ._typing import is_column_element
+from ._typing import is_dml
 from .base import _from_objects
 from .base import Executable
 from .base import NO_ARG
@@ -90,6 +92,7 @@ if typing.TYPE_CHECKING:
     from .elements import _truncated_label
     from .elements import BindParameter
     from .elements import ColumnClause
+    from .elements import ColumnElement
     from .elements import Label
     from .functions import Function
     from .selectable import Alias
@@ -492,6 +495,9 @@ class Compiled:
     defaults.
     """
 
+    is_sql = False
+    is_ddl = False
+
     _cached_metadata: Optional[CursorResultMetaData] = None
 
     _result_columns: Optional[List[ResultColumnsEntry]] = None
@@ -701,6 +707,8 @@ class SQLCompiler(Compiled):
 
     extract_map = EXTRACT_MAP
 
+    is_sql = True
+
     _result_columns: List[ResultColumnsEntry]
 
     compound_keywords = COMPOUND_KEYWORDS
@@ -725,9 +733,14 @@ class SQLCompiler(Compiled):
     """list of columns for which onupdate default values should be evaluated
     before an UPDATE takes place"""
 
-    returning: Optional[Sequence[roles.ColumnsClauseRole]]
-    """list of columns that will be delivered to cursor.description or
-    dialect equivalent via the RETURNING clause on an INSERT, UPDATE, or DELETE
+    implicit_returning: Optional[Sequence[ColumnElement[Any]]] = None
+    """list of "implicit" returning columns for a toplevel INSERT or UPDATE
+    statement, used to receive newly generated values of columns.
+
+    .. versionadded:: 2.0  ``implicit_returning`` replaces the previous
+       ``returning`` collection, which was not a generalized RETURNING
+       collection and instead was in fact specific to the "implicit returning"
+       feature.
 
     """
 
@@ -750,12 +763,6 @@ class SQLCompiler(Compiled):
     TypeEngine. CursorResult uses this for type processing and
     column targeting"""
 
-    returning = None
-    """holds the "returning" collection of columns if
-    the statement is CRUD and defines returning columns
-    either implicitly or explicitly
-    """
-
     returning_precedes_values: bool = False
     """set to True classwide to generate RETURNING
     clauses before the VALUES or WHERE clause (i.e. MSSQL)
@@ -978,9 +985,6 @@ class SQLCompiler(Compiled):
             if TYPE_CHECKING:
                 assert isinstance(statement, UpdateBase)
 
-            if statement._returning:
-                self.returning = statement._returning
-
             if self.isinsert or self.isupdate:
                 if TYPE_CHECKING:
                     assert isinstance(statement, ValuesBase)
@@ -1001,6 +1005,39 @@ class SQLCompiler(Compiled):
         if self._render_postcompile:
             self._process_parameters_for_postcompile(_populate_self=True)
 
+    @util.ro_memoized_property
+    def effective_returning(self) -> Optional[Sequence[ColumnElement[Any]]]:
+        """The effective "returning" columns for INSERT, UPDATE or DELETE.
+
+        This is either the so-called "implicit returning" columns which are
+        calculated by the compiler on the fly, or those present based on what's
+        present in ``self.statement._returning`` (expanded into individual
+        columns using the ``._all_selected_columns`` attribute) i.e. those set
+        explicitly using the :meth:`.UpdateBase.returning` method.
+
+        .. versionadded:: 2.0
+
+        """
+        if self.implicit_returning:
+            return self.implicit_returning
+        elif is_dml(self.statement):
+            return [
+                c
+                for c in self.statement._all_selected_columns
+                if is_column_element(c)
+            ]
+
+        else:
+            return None
+
+    @property
+    def returning(self):
+        """backwards compatibility; returns the
+        effective_returning collection.
+
+        """
+        return self.effective_returning
+
     @property
     def current_executable(self):
         """Return the current 'executable' that is being compiled.
@@ -1569,7 +1606,7 @@ class SQLCompiler(Compiled):
         param_key_getter = self._within_exec_param_key_getter
         table = self.statement.table
 
-        returning = self.returning
+        returning = self.implicit_returning
         assert returning is not None
         ret = {col: idx for idx, col in enumerate(returning)}
 
@@ -3373,7 +3410,9 @@ class SQLCompiler(Compiled):
             ResultColumnsEntry(keyname, name, objects, type_)
         )
 
-    def _label_returning_column(self, stmt, column, column_clause_args=None):
+    def _label_returning_column(
+        self, stmt, column, populate_result_map, column_clause_args=None
+    ):
         """Render a column with necessary labels inside of a RETURNING clause.
 
         This method is provided for individual dialects in place of calling
@@ -3386,7 +3425,7 @@ class SQLCompiler(Compiled):
         return self._label_select_column(
             None,
             column,
-            True,
+            populate_result_map,
             False,
             {} if column_clause_args is None else column_clause_args,
         )
@@ -4103,7 +4142,10 @@ class SQLCompiler(Compiled):
     def returning_clause(
         self,
         stmt: UpdateBase,
-        returning_cols: Sequence[roles.ColumnsClauseRole],
+        returning_cols: Sequence[ColumnElement[Any]],
+        *,
+        populate_result_map: bool,
+        **kw: Any,
     ) -> str:
         raise exc.CompileError(
             "RETURNING is not supported by this "
@@ -4228,7 +4270,6 @@ class SQLCompiler(Compiled):
         return dialect_hints, table_text
 
     def visit_insert(self, insert_stmt, **kw):
-
         compile_state = insert_stmt._compile_state_factory(
             insert_stmt, self, **kw
         )
@@ -4250,7 +4291,7 @@ class SQLCompiler(Compiled):
         )
 
         crud_params_struct = crud._get_crud_params(
-            self, insert_stmt, compile_state, **kw
+            self, insert_stmt, compile_state, toplevel, **kw
         )
         crud_params_single = crud_params_struct.single_params
 
@@ -4303,9 +4344,11 @@ class SQLCompiler(Compiled):
                 [expr for _, expr, _ in crud_params_single]
             )
 
-        if self.returning or insert_stmt._returning:
+        if self.implicit_returning or insert_stmt._returning:
             returning_clause = self.returning_clause(
-                insert_stmt, self.returning or insert_stmt._returning
+                insert_stmt,
+                self.implicit_returning or insert_stmt._returning,
+                populate_result_map=toplevel,
             )
 
             if self.returning_precedes_values:
@@ -4449,7 +4492,7 @@ class SQLCompiler(Compiled):
             update_stmt, update_stmt.table, render_extra_froms, **kw
         )
         crud_params_struct = crud._get_crud_params(
-            self, update_stmt, compile_state, **kw
+            self, update_stmt, compile_state, toplevel, **kw
         )
         crud_params = crud_params_struct.single_params
 
@@ -4473,10 +4516,12 @@ class SQLCompiler(Compiled):
             )
         )
 
-        if self.returning or update_stmt._returning:
+        if self.implicit_returning or update_stmt._returning:
             if self.returning_precedes_values:
                 text += " " + self.returning_clause(
-                    update_stmt, self.returning or update_stmt._returning
+                    update_stmt,
+                    self.implicit_returning or update_stmt._returning,
+                    populate_result_map=toplevel,
                 )
 
         if extra_froms:
@@ -4502,10 +4547,12 @@ class SQLCompiler(Compiled):
             text += " " + limit_clause
 
         if (
-            self.returning or update_stmt._returning
+            self.implicit_returning or update_stmt._returning
         ) and not self.returning_precedes_values:
             text += " " + self.returning_clause(
-                update_stmt, self.returning or update_stmt._returning
+                update_stmt,
+                self.implicit_returning or update_stmt._returning,
+                populate_result_map=toplevel,
             )
 
         if self.ctes:
@@ -4585,7 +4632,9 @@ class SQLCompiler(Compiled):
         if delete_stmt._returning:
             if self.returning_precedes_values:
                 text += " " + self.returning_clause(
-                    delete_stmt, delete_stmt._returning
+                    delete_stmt,
+                    delete_stmt._returning,
+                    populate_result_map=toplevel,
                 )
 
         if extra_froms:
@@ -4608,7 +4657,9 @@ class SQLCompiler(Compiled):
 
         if delete_stmt._returning and not self.returning_precedes_values:
             text += " " + self.returning_clause(
-                delete_stmt, delete_stmt._returning
+                delete_stmt,
+                delete_stmt._returning,
+                populate_result_map=toplevel,
             )
 
         if self.ctes:
@@ -4685,7 +4736,14 @@ class StrSQLCompiler(SQLCompiler):
     def visit_sequence(self, seq, **kw):
         return "<next sequence value: %s>" % self.preparer.format_sequence(seq)
 
-    def returning_clause(self, stmt, returning_cols):
+    def returning_clause(
+        self,
+        stmt: UpdateBase,
+        returning_cols: Sequence[ColumnElement[Any]],
+        *,
+        populate_result_map: bool,
+        **kw: Any,
+    ) -> str:
         columns = [
             self._label_select_column(None, c, True, False, {})
             for c in base._select_iterables(returning_cols)
@@ -4733,6 +4791,8 @@ class StrSQLCompiler(SQLCompiler):
 
 
 class DDLCompiler(Compiled):
+    is_ddl = True
+
     if TYPE_CHECKING:
 
         def __init__(
index 91a3f70c91fb4a08fa17a829cb24a466377cbda9..f6db2c4b25837c4c5f912f888513b495ad13a5af 100644 (file)
@@ -87,6 +87,7 @@ def _get_crud_params(
     compiler: SQLCompiler,
     stmt: ValuesBase,
     compile_state: DMLState,
+    toplevel: bool,
     **kw: Any,
 ) -> _CrudParams:
     """create a set of tuples representing column/string pairs for use
@@ -99,10 +100,33 @@ def _get_crud_params(
 
     """
 
+    # note: the _get_crud_params() system was written with the notion in mind
+    # that INSERT, UPDATE, DELETE are always the top level statement and
+    # that there is only one of them.  With the addition of CTEs that can
+    # make use of DML, this assumption is no longer accurate; the DML
+    # statement is not necessarily the top-level "row returning" thing
+    # and it is also theoretically possible (fortunately nobody has asked yet)
+    # to have a single statement with multiple DMLs inside of it via CTEs.
+
+    # the current _get_crud_params() design doesn't accommodate these cases
+    # right now.  It "just works" for a CTE that has a single DML inside of
+    # it, and for a CTE with multiple DML, it's not clear what would happen.
+
+    # overall, the "compiler.XYZ" collections here would need to be in a
+    # per-DML structure of some kind, and DefaultDialect would need to
+    # navigate these collections on a per-statement basis, with additional
+    # emphasis on the "toplevel returning data" statement.  However we
+    # still need to run through _get_crud_params() for all DML as we have
+    # Python / SQL generated column defaults that need to be rendered.
+
+    # if there is user need for this kind of thing, it's likely a post 2.0
+    # kind of change as it would require deep changes to DefaultDialect
+    # as well as here.
+
     compiler.postfetch = []
     compiler.insert_prefetch = []
     compiler.update_prefetch = []
-    compiler.returning = []
+    compiler.implicit_returning = []
 
     # getters - these are normally just column.key,
     # but in the case of mysql multi-table update, the rules for
@@ -213,6 +237,7 @@ def _get_crud_params(
             _col_bind_name,
             check_columns,
             values,
+            toplevel,
             kw,
         )
     else:
@@ -226,6 +251,7 @@ def _get_crud_params(
             _col_bind_name,
             check_columns,
             values,
+            toplevel,
             kw,
         )
 
@@ -419,6 +445,7 @@ def _scan_insert_from_select_cols(
     _col_bind_name,
     check_columns,
     values,
+    toplevel,
     kw,
 ):
 
@@ -427,7 +454,7 @@ def _scan_insert_from_select_cols(
         implicit_returning,
         implicit_return_defaults,
         postfetch_lastrowid,
-    ) = _get_returning_modifiers(compiler, stmt, compile_state)
+    ) = _get_returning_modifiers(compiler, stmt, compile_state, toplevel)
 
     cols = [stmt.table.c[_column_as_key(name)] for name in stmt._select_names]
 
@@ -472,6 +499,7 @@ def _scan_cols(
     _col_bind_name,
     check_columns,
     values,
+    toplevel,
     kw,
 ):
     (
@@ -479,7 +507,7 @@ def _scan_cols(
         implicit_returning,
         implicit_return_defaults,
         postfetch_lastrowid,
-    ) = _get_returning_modifiers(compiler, stmt, compile_state)
+    ) = _get_returning_modifiers(compiler, stmt, compile_state, toplevel)
 
     if compile_state._parameter_ordering:
         parameter_ordering = [
@@ -556,11 +584,11 @@ def _scan_cols(
                 # column has a DDL-level default, and is either not a pk
                 # column or we don't need the pk.
                 if implicit_return_defaults and c in implicit_return_defaults:
-                    compiler.returning.append(c)
+                    compiler.implicit_returning.append(c)
                 elif not c.primary_key:
                     compiler.postfetch.append(c)
             elif implicit_return_defaults and c in implicit_return_defaults:
-                compiler.returning.append(c)
+                compiler.implicit_returning.append(c)
             elif (
                 c.primary_key
                 and c is not stmt.table._autoincrement_column
@@ -628,7 +656,7 @@ def _append_param_parameter(
 
         if compile_state.isupdate:
             if implicit_return_defaults and c in implicit_return_defaults:
-                compiler.returning.append(c)
+                compiler.implicit_returning.append(c)
 
             else:
                 compiler.postfetch.append(c)
@@ -636,12 +664,12 @@ def _append_param_parameter(
             if c.primary_key:
 
                 if implicit_returning:
-                    compiler.returning.append(c)
+                    compiler.implicit_returning.append(c)
                 elif compiler.dialect.postfetch_lastrowid:
                     compiler.postfetch_lastrowid = True
 
             elif implicit_return_defaults and c in implicit_return_defaults:
-                compiler.returning.append(c)
+                compiler.implicit_returning.append(c)
 
             else:
                 # postfetch specifically means, "we can SELECT the row we just
@@ -674,7 +702,7 @@ def _append_param_insert_pk_returning(compiler, stmt, c, values, kw):
                         compiler.process(c.default, **kw),
                     )
                 )
-            compiler.returning.append(c)
+            compiler.implicit_returning.append(c)
         elif c.default.is_clause_element:
             values.append(
                 (
@@ -683,7 +711,7 @@ def _append_param_insert_pk_returning(compiler, stmt, c, values, kw):
                     compiler.process(c.default.arg.self_group(), **kw),
                 )
             )
-            compiler.returning.append(c)
+            compiler.implicit_returning.append(c)
         else:
             # client side default.  OK we can't use RETURNING, need to
             # do a "prefetch", which in fact fetches the default value
@@ -696,7 +724,7 @@ def _append_param_insert_pk_returning(compiler, stmt, c, values, kw):
                 )
             )
     elif c is stmt.table._autoincrement_column or c.server_default is not None:
-        compiler.returning.append(c)
+        compiler.implicit_returning.append(c)
     elif not c.nullable:
         # no .default, no .server_default, not autoincrement, we have
         # no indication this primary key column will have any value
@@ -794,7 +822,7 @@ def _append_param_insert_hasdefault(
                 )
             )
             if implicit_return_defaults and c in implicit_return_defaults:
-                compiler.returning.append(c)
+                compiler.implicit_returning.append(c)
             elif not c.primary_key:
                 compiler.postfetch.append(c)
     elif c.default.is_clause_element:
@@ -807,7 +835,7 @@ def _append_param_insert_hasdefault(
         )
 
         if implicit_return_defaults and c in implicit_return_defaults:
-            compiler.returning.append(c)
+            compiler.implicit_returning.append(c)
         elif not c.primary_key:
             # don't add primary key column to postfetch
             compiler.postfetch.append(c)
@@ -870,7 +898,7 @@ def _append_param_update(
                 )
             )
             if implicit_return_defaults and c in implicit_return_defaults:
-                compiler.returning.append(c)
+                compiler.implicit_returning.append(c)
             else:
                 compiler.postfetch.append(c)
         else:
@@ -886,7 +914,7 @@ def _append_param_update(
             )
     elif c.server_onupdate is not None:
         if implicit_return_defaults and c in implicit_return_defaults:
-            compiler.returning.append(c)
+            compiler.implicit_returning.append(c)
         else:
             compiler.postfetch.append(c)
     elif (
@@ -894,7 +922,7 @@ def _append_param_update(
         and (stmt._return_defaults_columns or not stmt._return_defaults)
         and c in implicit_return_defaults
     ):
-        compiler.returning.append(c)
+        compiler.implicit_returning.append(c)
 
 
 @overload
@@ -1195,10 +1223,11 @@ def _get_stmt_parameter_tuples_params(
             values.append((k, col_expr, v))
 
 
-def _get_returning_modifiers(compiler, stmt, compile_state):
+def _get_returning_modifiers(compiler, stmt, compile_state, toplevel):
 
     need_pks = (
-        compile_state.isinsert
+        toplevel
+        and compile_state.isinsert
         and not stmt._inline
         and (
             not compiler.for_executemany
index 8a3a1b38f0cf6c8bdc115b32f97e7d00cc9f5665..f23ba2e6e2e3c27e16da8614665c246089156389 100644 (file)
@@ -463,11 +463,18 @@ class UpdateBase(
         )
         return self
 
-    @util.non_memoized_property
+    def corresponding_column(
+        self, column: ColumnElement[Any], require_embedded: bool = False
+    ) -> Optional[ColumnElement[Any]]:
+        return self.exported_columns.corresponding_column(
+            column, require_embedded=require_embedded
+        )
+
+    @util.ro_memoized_property
     def _all_selected_columns(self) -> _SelectIterable:
         return [c for c in _select_iterables(self._returning)]
 
-    @property
+    @util.ro_memoized_property
     def exported_columns(
         self,
     ) -> ReadOnlyColumnCollection[Optional[str], ColumnElement[Any]]:
index aec29d1b2e2107c48d1c24952061a7f6993191e6..77813b98f783062f8df08a8d86ce3df15e169135 100644 (file)
@@ -305,6 +305,7 @@ class ClauseElement(
 
     is_clause_element = True
     is_selectable = False
+    is_dml = False
     _is_column_element = False
     _is_table = False
     _is_textual = False
index 6504449f1dc49d447304e665bcfb7f41df37e245..292225ce2e529f7c1f8796e7099f2eb1ec78652f 100644 (file)
@@ -127,6 +127,9 @@ if TYPE_CHECKING:
 
 
 _ColumnsClauseElement = Union["FromClause", ColumnElement[Any], "TextClause"]
+_LabelConventionCallable = Callable[
+    [Union["ColumnElement[Any]", "TextClause"]], Optional[str]
+]
 
 
 class _JoinTargetProtocol(Protocol):
@@ -183,7 +186,7 @@ class ReturnsRows(roles.ReturnsRowsRole, DQLDMLClauseElement):
     def selectable(self) -> ReturnsRows:
         return self
 
-    @util.non_memoized_property
+    @util.ro_non_memoized_property
     def _all_selected_columns(self) -> _SelectIterable:
         """A sequence of column expression objects that represents the
         "selected" columns of this :class:`_expression.ReturnsRows`.
@@ -3277,7 +3280,7 @@ class SelectBase(
         """
         raise NotImplementedError()
 
-    @util.non_memoized_property
+    @util.ro_non_memoized_property
     def _all_selected_columns(self) -> _SelectIterable:
         """A sequence of expressions that correspond to what is rendered
         in the columns clause, including :class:`_sql.TextClause`
@@ -3586,7 +3589,7 @@ class SelectStatementGrouping(GroupedElement, SelectBase):
     ) -> None:
         self.element._generate_fromclause_column_proxies(subquery)
 
-    @util.non_memoized_property
+    @util.ro_non_memoized_property
     def _all_selected_columns(self) -> _SelectIterable:
         return self.element._all_selected_columns
 
@@ -4297,7 +4300,7 @@ class CompoundSelect(HasCompileState, GenerativeSelect):
         for select in self.selects:
             select._refresh_for_new_column(column)
 
-    @util.non_memoized_property
+    @util.ro_non_memoized_property
     def _all_selected_columns(self) -> _SelectIterable:
         return self.selects[0]._all_selected_columns
 
@@ -4408,7 +4411,7 @@ class SelectState(util.MemoizedSlots, CompileState):
     @classmethod
     def _column_naming_convention(
         cls, label_style: SelectLabelStyle
-    ) -> Callable[[Union[ColumnElement[Any], TextClause]], Optional[str]]:
+    ) -> _LabelConventionCallable:
 
         table_qualified = label_style is LABEL_STYLE_TABLENAME_PLUS_COL
         dedupe = label_style is not LABEL_STYLE_NONE
@@ -5984,7 +5987,7 @@ class Select(
         )
         return cc.as_readonly()
 
-    @HasMemoized.memoized_attribute
+    @HasMemoized_ro_memoized_attribute
     def _all_selected_columns(self) -> _SelectIterable:
         meth = SelectState.get_plugin_class(self).all_selected_columns
         return list(meth(self))
@@ -6537,14 +6540,7 @@ class TextualSelect(SelectBase):
             (c.key, c) for c in self.column_args
         ).as_readonly()
 
-    # def _generate_columns_plus_names(
-    #    self, anon_for_dupe_key: bool
-    # ) -> List[Tuple[str, str, str, ColumnElement[Any], bool]]:
-    #    return Select._generate_columns_plus_names(
-    #        self, anon_for_dupe_key=anon_for_dupe_key
-    #    )
-
-    @util.non_memoized_property
+    @util.ro_non_memoized_property
     def _all_selected_columns(self) -> _SelectIterable:
         return self.column_args
 
index 02a1ec8aab23b8e46daeeee479e3a64c994dba5c..5ef2c6f22c8b8772cf309790c644a0509060bc24 100644 (file)
--- a/setup.cfg
+++ b/setup.cfg
@@ -162,5 +162,4 @@ mariadb_connector = mariadb+mariadbconnector://scott:tiger@127.0.0.1:3306/test
 mssql = mssql+pyodbc://scott:tiger^5HHH@mssql2017:1433/test?driver=ODBC+Driver+13+for+SQL+Server
 mssql_pymssql = mssql+pymssql://scott:tiger@ms_2008
 docker_mssql = mssql+pymssql://scott:tiger^5HHH@127.0.0.1:1433/test
-oracle = oracle+cx_oracle://scott:tiger@127.0.0.1:1521
-oracle8 = oracle+cx_oracle://scott:tiger@127.0.0.1:1521/?use_ansi=0
+oracle = oracle+cx_oracle://scott:tiger@oracle18c
index e827fa56c04b8c9849527be477100b2970bd49b4..60cdb2577aeffa41c7a6e5340e5b218e7f38860b 100644 (file)
@@ -31,6 +31,7 @@ from sqlalchemy.testing import config
 from sqlalchemy.testing import engines
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import fixtures
+from sqlalchemy.testing.assertions import expect_raises_message
 from sqlalchemy.testing.schema import Column
 from sqlalchemy.testing.schema import Table
 from sqlalchemy.testing.suite import test_select
@@ -54,7 +55,7 @@ class DialectTest(fixtures.TestBase):
         ):
             assert_raises_message(
                 exc.InvalidRequestError,
-                "cx_Oracle version 5.2 and above are supported",
+                "cx_Oracle version 7 and above are supported",
                 cx_oracle.OracleDialect_cx_oracle,
                 dbapi=mock.Mock(),
             )
@@ -62,7 +63,7 @@ class DialectTest(fixtures.TestBase):
         with mock.patch(
             "sqlalchemy.dialects.oracle.cx_oracle.OracleDialect_cx_oracle."
             "_parse_cx_oracle_ver",
-            lambda self, vers: (5, 3, 1),
+            lambda self, vers: (7, 1, 0),
         ):
             cx_oracle.OracleDialect_cx_oracle(dbapi=mock.Mock())
 
@@ -301,19 +302,6 @@ class EncodingErrorsTest(fixtures.TestBase):
         else:
             assert_raises(UnicodeDecodeError, outconverter, utf8_w_errors)
 
-    @_oracle_char_combinations
-    def test_older_cx_oracle_warning(self, cx_Oracle, cx_oracle_type):
-        cx_Oracle.version = "6.3"
-
-        with testing.expect_warnings(
-            r"cx_oracle version \(6, 3\) does not support encodingErrors"
-        ):
-            dialect = cx_oracle.dialect(
-                dbapi=cx_Oracle, encoding_errors="ignore"
-            )
-
-            eq_(dialect._cursor_var_unicode_kwargs, {})
-
     @_oracle_char_combinations
     def test_encoding_errors_cx_oracle(
         self,
@@ -478,6 +466,23 @@ end;
         eq_(result.out_parameters, {"x_out": 10, "y_out": 75, "z_out": None})
         assert isinstance(result.out_parameters["x_out"], int)
 
+    def test_no_out_params_w_returning(self, connection, metadata):
+        t = Table("t", metadata, Column("x", Integer), Column("y", Integer))
+        metadata.create_all(connection)
+        stmt = (
+            t.insert()
+            .values(x=5, y=10)
+            .returning(outparam("my_param", Integer), t.c.x)
+        )
+
+        with expect_raises_message(
+            exc.InvalidRequestError,
+            r"Using explicit outparam\(\) objects with "
+            r"UpdateBase.returning\(\) in the same Core DML statement "
+            "is not supported in the Oracle dialect.",
+        ):
+            connection.execute(stmt)
+
     @classmethod
     def teardown_test_class(cls):
         with testing.db.begin() as conn:
index e1c610701cbd0bd70fc892709f81d84bfc0dff32..8f0a4b02294b02f172afcef0c3bd265d3d283cc1 100644 (file)
@@ -483,15 +483,15 @@ class ImplicitReturningFlagTest(fixtures.TestBase):
                     ):
                         stmt.compile(conn)
                 else:
-                    eq_(stmt.compile(conn).returning, [t.c.id])
+                    eq_(stmt.compile(conn).implicit_returning, [t.c.id])
             elif (
                 implicit_returning is None
                 and testing.db.dialect.implicit_returning
             ):
-                eq_(stmt.compile(conn).returning, [t.c.id])
+                eq_(stmt.compile(conn).implicit_returning, [t.c.id])
             else:
-                eq_(stmt.compile(conn).returning, [])
+                eq_(stmt.compile(conn).implicit_returning, [])
 
             # table setting it to False disables it
             stmt2 = insert(t2).values(data="data")
-            eq_(stmt2.compile(conn).returning, [])
+            eq_(stmt2.compile(conn).implicit_returning, [])
index 0b079055cd2e565a009e3ca061d532f6de390087..c3e5b586a6e3b2b431818a05a75616a6e01c820f 100644 (file)
@@ -224,8 +224,10 @@ class DeferredReflectionTest(DeferredReflectBase):
         eq_(len(_DeferredMapperConfig._configs), 2)
         del Address
         gc_collect()
+        gc_collect()
         eq_(len(_DeferredMapperConfig._configs), 1)
         DeferredReflection.prepare(testing.db)
+        gc_collect()
         assert not _DeferredMapperConfig._configs
 
 
index 8147589368a53ff9a98edd73dd378dae88c4b138..8e568aef05b28c52f62f8d6f54bc2f6333a43f63 100644 (file)
@@ -1585,6 +1585,8 @@ class WeakIdentityMapTest(_fixtures.FixtureTest):
         user_is = user._sa_instance_state
         del user
         gc_collect()
+        gc_collect()
+        gc_collect()
         assert user_is.obj() is None
 
         assert len(s.identity_map) == 0
index bc84d444758ad7eb48774b639cdb53fa90930617..ec4fb2d68296a2f390d3d75297d88336eb189035 100644 (file)
@@ -1267,6 +1267,8 @@ class AutoExpireTest(_LocalFixture):
         assert u1_state not in s._deleted
         del u1
         gc_collect()
+        gc_collect()
+        gc_collect()
         assert u1_state.obj() is None
 
         s.rollback()
index 74a60bd215a4bdf2ba461c268ea6acfc08136174..5f02fde4c94fbbc692dfac1a7ea158694fe26b47 100644 (file)
@@ -1423,7 +1423,7 @@ class MultirowTest(_InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL):
         stmt = table.insert().return_defaults().values(id=func.foobar())
         compiled = stmt.compile(dialect=sqlite.dialect(), column_keys=["data"])
         eq_(compiled.postfetch, [])
-        eq_(compiled.returning, [])
+        eq_(compiled.implicit_returning, [])
 
         self.assert_compile(
             stmt,
@@ -1452,7 +1452,7 @@ class MultirowTest(_InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL):
             dialect=returning_dialect, column_keys=["data"]
         )
         eq_(compiled.postfetch, [])
-        eq_(compiled.returning, [table.c.id])
+        eq_(compiled.implicit_returning, [table.c.id])
 
         self.assert_compile(
             stmt,
@@ -1482,7 +1482,7 @@ class MultirowTest(_InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL):
             dialect=returning_dialect, column_keys=["data"]
         )
         eq_(compiled.postfetch, [])
-        eq_(compiled.returning, [table.c.id])
+        eq_(compiled.implicit_returning, [table.c.id])
 
         self.assert_compile(
             stmt,