From: Mike Bayer Date: Mon, 4 Apr 2022 14:13:23 +0000 (-0400) Subject: cx_Oracle modernize X-Git-Tag: rel_2_0_0b1~370^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=2acc9ec1281b2818bd44804f040d94ec46215688;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git cx_Oracle modernize Full "RETURNING" support is implemented for the cx_Oracle dialect, meaning multiple RETURNING rows are now recived for DML statements that produce more than one row for RETURNING. cx_Oracle 7 is now the minimum version for cx_Oracle. Getting Oracle to do multirow returning took about 5 minutes. however, getting Oracle's RETURNING system to integrate with ORM-enabled insert, update, delete, is a big deal because that architecture wasn't really working very robustly, including some recent changes in 1.4 for FromStatement were done in a hurry, so this patch also cleans up the FromStatement situation and begins to establish it more concretely as the base for all ReturnsRows / TextClause ORM scenarios. Fixes: #6245 Change-Id: I2b4e6007affa51ce311d2d5baa3917f356ab961f --- diff --git a/doc/build/changelog/unreleased_20/6245.rst b/doc/build/changelog/unreleased_20/6245.rst new file mode 100644 index 0000000000..1247544f1c --- /dev/null +++ b/doc/build/changelog/unreleased_20/6245.rst @@ -0,0 +1,13 @@ +.. change:: + :tags: oracle, feature + :tickets: 6245 + + Full "RETURNING" support is implemented for the cx_Oracle dialect, meaning + multiple RETURNING rows are now recived for DML statements that produce + more than one row for RETURNING. + + +.. change:: + :tags: oracle + + cx_Oracle 7 is now the minimum version for cx_Oracle. diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index 07ff495a71..ac02b98a04 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -805,10 +805,13 @@ https://msdn.microsoft.com/en-us/library/ms175095.aspx. """ # noqa +from __future__ import annotations + import codecs import datetime import operator import re +from typing import TYPE_CHECKING from . import information_schema as ischema from .json import JSON @@ -833,6 +836,7 @@ from ...sql import func from ...sql import quoted_name from ...sql import roles from ...sql import util as sql_util +from ...sql._typing import is_sql_compiler from ...types import BIGINT from ...types import BINARY from ...types import CHAR @@ -849,6 +853,10 @@ from ...types import TEXT from ...types import VARCHAR from ...util import update_wrapper +if TYPE_CHECKING: + from ...sql.compiler import SQLCompiler + from ...sql.dml import DMLState + from ...sql.selectable import TableClause # https://sqlserverbuilds.blogspot.com/ MS_2017_VERSION = (14,) @@ -1623,6 +1631,8 @@ class MSExecutionContext(default.DefaultExecutionContext): _lastrowid = None _rowcount = None + dialect: MSDialect + def _opt_encode(self, statement): if self.compiled and self.compiled.schema_translate_map: @@ -1636,13 +1646,20 @@ class MSExecutionContext(default.DefaultExecutionContext): """Activate IDENTITY_INSERT if needed.""" if self.isinsert: + if TYPE_CHECKING: + assert is_sql_compiler(self.compiled) + assert isinstance(self.compiled.compile_state, DMLState) + assert isinstance( + self.compiled.compile_state.dml_table, TableClause + ) + tbl = self.compiled.compile_state.dml_table id_column = tbl._autoincrement_column - insert_has_identity = (id_column is not None) and ( - not isinstance(id_column.default, Sequence) - ) - if insert_has_identity: + if id_column is not None and ( + not isinstance(id_column.default, Sequence) + ): + insert_has_identity = True compile_state = self.compiled.compile_state self._enable_identity_insert = ( id_column.key in self.compiled_parameters[0] @@ -1655,12 +1672,13 @@ class MSExecutionContext(default.DefaultExecutionContext): ) else: + insert_has_identity = False self._enable_identity_insert = False self._select_lastrowid = ( not self.compiled.inline and insert_has_identity - and not self.compiled.returning + and not self.compiled.effective_returning and not self._enable_identity_insert and not self.executemany ) @@ -1701,8 +1719,10 @@ class MSExecutionContext(default.DefaultExecutionContext): self._lastrowid = int(row[0]) elif ( - self.isinsert or self.isupdate or self.isdelete - ) and self.compiled.returning: + self.compiled is not None + and is_sql_compiler(self.compiled) + and self.compiled.effective_returning + ): self.cursor_fetch_strategy = ( _cursor.FullyBufferedCursorFetchStrategy( self.cursor, @@ -1712,6 +1732,12 @@ class MSExecutionContext(default.DefaultExecutionContext): ) if self._enable_identity_insert: + if TYPE_CHECKING: + assert is_sql_compiler(self.compiled) + assert isinstance(self.compiled.compile_state, DMLState) + assert isinstance( + self.compiled.compile_state.dml_table, TableClause + ) conn._cursor_execute( self.cursor, self._opt_encode( @@ -2065,17 +2091,21 @@ class MSSQLCompiler(compiler.SQLCompiler): ) return super(MSSQLCompiler, self).visit_binary(binary, **kwargs) - def returning_clause(self, stmt, returning_cols): + def returning_clause( + self, stmt, returning_cols, *, populate_result_map, **kw + ): # SQL server returning clause requires that the columns refer to # the virtual table names "inserted" or "deleted". Here, we make # a simple alias of our table with that name, and then adapt the # columns we have from the list of RETURNING columns to that new name # so that they render as "inserted." / "deleted.". - if self.isinsert or self.isupdate: + if stmt.is_insert or stmt.is_update: target = stmt.table.alias("inserted") - else: + elif stmt.is_delete: target = stmt.table.alias("deleted") + else: + assert False, "expected Insert, Update or Delete statement" adapter = sql_util.ClauseAdapter(target) @@ -2091,6 +2121,7 @@ class MSSQLCompiler(compiler.SQLCompiler): self._label_returning_column( stmt, adapter.traverse(c), + populate_result_map, {"result_map_targets": (c,)}, ) for c in expression._select_iterables(returning_cols) diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py index 3ee38c0cf1..39a542cce8 100644 --- a/lib/sqlalchemy/dialects/oracle/base.py +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -1089,7 +1089,9 @@ class OracleCompiler(compiler.SQLCompiler): return " " + alias_name_text - def returning_clause(self, stmt, returning_cols): + def returning_clause( + self, stmt, returning_cols, *, populate_result_map, **kw + ): columns = [] binds = [] @@ -1122,23 +1124,34 @@ class OracleCompiler(compiler.SQLCompiler): self.bindparam_string(self._truncate_bindparam(outparam)) ) - # ensure the ExecutionContext.get_out_parameters() method is - # *not* called; the cx_Oracle dialect wants to handle these - # parameters separately - self.has_out_parameters = False + # has_out_parameters would in a normal case be set to True + # as a result of the compiler visiting an outparam() object. + # in this case, the above outparam() objects are not being + # visited. Ensure the statement itself didn't have other + # outparam() objects independently. + # technically, this could be supported, but as it would be + # a very strange use case without a clear rationale, disallow it + if self.has_out_parameters: + raise exc.InvalidRequestError( + "Using explicit outparam() objects with " + "UpdateBase.returning() in the same Core DML statement " + "is not supported in the Oracle dialect." + ) - columns.append(self.process(col_expr, within_columns_clause=False)) + self._oracle_returning = True - self._add_to_result_map( - getattr(col_expr, "name", col_expr._anon_name_label), - getattr(col_expr, "name", col_expr._anon_name_label), - ( - column, - getattr(column, "name", None), - getattr(column, "key", None), - ), - column.type, - ) + columns.append(self.process(col_expr, within_columns_clause=False)) + if populate_result_map: + self._add_to_result_map( + getattr(col_expr, "name", col_expr._anon_name_label), + getattr(col_expr, "name", col_expr._anon_name_label), + ( + column, + getattr(column, "name", None), + getattr(column, "key", None), + ), + column.type, + ) return "RETURNING " + ", ".join(columns) + " INTO " + ", ".join(binds) @@ -1510,6 +1523,7 @@ class OracleDialect(default.DefaultDialect): max_identifier_length = 128 implicit_returning = True + full_returning = True div_is_floordiv = False diff --git a/lib/sqlalchemy/dialects/oracle/cx_oracle.py b/lib/sqlalchemy/dialects/oracle/cx_oracle.py index 5208f96718..9f33945331 100644 --- a/lib/sqlalchemy/dialects/oracle/cx_oracle.py +++ b/lib/sqlalchemy/dialects/oracle/cx_oracle.py @@ -352,8 +352,7 @@ RETURNING Support ----------------- The cx_Oracle dialect implements RETURNING using OUT parameters. -The dialect supports RETURNING fully, however cx_Oracle 6 is recommended -for complete support. +The dialect supports RETURNING fully. .. _cx_oracle_lob: @@ -430,6 +429,8 @@ SQLAlchemy type (or a subclass of such). as better integration of outputtypehandlers. """ # noqa +from __future__ import annotations + import decimal import random import re @@ -444,6 +445,7 @@ from ... import util from ...engine import cursor as _cursor from ...engine import interfaces from ...engine import processors +from ...sql._typing import is_sql_compiler class _OracleInteger(sqltypes.Integer): @@ -452,10 +454,13 @@ class _OracleInteger(sqltypes.Integer): # 208#issuecomment-409715955 return int - def _cx_oracle_var(self, dialect, cursor): + def _cx_oracle_var(self, dialect, cursor, arraysize=None): cx_Oracle = dialect.dbapi return cursor.var( - cx_Oracle.STRING, 255, arraysize=cursor.arraysize, outconverter=int + cx_Oracle.STRING, + 255, + arraysize=arraysize if arraysize is not None else cursor.arraysize, + outconverter=int, ) def _cx_oracle_outputtypehandler(self, dialect): @@ -494,8 +499,6 @@ class _OracleNumeric(sqltypes.Numeric): def _cx_oracle_outputtypehandler(self, dialect): cx_Oracle = dialect.dbapi - is_cx_oracle_6 = dialect._is_cx_oracle_6 - def handler(cursor, name, default_type, size, precision, scale): outconverter = None @@ -506,11 +509,8 @@ class _OracleNumeric(sqltypes.Numeric): # allows for float("inf") to be handled type_ = default_type outconverter = decimal.Decimal - elif is_cx_oracle_6: - type_ = decimal.Decimal else: - type_ = cx_Oracle.STRING - outconverter = dialect._to_decimal + type_ = decimal.Decimal else: if self.is_number and scale == 0: # integer. cx_Oracle is observed to handle the widest @@ -525,11 +525,8 @@ class _OracleNumeric(sqltypes.Numeric): if default_type == cx_Oracle.NATIVE_FLOAT: type_ = default_type outconverter = decimal.Decimal - elif is_cx_oracle_6: - type_ = decimal.Decimal else: - type_ = cx_Oracle.STRING - outconverter = dialect._to_decimal + type_ = decimal.Decimal else: if self.is_number and scale == 0: # integer. cx_Oracle is observed to handle the widest @@ -670,6 +667,8 @@ class _OracleRowid(oracle.ROWID): class OracleCompiler_cx_oracle(OracleCompiler): _oracle_cx_sql_compiler = True + _oracle_returning = False + def bindparam_string(self, name, **kw): quote = getattr(name, "quote", None) if ( @@ -696,7 +695,7 @@ class OracleExecutionContext_cx_oracle(OracleExecutionContext): def _generate_out_parameter_vars(self): # check for has_out_parameters or RETURNING, create cx_Oracle.var # objects if so - if self.compiled.returning or self.compiled.has_out_parameters: + if self.compiled.has_out_parameters or self.compiled._oracle_returning: quoted_bind_names = self.compiled.escaped_bind_names for bindparam in self.compiled.binds.values(): if bindparam.isoutparam: @@ -705,7 +704,7 @@ class OracleExecutionContext_cx_oracle(OracleExecutionContext): if hasattr(type_impl, "_cx_oracle_var"): self.out_parameters[name] = type_impl._cx_oracle_var( - self.dialect, self.cursor + self.dialect, self.cursor, arraysize=1 ) else: dbtype = type_impl.get_dbapi_type(self.dialect.dbapi) @@ -726,10 +725,14 @@ class OracleExecutionContext_cx_oracle(OracleExecutionContext): cx_Oracle.NCLOB, ): self.out_parameters[name] = self.cursor.var( - dbtype, outconverter=lambda value: value.read() + dbtype, + outconverter=lambda value: value.read(), + arraysize=1, ) else: - self.out_parameters[name] = self.cursor.var(dbtype) + self.out_parameters[name] = self.cursor.var( + dbtype, arraysize=1 + ) self.parameters[0][ quoted_bind_names.get(name, name) ] = self.out_parameters[name] @@ -782,22 +785,33 @@ class OracleExecutionContext_cx_oracle(OracleExecutionContext): self._generate_cursor_outputtype_handler() def post_exec(self): - if self.compiled and self.out_parameters and self.compiled.returning: + if ( + self.compiled + and is_sql_compiler(self.compiled) + and self.compiled._oracle_returning + ): # create a fake cursor result from the out parameters. unlike # get_out_parameter_values(), the result-row handlers here will be # applied at the Result level - returning_params = [ - self.dialect._returningval(self.out_parameters["ret_%d" % i]) - for i in range(len(self.out_parameters)) + + numrows = len(self.out_parameters["ret_0"].values[0]) + numcols = len(self.out_parameters) + + initial_buffer = [ + tuple( + self.out_parameters[f"ret_{j}"].values[0][i] + for j in range(numcols) + ) + for i in range(numrows) ] fetch_strategy = _cursor.FullyBufferedCursorFetchStrategy( self.cursor, [ - (getattr(col, "name", col._anon_name_label), None) - for col in self.compiled.returning + (entry.keyname, None) + for entry in self.compiled._result_columns ], - initial_buffer=[tuple(returning_params)], + initial_buffer=initial_buffer, ) self.cursor_fetch_strategy = fetch_strategy @@ -908,18 +922,11 @@ class OracleDialect_cx_oracle(OracleDialect): self.cx_oracle_ver = (0, 0, 0) else: self.cx_oracle_ver = self._parse_cx_oracle_ver(cx_Oracle.version) - if self.cx_oracle_ver < (5, 2) and self.cx_oracle_ver > (0, 0, 0): + if self.cx_oracle_ver < (7,) and self.cx_oracle_ver > (0, 0, 0): raise exc.InvalidRequestError( - "cx_Oracle version 5.2 and above are supported" + "cx_Oracle version 7 and above are supported" ) - if encoding_errors and self.cx_oracle_ver < (6, 4): - util.warn( - "cx_oracle version %r does not support encodingErrors" - % (self.cx_oracle_ver,) - ) - self._cursor_var_unicode_kwargs = util.immutabledict() - self.include_set_input_sizes = { cx_Oracle.DATETIME, cx_Oracle.NCLOB, @@ -937,24 +944,6 @@ class OracleDialect_cx_oracle(OracleDialect): self._paramval = lambda value: value.getvalue() - # https://github.com/oracle/python-cx_Oracle/issues/176#issuecomment-386821291 - # https://github.com/oracle/python-cx_Oracle/issues/224 - self._values_are_lists = self.cx_oracle_ver >= (6, 3) - if self._values_are_lists: - cx_Oracle.__future__.dml_ret_array_val = True - - def _returningval(value): - try: - return value.values[0][0] - except IndexError: - return None - - self._returningval = _returningval - else: - self._returningval = self._paramval - - self._is_cx_oracle_6 = self.cx_oracle_ver >= (6,) - def _parse_cx_oracle_ver(self, version): m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", version) if m: diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 9bfde47688..9994528049 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -2401,10 +2401,11 @@ class PGCompiler(compiler.SQLCompiler): return tmp - def returning_clause(self, stmt, returning_cols): - + def returning_clause( + self, stmt, returning_cols, *, populate_result_map, **kw + ): columns = [ - self._label_returning_column(stmt, c) + self._label_returning_column(stmt, c, populate_result_map) for c in expression._select_iterables(returning_cols) ] diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg2.py b/lib/sqlalchemy/dialects/postgresql/psycopg2.py index 391368c5fc..07783ced78 100644 --- a/lib/sqlalchemy/dialects/postgresql/psycopg2.py +++ b/lib/sqlalchemy/dialects/postgresql/psycopg2.py @@ -478,7 +478,7 @@ class PGExecutionContext_psycopg2(_PGExecutionContext_common_psycopg): if ( self._psycopg2_fetched_rows and self.compiled - and self.compiled.returning + and self.compiled.effective_returning ): # psycopg2 execute_values will provide for a real cursor where # cursor.description works correctly. however, it executes the @@ -736,7 +736,7 @@ class PGDialect_psycopg2(_PGDialect_common_psycopg): statement, parameters, template=executemany_values, - fetch=bool(context.compiled.returning), + fetch=bool(context.compiled.effective_returning), **kwargs, ) diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 85ce91deb1..4d700866fe 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -933,8 +933,11 @@ class DefaultExecutionContext(ExecutionContext): assert isinstance(compiled.statement, UpdateBase) self.is_crud = True self._is_explicit_returning = bool(compiled.statement._returning) - self._is_implicit_returning = bool( - compiled.returning and not compiled.statement._returning + self._is_implicit_returning = is_implicit_returning = bool( + compiled.implicit_returning + ) + assert not ( + is_implicit_returning and compiled.statement._returning ) if not parameters: @@ -1165,12 +1168,6 @@ class DefaultExecutionContext(ExecutionContext): else: return () - @util.memoized_property - def returning_cols(self) -> Optional[Sequence[ColumnsClauseRole]]: - if TYPE_CHECKING: - assert isinstance(self.compiled, SQLCompiler) - return self.compiled.returning - @util.memoized_property def no_parameters(self): return self.execution_options.get("no_parameters", False) @@ -1349,6 +1346,7 @@ class DefaultExecutionContext(ExecutionContext): result = _cursor.CursorResult(self, strategy, cursor_description) compiled = self.compiled + if ( compiled and not self.isddl diff --git a/lib/sqlalchemy/orm/_typing.py b/lib/sqlalchemy/orm/_typing.py new file mode 100644 index 0000000000..e9ddf6d158 --- /dev/null +++ b/lib/sqlalchemy/orm/_typing.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING +from typing import Union + + +if TYPE_CHECKING: + from .mapper import Mapper + from .util import AliasedInsp + +_EntityType = Union[Mapper, AliasedInsp] diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py index bcd3e7d233..edd3fb56bf 100644 --- a/lib/sqlalchemy/orm/context.py +++ b/lib/sqlalchemy/orm/context.py @@ -8,7 +8,15 @@ from __future__ import annotations import itertools +from typing import Any +from typing import Dict from typing import List +from typing import Optional +from typing import Set +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import Union from . import attributes from . import interfaces @@ -27,23 +35,39 @@ from .. import future from .. import inspect from .. import sql from .. import util -from ..sql import ClauseElement from ..sql import coercions from ..sql import expression from ..sql import roles from ..sql import util as sql_util from ..sql import visitors +from ..sql._typing import is_dml +from ..sql._typing import is_insert_update +from ..sql._typing import is_select_base from ..sql.base import _select_iterables from ..sql.base import CacheableOptions from ..sql.base import CompileState +from ..sql.base import Executable from ..sql.base import Options +from ..sql.dml import UpdateBase +from ..sql.elements import GroupedElement +from ..sql.elements import TextClause from ..sql.selectable import LABEL_STYLE_DISAMBIGUATE_ONLY from ..sql.selectable import LABEL_STYLE_NONE from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL +from ..sql.selectable import ReturnsRows from ..sql.selectable import Select +from ..sql.selectable import SelectLabelStyle from ..sql.selectable import SelectState from ..sql.visitors import InternalTraversal +if TYPE_CHECKING: + from ._typing import _EntityType + from ..sql.compiler import _CompilerStackEntry + from ..sql.dml import _DMLTableElement + from ..sql.elements import ColumnElement + from ..sql.selectable import _LabelConventionCallable + from ..sql.selectable import SelectBase + _path_registry = PathRegistry.root _EMPTY_DICT = util.immutabledict() @@ -144,16 +168,6 @@ class QueryContext: self.yield_per = load_options._yield_per self.identity_token = load_options._refresh_identity_token - if self.yield_per and compile_state._no_yield_pers: - raise sa_exc.InvalidRequestError( - "The yield_per Query option is currently not " - "compatible with %s eager loading. Please " - "specify lazyload('*') or query.enable_eagerloads(False) in " - "order to " - "proceed with query.yield_per()." - % ", ".join(compile_state._no_yield_pers) - ) - _orm_load_exec_options = util.immutabledict( {"_result_disable_adapt_to_context": True, "future_result": True} @@ -196,7 +210,24 @@ class ORMCompileState(CompileState): _for_refresh_state = False _render_for_subquery = False - current_path = _path_registry + statement: Union[Select, FromStatement] + select_statement: Union[Select, FromStatement] + _entities: List[_QueryEntity] + _polymorphic_adapters: Dict[_EntityType, ORMAdapter] + compile_options: Union[ + Type[default_compile_options], default_compile_options + ] + _primary_entity: Optional[_QueryEntity] + use_legacy_query_style: bool + _label_convention: _LabelConventionCallable + primary_columns: List[ColumnElement[Any]] + secondary_columns: List[ColumnElement[Any]] + dedupe_columns: Set[ColumnElement[Any]] + create_eager_joins: List[ + # TODO: this structure is set up by JoinedLoader + Tuple[Any, ...] + ] + current_path: PathRegistry = _path_registry def __init__(self, *arg, **kw): raise NotImplementedError() @@ -208,7 +239,9 @@ class ORMCompileState(CompileState): col_collection.append(obj) @classmethod - def _column_naming_convention(cls, label_style, legacy): + def _column_naming_convention( + cls, label_style: SelectLabelStyle, legacy: bool + ) -> _LabelConventionCallable: if legacy: @@ -388,6 +421,10 @@ class ORMFromStatementCompileState(ORMCompileState): _from_obj_alias = None _has_mapper_entities = False + statement_container: FromStatement + requested_statement: Union[SelectBase, TextClause, UpdateBase] + dml_table: _DMLTableElement + _has_orm_entities = False multi_row_eager_loaders = False compound_eager_adapter = None @@ -403,6 +440,12 @@ class ORMFromStatementCompileState(ORMCompileState): else: toplevel = True + if not toplevel: + raise sa_exc.CompileError( + "The ORM FromStatement construct only supports being " + "invoked as the topmost statement, as it is only intended to " + "define how result rows should be returned." + ) self = cls.__new__(cls) self._primary_entity = None @@ -417,7 +460,6 @@ class ORMFromStatementCompileState(ORMCompileState): self._entities = [] self._polymorphic_adapters = {} - self._no_yield_pers = set() self.compile_options = statement_container._compile_options @@ -474,37 +516,47 @@ class ORMFromStatementCompileState(ORMCompileState): self.order_by = None - if isinstance( - self.statement, (expression.TextClause, expression.UpdateBase) - ): - + if isinstance(self.statement, expression.TextClause): + # TextClause has no "column" objects at all. for this case, + # we generate columns from our _QueryEntity objects, then + # flip on all the "please match no matter what" parameters. self.extra_criteria_entities = {} - # setup for all entities. Currently, this is not useful - # for eager loaders, as the eager loaders that work are able - # to do their work entirely in row_processor. for entity in self._entities: entity.setup_compile_state(self) - # we did the setup just to get primary columns. - self.statement = _AdHocColumnsStatement( - self.statement, self.primary_columns - ) + compiler._ordered_columns = ( + compiler._textual_ordered_columns + ) = False + + # enable looser result column matching. this is shown to be + # needed by test_query.py::TextTest + compiler._loose_column_name_matching = True + + for c in self.primary_columns: + compiler.process( + c, + within_columns_clause=True, + add_to_result_map=compiler._add_to_result_map, + ) else: - # allow TextualSelect with implicit columns as well - # as select() with ad-hoc columns, see test_query::TextTest + # for everyone else, Select, Insert, Update, TextualSelect, they + # have column objects already. After much + # experimentation here, the best approach seems to be, use + # those columns completely, don't interfere with the compiler + # at all; just in ORM land, use an adapter to convert from + # our ORM columns to whatever columns are in the statement, + # before we look in the result row. If the inner statement is + # not ORM enabled, assume looser col matching based on name + statement_is_orm = ( + self.statement._propagate_attrs.get( + "compile_state_plugin", None + ) + == "orm" + ) self._from_obj_alias = sql.util.ColumnAdapter( - self.statement, adapt_on_names=True + self.statement, adapt_on_names=not statement_is_orm ) - # set up for eager loaders, however if we fix subqueryload - # it should not need to do this here. the model of eager loaders - # that can work entirely in row_processor might be interesting - # here though subqueryloader has a lot of upfront work to do - # see test/orm/test_query.py -> test_related_eagerload_against_text - # for where this part makes a difference. would rather have - # subqueryload figure out what it needs more intelligently. - # for entity in self._entities: - # entity.setup_compile_state(self) return self @@ -515,63 +567,91 @@ class ORMFromStatementCompileState(ORMCompileState): return None -class _AdHocColumnsStatement(ClauseElement): - """internal object created to somewhat act like a SELECT when we - are selecting columns from a DML RETURNING. +class FromStatement(GroupedElement, ReturnsRows, Executable): + """Core construct that represents a load of ORM objects from various + :class:`.ReturnsRows` and other classes including: + :class:`.Select`, :class:`.TextClause`, :class:`.TextualSelect`, + :class:`.CompoundSelect`, :class`.Insert`, :class:`.Update`, + and in theory, :class:`.Delete`. """ - __visit_name__ = None + __visit_name__ = "orm_from_statement" - def __init__(self, text, columns): - self.element = text - self.column_args = [ - coercions.expect(roles.ColumnsClauseRole, c) for c in columns - ] + _compile_options = ORMFromStatementCompileState.default_compile_options - def _generate_cache_key(self): - raise NotImplementedError() + _compile_state_factory = ORMFromStatementCompileState.create_for_statement - def _gen_cache_key(self, anon_map, bindparams): - raise NotImplementedError() + _for_update_arg = None - def _compiler_dispatch( - self, compiler, compound_index=None, asfrom=False, **kw - ): - """provide a fixed _compiler_dispatch method.""" + element: Union[SelectBase, TextClause, UpdateBase] - toplevel = not compiler.stack - entry = ( - compiler._default_stack_entry if toplevel else compiler.stack[-1] - ) + _traverse_internals = [ + ("_raw_columns", InternalTraversal.dp_clauseelement_list), + ("element", InternalTraversal.dp_clauseelement), + ] + Executable._executable_traverse_internals - populate_result_map = ( - toplevel - # these two might not be needed - or ( - compound_index == 0 - and entry.get("need_result_map_for_compound", False) + _cache_key_traversal = _traverse_internals + [ + ("_compile_options", InternalTraversal.dp_has_cache_key) + ] + + def __init__(self, entities, element): + self._raw_columns = [ + coercions.expect( + roles.ColumnsClauseRole, + ent, + apply_propagate_attrs=self, + post_inspect=True, ) - or entry.get("need_result_map_for_nested", False) + for ent in util.to_list(entities) + ] + self.element = element + self.is_dml = element.is_dml + self._label_style = ( + element._label_style if is_select_base(element) else None ) - if populate_result_map: - compiler._ordered_columns = ( - compiler._textual_ordered_columns - ) = False + def _compiler_dispatch(self, compiler, **kw): - # enable looser result column matching. this is shown to be - # needed by test_query.py::TextTest - compiler._loose_column_name_matching = True + """provide a fixed _compiler_dispatch method. - for c in self.column_args: - compiler.process( - c, - within_columns_clause=True, - add_to_result_map=compiler._add_to_result_map, - ) - return compiler.process(self.element, **kw) + This is roughly similar to using the sqlalchemy.ext.compiler + ``@compiles`` extension. + + """ + + compile_state = self._compile_state_factory(self, compiler, **kw) + + toplevel = not compiler.stack + + if toplevel: + compiler.compile_state = compile_state + + return compiler.process(compile_state.statement, **kw) + + def _ensure_disambiguated_names(self): + return self + + def get_children(self, **kw): + for elem in itertools.chain.from_iterable( + element._from_objects for element in self._raw_columns + ): + yield elem + for elem in super(FromStatement, self).get_children(**kw): + yield elem + + @property + def _all_selected_columns(self): + return self.element._all_selected_columns + + @property + def _returning(self): + return self.element._returning if is_dml(self.element) else None + + @property + def _inline(self): + return self.element._inline if is_insert_update(self.element) else None @sql.base.CompileState.plugin_for("orm", "select") @@ -633,7 +713,6 @@ class ORMSelectCompileState(ORMCompileState, SelectState): self._entities = [] self._primary_entity = None self._polymorphic_adapters = {} - self._no_yield_pers = set() self.compile_options = select_statement._compile_options @@ -2059,6 +2138,9 @@ class _QueryEntity: _null_column_type = False use_id_for_hash = False + def setup_compile_state(self, compile_state: ORMCompileState) -> None: + raise NotImplementedError() + @classmethod def to_compile_state( cls, compile_state, entities, entities_collection, is_current_entities diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index b9ced44d53..a754bd4f2a 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -21,7 +21,6 @@ database to return iterable result sets. from __future__ import annotations import collections.abc as collections_abc -import itertools import operator from typing import Any from typing import Generic @@ -40,9 +39,9 @@ from .base import _assertions from .context import _column_descriptions from .context import _determine_last_joined_entity from .context import _legacy_filter_by_entity_zero +from .context import FromStatement from .context import LABEL_STYLE_LEGACY_ORM from .context import ORMCompileState -from .context import ORMFromStatementCompileState from .context import QueryContext from .interfaces import ORMColumnDescription from .interfaces import ORMColumnsClauseRole @@ -71,14 +70,10 @@ from ..sql.expression import Exists from ..sql.selectable import _MemoizedSelectEntities from ..sql.selectable import _SelectFromElements from ..sql.selectable import ForUpdateArg -from ..sql.selectable import GroupedElement from ..sql.selectable import HasHints from ..sql.selectable import HasPrefixes from ..sql.selectable import HasSuffixes from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL -from ..sql.selectable import SelectBase -from ..sql.selectable import SelectStatementGrouping -from ..sql.visitors import InternalTraversal if TYPE_CHECKING: from ..sql.selectable import _SetupJoinsElement @@ -2765,91 +2760,6 @@ class Query( return context -class FromStatement(GroupedElement, SelectBase, Executable): - """Core construct that represents a load of ORM objects from a finished - select or text construct. - - """ - - __visit_name__ = "orm_from_statement" - - _compile_options = ORMFromStatementCompileState.default_compile_options - - _compile_state_factory = ORMFromStatementCompileState.create_for_statement - - _for_update_arg = None - - _traverse_internals = [ - ("_raw_columns", InternalTraversal.dp_clauseelement_list), - ("element", InternalTraversal.dp_clauseelement), - ] + Executable._executable_traverse_internals - - _cache_key_traversal = _traverse_internals + [ - ("_compile_options", InternalTraversal.dp_has_cache_key) - ] - - def __init__(self, entities, element): - self._raw_columns = [ - coercions.expect( - roles.ColumnsClauseRole, - ent, - apply_propagate_attrs=self, - post_inspect=True, - ) - for ent in util.to_list(entities) - ] - self.element = element - - def get_label_style(self): - return self._label_style - - def set_label_style(self, label_style): - return SelectStatementGrouping( - self.element.set_label_style(label_style) - ) - - @property - def _label_style(self): - return self.element._label_style - - def _compiler_dispatch(self, compiler, **kw): - - """provide a fixed _compiler_dispatch method. - - This is roughly similar to using the sqlalchemy.ext.compiler - ``@compiles`` extension. - - """ - - compile_state = self._compile_state_factory(self, compiler, **kw) - - toplevel = not compiler.stack - - if toplevel: - compiler.compile_state = compile_state - - return compiler.process(compile_state.statement, **kw) - - def _ensure_disambiguated_names(self): - return self - - def get_children(self, **kw): - for elem in itertools.chain.from_iterable( - element._from_objects for element in self._raw_columns - ): - yield elem - for elem in super(FromStatement, self).get_children(**kw): - yield elem - - @property - def _returning(self): - return self.element._returning if self.element.is_dml else None - - @property - def _inline(self): - return self.element._inline if self.element.is_dml else None - - class AliasOption(interfaces.LoaderOption): inherit_cache = False diff --git a/lib/sqlalchemy/sql/_typing.py b/lib/sqlalchemy/sql/_typing.py index 0a72a93c5c..bc1e0672c4 100644 --- a/lib/sqlalchemy/sql/_typing.py +++ b/lib/sqlalchemy/sql/_typing.py @@ -14,6 +14,11 @@ from ..util.typing import Literal from ..util.typing import Protocol if TYPE_CHECKING: + from .compiler import Compiled + from .compiler import DDLCompiler + from .compiler import SQLCompiler + from .dml import UpdateBase + from .dml import ValuesBase from .elements import ClauseElement from .elements import ColumnClause from .elements import ColumnElement @@ -38,6 +43,7 @@ if TYPE_CHECKING: from .type_api import TypeEngine from ..util.typing import TypeGuard + _T = TypeVar("_T", bound=Any) @@ -153,6 +159,12 @@ _TypeEngineArgument = Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"] if TYPE_CHECKING: + def is_sql_compiler(c: Compiled) -> TypeGuard[SQLCompiler]: + ... + + def is_ddl_compiler(c: Compiled) -> TypeGuard[DDLCompiler]: + ... + def is_named_from_clause(t: FromClauseRole) -> TypeGuard[NamedFromClause]: ... @@ -183,7 +195,13 @@ if TYPE_CHECKING: def is_subquery(t: FromClause) -> TypeGuard[Subquery]: ... + def is_dml(c: ClauseElement) -> TypeGuard[UpdateBase]: + ... + else: + + is_sql_compiler = operator.attrgetter("is_sql") + is_ddl_compiler = operator.attrgetter("is_ddl") is_named_from_clause = operator.attrgetter("named_with_column") is_column_element = operator.attrgetter("_is_column_element") is_text_clause = operator.attrgetter("_is_text_clause") @@ -194,6 +212,7 @@ else: is_select_statement = operator.attrgetter("_is_select_statement") is_table = operator.attrgetter("_is_table") is_subquery = operator.attrgetter("_is_subquery") + is_dml = operator.attrgetter("is_dml") def has_schema_attr(t: FromClauseRole) -> TypeGuard[TableClause]: @@ -206,3 +225,7 @@ def is_quoted_name(s: str) -> TypeGuard[quoted_name]: def is_has_clause_element(s: object) -> TypeGuard[_HasClauseElement]: return hasattr(s, "__clause_element__") + + +def is_insert_update(c: ClauseElement) -> TypeGuard[ValuesBase]: + return c.is_dml and (c.is_insert or c.is_update) # type: ignore diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index 6b25d8fcd5..f766a5ac50 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -937,7 +937,7 @@ class Executable(roles.StatementRole, Generative): _with_context_options: Tuple[ Tuple[Callable[[CompileState], None], Any], ... ] = () - _compile_options: Optional[CacheableOptions] + _compile_options: Optional[Union[Type[CacheableOptions], CacheableOptions]] _executable_traverse_internals = [ ("_with_options", InternalTraversal.dp_executable_options), @@ -982,7 +982,7 @@ class Executable(roles.StatementRole, Generative): ) -> Result: ... - @util.non_memoized_property + @util.ro_non_memoized_property def _all_selected_columns(self): raise NotImplementedError() diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 6ecfbf9866..522a0bd4a0 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -61,6 +61,8 @@ from . import operators from . import schema from . import selectable from . import sqltypes +from ._typing import is_column_element +from ._typing import is_dml from .base import _from_objects from .base import Executable from .base import NO_ARG @@ -90,6 +92,7 @@ if typing.TYPE_CHECKING: from .elements import _truncated_label from .elements import BindParameter from .elements import ColumnClause + from .elements import ColumnElement from .elements import Label from .functions import Function from .selectable import Alias @@ -492,6 +495,9 @@ class Compiled: defaults. """ + is_sql = False + is_ddl = False + _cached_metadata: Optional[CursorResultMetaData] = None _result_columns: Optional[List[ResultColumnsEntry]] = None @@ -701,6 +707,8 @@ class SQLCompiler(Compiled): extract_map = EXTRACT_MAP + is_sql = True + _result_columns: List[ResultColumnsEntry] compound_keywords = COMPOUND_KEYWORDS @@ -725,9 +733,14 @@ class SQLCompiler(Compiled): """list of columns for which onupdate default values should be evaluated before an UPDATE takes place""" - returning: Optional[Sequence[roles.ColumnsClauseRole]] - """list of columns that will be delivered to cursor.description or - dialect equivalent via the RETURNING clause on an INSERT, UPDATE, or DELETE + implicit_returning: Optional[Sequence[ColumnElement[Any]]] = None + """list of "implicit" returning columns for a toplevel INSERT or UPDATE + statement, used to receive newly generated values of columns. + + .. versionadded:: 2.0 ``implicit_returning`` replaces the previous + ``returning`` collection, which was not a generalized RETURNING + collection and instead was in fact specific to the "implicit returning" + feature. """ @@ -750,12 +763,6 @@ class SQLCompiler(Compiled): TypeEngine. CursorResult uses this for type processing and column targeting""" - returning = None - """holds the "returning" collection of columns if - the statement is CRUD and defines returning columns - either implicitly or explicitly - """ - returning_precedes_values: bool = False """set to True classwide to generate RETURNING clauses before the VALUES or WHERE clause (i.e. MSSQL) @@ -978,9 +985,6 @@ class SQLCompiler(Compiled): if TYPE_CHECKING: assert isinstance(statement, UpdateBase) - if statement._returning: - self.returning = statement._returning - if self.isinsert or self.isupdate: if TYPE_CHECKING: assert isinstance(statement, ValuesBase) @@ -1001,6 +1005,39 @@ class SQLCompiler(Compiled): if self._render_postcompile: self._process_parameters_for_postcompile(_populate_self=True) + @util.ro_memoized_property + def effective_returning(self) -> Optional[Sequence[ColumnElement[Any]]]: + """The effective "returning" columns for INSERT, UPDATE or DELETE. + + This is either the so-called "implicit returning" columns which are + calculated by the compiler on the fly, or those present based on what's + present in ``self.statement._returning`` (expanded into individual + columns using the ``._all_selected_columns`` attribute) i.e. those set + explicitly using the :meth:`.UpdateBase.returning` method. + + .. versionadded:: 2.0 + + """ + if self.implicit_returning: + return self.implicit_returning + elif is_dml(self.statement): + return [ + c + for c in self.statement._all_selected_columns + if is_column_element(c) + ] + + else: + return None + + @property + def returning(self): + """backwards compatibility; returns the + effective_returning collection. + + """ + return self.effective_returning + @property def current_executable(self): """Return the current 'executable' that is being compiled. @@ -1569,7 +1606,7 @@ class SQLCompiler(Compiled): param_key_getter = self._within_exec_param_key_getter table = self.statement.table - returning = self.returning + returning = self.implicit_returning assert returning is not None ret = {col: idx for idx, col in enumerate(returning)} @@ -3373,7 +3410,9 @@ class SQLCompiler(Compiled): ResultColumnsEntry(keyname, name, objects, type_) ) - def _label_returning_column(self, stmt, column, column_clause_args=None): + def _label_returning_column( + self, stmt, column, populate_result_map, column_clause_args=None + ): """Render a column with necessary labels inside of a RETURNING clause. This method is provided for individual dialects in place of calling @@ -3386,7 +3425,7 @@ class SQLCompiler(Compiled): return self._label_select_column( None, column, - True, + populate_result_map, False, {} if column_clause_args is None else column_clause_args, ) @@ -4103,7 +4142,10 @@ class SQLCompiler(Compiled): def returning_clause( self, stmt: UpdateBase, - returning_cols: Sequence[roles.ColumnsClauseRole], + returning_cols: Sequence[ColumnElement[Any]], + *, + populate_result_map: bool, + **kw: Any, ) -> str: raise exc.CompileError( "RETURNING is not supported by this " @@ -4228,7 +4270,6 @@ class SQLCompiler(Compiled): return dialect_hints, table_text def visit_insert(self, insert_stmt, **kw): - compile_state = insert_stmt._compile_state_factory( insert_stmt, self, **kw ) @@ -4250,7 +4291,7 @@ class SQLCompiler(Compiled): ) crud_params_struct = crud._get_crud_params( - self, insert_stmt, compile_state, **kw + self, insert_stmt, compile_state, toplevel, **kw ) crud_params_single = crud_params_struct.single_params @@ -4303,9 +4344,11 @@ class SQLCompiler(Compiled): [expr for _, expr, _ in crud_params_single] ) - if self.returning or insert_stmt._returning: + if self.implicit_returning or insert_stmt._returning: returning_clause = self.returning_clause( - insert_stmt, self.returning or insert_stmt._returning + insert_stmt, + self.implicit_returning or insert_stmt._returning, + populate_result_map=toplevel, ) if self.returning_precedes_values: @@ -4449,7 +4492,7 @@ class SQLCompiler(Compiled): update_stmt, update_stmt.table, render_extra_froms, **kw ) crud_params_struct = crud._get_crud_params( - self, update_stmt, compile_state, **kw + self, update_stmt, compile_state, toplevel, **kw ) crud_params = crud_params_struct.single_params @@ -4473,10 +4516,12 @@ class SQLCompiler(Compiled): ) ) - if self.returning or update_stmt._returning: + if self.implicit_returning or update_stmt._returning: if self.returning_precedes_values: text += " " + self.returning_clause( - update_stmt, self.returning or update_stmt._returning + update_stmt, + self.implicit_returning or update_stmt._returning, + populate_result_map=toplevel, ) if extra_froms: @@ -4502,10 +4547,12 @@ class SQLCompiler(Compiled): text += " " + limit_clause if ( - self.returning or update_stmt._returning + self.implicit_returning or update_stmt._returning ) and not self.returning_precedes_values: text += " " + self.returning_clause( - update_stmt, self.returning or update_stmt._returning + update_stmt, + self.implicit_returning or update_stmt._returning, + populate_result_map=toplevel, ) if self.ctes: @@ -4585,7 +4632,9 @@ class SQLCompiler(Compiled): if delete_stmt._returning: if self.returning_precedes_values: text += " " + self.returning_clause( - delete_stmt, delete_stmt._returning + delete_stmt, + delete_stmt._returning, + populate_result_map=toplevel, ) if extra_froms: @@ -4608,7 +4657,9 @@ class SQLCompiler(Compiled): if delete_stmt._returning and not self.returning_precedes_values: text += " " + self.returning_clause( - delete_stmt, delete_stmt._returning + delete_stmt, + delete_stmt._returning, + populate_result_map=toplevel, ) if self.ctes: @@ -4685,7 +4736,14 @@ class StrSQLCompiler(SQLCompiler): def visit_sequence(self, seq, **kw): return "" % self.preparer.format_sequence(seq) - def returning_clause(self, stmt, returning_cols): + def returning_clause( + self, + stmt: UpdateBase, + returning_cols: Sequence[ColumnElement[Any]], + *, + populate_result_map: bool, + **kw: Any, + ) -> str: columns = [ self._label_select_column(None, c, True, False, {}) for c in base._select_iterables(returning_cols) @@ -4733,6 +4791,8 @@ class StrSQLCompiler(SQLCompiler): class DDLCompiler(Compiled): + is_ddl = True + if TYPE_CHECKING: def __init__( diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index 91a3f70c91..f6db2c4b25 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -87,6 +87,7 @@ def _get_crud_params( compiler: SQLCompiler, stmt: ValuesBase, compile_state: DMLState, + toplevel: bool, **kw: Any, ) -> _CrudParams: """create a set of tuples representing column/string pairs for use @@ -99,10 +100,33 @@ def _get_crud_params( """ + # note: the _get_crud_params() system was written with the notion in mind + # that INSERT, UPDATE, DELETE are always the top level statement and + # that there is only one of them. With the addition of CTEs that can + # make use of DML, this assumption is no longer accurate; the DML + # statement is not necessarily the top-level "row returning" thing + # and it is also theoretically possible (fortunately nobody has asked yet) + # to have a single statement with multiple DMLs inside of it via CTEs. + + # the current _get_crud_params() design doesn't accommodate these cases + # right now. It "just works" for a CTE that has a single DML inside of + # it, and for a CTE with multiple DML, it's not clear what would happen. + + # overall, the "compiler.XYZ" collections here would need to be in a + # per-DML structure of some kind, and DefaultDialect would need to + # navigate these collections on a per-statement basis, with additional + # emphasis on the "toplevel returning data" statement. However we + # still need to run through _get_crud_params() for all DML as we have + # Python / SQL generated column defaults that need to be rendered. + + # if there is user need for this kind of thing, it's likely a post 2.0 + # kind of change as it would require deep changes to DefaultDialect + # as well as here. + compiler.postfetch = [] compiler.insert_prefetch = [] compiler.update_prefetch = [] - compiler.returning = [] + compiler.implicit_returning = [] # getters - these are normally just column.key, # but in the case of mysql multi-table update, the rules for @@ -213,6 +237,7 @@ def _get_crud_params( _col_bind_name, check_columns, values, + toplevel, kw, ) else: @@ -226,6 +251,7 @@ def _get_crud_params( _col_bind_name, check_columns, values, + toplevel, kw, ) @@ -419,6 +445,7 @@ def _scan_insert_from_select_cols( _col_bind_name, check_columns, values, + toplevel, kw, ): @@ -427,7 +454,7 @@ def _scan_insert_from_select_cols( implicit_returning, implicit_return_defaults, postfetch_lastrowid, - ) = _get_returning_modifiers(compiler, stmt, compile_state) + ) = _get_returning_modifiers(compiler, stmt, compile_state, toplevel) cols = [stmt.table.c[_column_as_key(name)] for name in stmt._select_names] @@ -472,6 +499,7 @@ def _scan_cols( _col_bind_name, check_columns, values, + toplevel, kw, ): ( @@ -479,7 +507,7 @@ def _scan_cols( implicit_returning, implicit_return_defaults, postfetch_lastrowid, - ) = _get_returning_modifiers(compiler, stmt, compile_state) + ) = _get_returning_modifiers(compiler, stmt, compile_state, toplevel) if compile_state._parameter_ordering: parameter_ordering = [ @@ -556,11 +584,11 @@ def _scan_cols( # column has a DDL-level default, and is either not a pk # column or we don't need the pk. if implicit_return_defaults and c in implicit_return_defaults: - compiler.returning.append(c) + compiler.implicit_returning.append(c) elif not c.primary_key: compiler.postfetch.append(c) elif implicit_return_defaults and c in implicit_return_defaults: - compiler.returning.append(c) + compiler.implicit_returning.append(c) elif ( c.primary_key and c is not stmt.table._autoincrement_column @@ -628,7 +656,7 @@ def _append_param_parameter( if compile_state.isupdate: if implicit_return_defaults and c in implicit_return_defaults: - compiler.returning.append(c) + compiler.implicit_returning.append(c) else: compiler.postfetch.append(c) @@ -636,12 +664,12 @@ def _append_param_parameter( if c.primary_key: if implicit_returning: - compiler.returning.append(c) + compiler.implicit_returning.append(c) elif compiler.dialect.postfetch_lastrowid: compiler.postfetch_lastrowid = True elif implicit_return_defaults and c in implicit_return_defaults: - compiler.returning.append(c) + compiler.implicit_returning.append(c) else: # postfetch specifically means, "we can SELECT the row we just @@ -674,7 +702,7 @@ def _append_param_insert_pk_returning(compiler, stmt, c, values, kw): compiler.process(c.default, **kw), ) ) - compiler.returning.append(c) + compiler.implicit_returning.append(c) elif c.default.is_clause_element: values.append( ( @@ -683,7 +711,7 @@ def _append_param_insert_pk_returning(compiler, stmt, c, values, kw): compiler.process(c.default.arg.self_group(), **kw), ) ) - compiler.returning.append(c) + compiler.implicit_returning.append(c) else: # client side default. OK we can't use RETURNING, need to # do a "prefetch", which in fact fetches the default value @@ -696,7 +724,7 @@ def _append_param_insert_pk_returning(compiler, stmt, c, values, kw): ) ) elif c is stmt.table._autoincrement_column or c.server_default is not None: - compiler.returning.append(c) + compiler.implicit_returning.append(c) elif not c.nullable: # no .default, no .server_default, not autoincrement, we have # no indication this primary key column will have any value @@ -794,7 +822,7 @@ def _append_param_insert_hasdefault( ) ) if implicit_return_defaults and c in implicit_return_defaults: - compiler.returning.append(c) + compiler.implicit_returning.append(c) elif not c.primary_key: compiler.postfetch.append(c) elif c.default.is_clause_element: @@ -807,7 +835,7 @@ def _append_param_insert_hasdefault( ) if implicit_return_defaults and c in implicit_return_defaults: - compiler.returning.append(c) + compiler.implicit_returning.append(c) elif not c.primary_key: # don't add primary key column to postfetch compiler.postfetch.append(c) @@ -870,7 +898,7 @@ def _append_param_update( ) ) if implicit_return_defaults and c in implicit_return_defaults: - compiler.returning.append(c) + compiler.implicit_returning.append(c) else: compiler.postfetch.append(c) else: @@ -886,7 +914,7 @@ def _append_param_update( ) elif c.server_onupdate is not None: if implicit_return_defaults and c in implicit_return_defaults: - compiler.returning.append(c) + compiler.implicit_returning.append(c) else: compiler.postfetch.append(c) elif ( @@ -894,7 +922,7 @@ def _append_param_update( and (stmt._return_defaults_columns or not stmt._return_defaults) and c in implicit_return_defaults ): - compiler.returning.append(c) + compiler.implicit_returning.append(c) @overload @@ -1195,10 +1223,11 @@ def _get_stmt_parameter_tuples_params( values.append((k, col_expr, v)) -def _get_returning_modifiers(compiler, stmt, compile_state): +def _get_returning_modifiers(compiler, stmt, compile_state, toplevel): need_pks = ( - compile_state.isinsert + toplevel + and compile_state.isinsert and not stmt._inline and ( not compiler.for_executemany diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py index 8a3a1b38f0..f23ba2e6e2 100644 --- a/lib/sqlalchemy/sql/dml.py +++ b/lib/sqlalchemy/sql/dml.py @@ -463,11 +463,18 @@ class UpdateBase( ) return self - @util.non_memoized_property + def corresponding_column( + self, column: ColumnElement[Any], require_embedded: bool = False + ) -> Optional[ColumnElement[Any]]: + return self.exported_columns.corresponding_column( + column, require_embedded=require_embedded + ) + + @util.ro_memoized_property def _all_selected_columns(self) -> _SelectIterable: return [c for c in _select_iterables(self._returning)] - @property + @util.ro_memoized_property def exported_columns( self, ) -> ReadOnlyColumnCollection[Optional[str], ColumnElement[Any]]: diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index aec29d1b2e..77813b98f7 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -305,6 +305,7 @@ class ClauseElement( is_clause_element = True is_selectable = False + is_dml = False _is_column_element = False _is_table = False _is_textual = False diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 6504449f1d..292225ce2e 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -127,6 +127,9 @@ if TYPE_CHECKING: _ColumnsClauseElement = Union["FromClause", ColumnElement[Any], "TextClause"] +_LabelConventionCallable = Callable[ + [Union["ColumnElement[Any]", "TextClause"]], Optional[str] +] class _JoinTargetProtocol(Protocol): @@ -183,7 +186,7 @@ class ReturnsRows(roles.ReturnsRowsRole, DQLDMLClauseElement): def selectable(self) -> ReturnsRows: return self - @util.non_memoized_property + @util.ro_non_memoized_property def _all_selected_columns(self) -> _SelectIterable: """A sequence of column expression objects that represents the "selected" columns of this :class:`_expression.ReturnsRows`. @@ -3277,7 +3280,7 @@ class SelectBase( """ raise NotImplementedError() - @util.non_memoized_property + @util.ro_non_memoized_property def _all_selected_columns(self) -> _SelectIterable: """A sequence of expressions that correspond to what is rendered in the columns clause, including :class:`_sql.TextClause` @@ -3586,7 +3589,7 @@ class SelectStatementGrouping(GroupedElement, SelectBase): ) -> None: self.element._generate_fromclause_column_proxies(subquery) - @util.non_memoized_property + @util.ro_non_memoized_property def _all_selected_columns(self) -> _SelectIterable: return self.element._all_selected_columns @@ -4297,7 +4300,7 @@ class CompoundSelect(HasCompileState, GenerativeSelect): for select in self.selects: select._refresh_for_new_column(column) - @util.non_memoized_property + @util.ro_non_memoized_property def _all_selected_columns(self) -> _SelectIterable: return self.selects[0]._all_selected_columns @@ -4408,7 +4411,7 @@ class SelectState(util.MemoizedSlots, CompileState): @classmethod def _column_naming_convention( cls, label_style: SelectLabelStyle - ) -> Callable[[Union[ColumnElement[Any], TextClause]], Optional[str]]: + ) -> _LabelConventionCallable: table_qualified = label_style is LABEL_STYLE_TABLENAME_PLUS_COL dedupe = label_style is not LABEL_STYLE_NONE @@ -5984,7 +5987,7 @@ class Select( ) return cc.as_readonly() - @HasMemoized.memoized_attribute + @HasMemoized_ro_memoized_attribute def _all_selected_columns(self) -> _SelectIterable: meth = SelectState.get_plugin_class(self).all_selected_columns return list(meth(self)) @@ -6537,14 +6540,7 @@ class TextualSelect(SelectBase): (c.key, c) for c in self.column_args ).as_readonly() - # def _generate_columns_plus_names( - # self, anon_for_dupe_key: bool - # ) -> List[Tuple[str, str, str, ColumnElement[Any], bool]]: - # return Select._generate_columns_plus_names( - # self, anon_for_dupe_key=anon_for_dupe_key - # ) - - @util.non_memoized_property + @util.ro_non_memoized_property def _all_selected_columns(self) -> _SelectIterable: return self.column_args diff --git a/setup.cfg b/setup.cfg index 02a1ec8aab..5ef2c6f22c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -162,5 +162,4 @@ mariadb_connector = mariadb+mariadbconnector://scott:tiger@127.0.0.1:3306/test mssql = mssql+pyodbc://scott:tiger^5HHH@mssql2017:1433/test?driver=ODBC+Driver+13+for+SQL+Server mssql_pymssql = mssql+pymssql://scott:tiger@ms_2008 docker_mssql = mssql+pymssql://scott:tiger^5HHH@127.0.0.1:1433/test -oracle = oracle+cx_oracle://scott:tiger@127.0.0.1:1521 -oracle8 = oracle+cx_oracle://scott:tiger@127.0.0.1:1521/?use_ansi=0 +oracle = oracle+cx_oracle://scott:tiger@oracle18c diff --git a/test/dialect/oracle/test_dialect.py b/test/dialect/oracle/test_dialect.py index e827fa56c0..60cdb2577a 100644 --- a/test/dialect/oracle/test_dialect.py +++ b/test/dialect/oracle/test_dialect.py @@ -31,6 +31,7 @@ from sqlalchemy.testing import config from sqlalchemy.testing import engines from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures +from sqlalchemy.testing.assertions import expect_raises_message from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table from sqlalchemy.testing.suite import test_select @@ -54,7 +55,7 @@ class DialectTest(fixtures.TestBase): ): assert_raises_message( exc.InvalidRequestError, - "cx_Oracle version 5.2 and above are supported", + "cx_Oracle version 7 and above are supported", cx_oracle.OracleDialect_cx_oracle, dbapi=mock.Mock(), ) @@ -62,7 +63,7 @@ class DialectTest(fixtures.TestBase): with mock.patch( "sqlalchemy.dialects.oracle.cx_oracle.OracleDialect_cx_oracle." "_parse_cx_oracle_ver", - lambda self, vers: (5, 3, 1), + lambda self, vers: (7, 1, 0), ): cx_oracle.OracleDialect_cx_oracle(dbapi=mock.Mock()) @@ -301,19 +302,6 @@ class EncodingErrorsTest(fixtures.TestBase): else: assert_raises(UnicodeDecodeError, outconverter, utf8_w_errors) - @_oracle_char_combinations - def test_older_cx_oracle_warning(self, cx_Oracle, cx_oracle_type): - cx_Oracle.version = "6.3" - - with testing.expect_warnings( - r"cx_oracle version \(6, 3\) does not support encodingErrors" - ): - dialect = cx_oracle.dialect( - dbapi=cx_Oracle, encoding_errors="ignore" - ) - - eq_(dialect._cursor_var_unicode_kwargs, {}) - @_oracle_char_combinations def test_encoding_errors_cx_oracle( self, @@ -478,6 +466,23 @@ end; eq_(result.out_parameters, {"x_out": 10, "y_out": 75, "z_out": None}) assert isinstance(result.out_parameters["x_out"], int) + def test_no_out_params_w_returning(self, connection, metadata): + t = Table("t", metadata, Column("x", Integer), Column("y", Integer)) + metadata.create_all(connection) + stmt = ( + t.insert() + .values(x=5, y=10) + .returning(outparam("my_param", Integer), t.c.x) + ) + + with expect_raises_message( + exc.InvalidRequestError, + r"Using explicit outparam\(\) objects with " + r"UpdateBase.returning\(\) in the same Core DML statement " + "is not supported in the Oracle dialect.", + ): + connection.execute(stmt) + @classmethod def teardown_test_class(cls): with testing.db.begin() as conn: diff --git a/test/engine/test_deprecations.py b/test/engine/test_deprecations.py index e1c610701c..8f0a4b0229 100644 --- a/test/engine/test_deprecations.py +++ b/test/engine/test_deprecations.py @@ -483,15 +483,15 @@ class ImplicitReturningFlagTest(fixtures.TestBase): ): stmt.compile(conn) else: - eq_(stmt.compile(conn).returning, [t.c.id]) + eq_(stmt.compile(conn).implicit_returning, [t.c.id]) elif ( implicit_returning is None and testing.db.dialect.implicit_returning ): - eq_(stmt.compile(conn).returning, [t.c.id]) + eq_(stmt.compile(conn).implicit_returning, [t.c.id]) else: - eq_(stmt.compile(conn).returning, []) + eq_(stmt.compile(conn).implicit_returning, []) # table setting it to False disables it stmt2 = insert(t2).values(data="data") - eq_(stmt2.compile(conn).returning, []) + eq_(stmt2.compile(conn).implicit_returning, []) diff --git a/test/ext/declarative/test_reflection.py b/test/ext/declarative/test_reflection.py index 0b079055cd..c3e5b586a6 100644 --- a/test/ext/declarative/test_reflection.py +++ b/test/ext/declarative/test_reflection.py @@ -224,8 +224,10 @@ class DeferredReflectionTest(DeferredReflectBase): eq_(len(_DeferredMapperConfig._configs), 2) del Address gc_collect() + gc_collect() eq_(len(_DeferredMapperConfig._configs), 1) DeferredReflection.prepare(testing.db) + gc_collect() assert not _DeferredMapperConfig._configs diff --git a/test/orm/test_session.py b/test/orm/test_session.py index 8147589368..8e568aef05 100644 --- a/test/orm/test_session.py +++ b/test/orm/test_session.py @@ -1585,6 +1585,8 @@ class WeakIdentityMapTest(_fixtures.FixtureTest): user_is = user._sa_instance_state del user gc_collect() + gc_collect() + gc_collect() assert user_is.obj() is None assert len(s.identity_map) == 0 diff --git a/test/orm/test_transaction.py b/test/orm/test_transaction.py index bc84d44475..ec4fb2d682 100644 --- a/test/orm/test_transaction.py +++ b/test/orm/test_transaction.py @@ -1267,6 +1267,8 @@ class AutoExpireTest(_LocalFixture): assert u1_state not in s._deleted del u1 gc_collect() + gc_collect() + gc_collect() assert u1_state.obj() is None s.rollback() diff --git a/test/sql/test_insert.py b/test/sql/test_insert.py index 74a60bd215..5f02fde4c9 100644 --- a/test/sql/test_insert.py +++ b/test/sql/test_insert.py @@ -1423,7 +1423,7 @@ class MultirowTest(_InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL): stmt = table.insert().return_defaults().values(id=func.foobar()) compiled = stmt.compile(dialect=sqlite.dialect(), column_keys=["data"]) eq_(compiled.postfetch, []) - eq_(compiled.returning, []) + eq_(compiled.implicit_returning, []) self.assert_compile( stmt, @@ -1452,7 +1452,7 @@ class MultirowTest(_InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL): dialect=returning_dialect, column_keys=["data"] ) eq_(compiled.postfetch, []) - eq_(compiled.returning, [table.c.id]) + eq_(compiled.implicit_returning, [table.c.id]) self.assert_compile( stmt, @@ -1482,7 +1482,7 @@ class MultirowTest(_InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL): dialect=returning_dialect, column_keys=["data"] ) eq_(compiled.postfetch, []) - eq_(compiled.returning, [table.c.id]) + eq_(compiled.implicit_returning, [table.c.id]) self.assert_compile( stmt,