From: Mike Bayer Date: Tue, 30 Apr 2019 03:26:36 +0000 (-0400) Subject: Implement new ClauseElement role and coercion system X-Git-Tag: rel_1_4_0b1~869 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=f07e050c9ce4afdeb9c0c136dbcc547f7e5ac7b8;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Implement new ClauseElement role and coercion system A major refactoring of all the functions handle all detection of Core argument types as well as perform coercions into a new class hierarchy based on "roles", each of which identify a syntactical location within a SQL statement. In contrast to the ClauseElement hierarchy that identifies "what" each object is syntactically, the SQLRole hierarchy identifies the "where does it go" of each object syntactically. From this we define a consistent type checking and coercion system that establishes well defined behviors. This is a breakout of the patch that is reorganizing select() constructs to no longer be in the FromClause hierarchy. Also includes a rename of as_scalar() into scalar_subquery(); deprecates automatic coercion to scalar_subquery(). Partially-fixes: #4617 Change-Id: I26f1e78898693c6b99ef7ea2f4e7dfd0e8e1a1bd --- diff --git a/doc/build/changelog/unreleased_14/4617_coercion.rst b/doc/build/changelog/unreleased_14/4617_coercion.rst new file mode 100644 index 0000000000..93be6d57dc --- /dev/null +++ b/doc/build/changelog/unreleased_14/4617_coercion.rst @@ -0,0 +1,13 @@ +.. change:: + :tags: sql, change + :tickets: 4617 + + The "clause coercion" system, which is SQLAlchemy Core's system of receiving + arguments and resolving them into :class:`.ClauseElement` structures in order + to build up SQL expression objects, has been rewritten from a series of + ad-hoc functions to a fully consistent class-based system. This change + is internal and should have no impact on end users other than more specific + error messages when the wrong kind of argument is passed to an expression + object, however the change is part of a larger set of changes involving + the role and behavior of :func:`.select` objects. + diff --git a/doc/build/changelog/unreleased_14/4617_scalar.rst b/doc/build/changelog/unreleased_14/4617_scalar.rst new file mode 100644 index 0000000000..3f22414f7f --- /dev/null +++ b/doc/build/changelog/unreleased_14/4617_scalar.rst @@ -0,0 +1,19 @@ +.. change:: + :tags: change, sql + :tickets: 4617 + + The :meth:`.SelectBase.as_scalar` and :meth:`.Query.as_scalar` methods have + been renamed to :meth:`.SelectBase.scalar_subquery` and :meth:`.Query.scalar_subquery`, + respectively. The old names continue to exist within 1.4 series with a deprecation + warning. In addition, the implicit coercion of :class:`.SelectBase`, :class:`.Alias`, + and other SELECT oriented objects into scalar subqueries when evaluated in a column + context is also deprecated, and emits a warning that the :meth:`.SelectBase.scalar_subquery` + method should be called explicitly. This warning will in a later major release + become an error, however the message will always be clear when :meth:`.SelectBase.scalar_subquery` + needs to be invoked. The latter part of the change is for clarity and to reduce the + implicit decisionmaking by the query coercion system. + + This change is part of the larger change to convert :func:`.select` objects to no + longer be directly part of the "from clause" class hierarchy, which also includes + an overhaul of the clause coercion system. + diff --git a/doc/build/core/tutorial.rst b/doc/build/core/tutorial.rst index cdacedadfe..4409996bdf 100644 --- a/doc/build/core/tutorial.rst +++ b/doc/build/core/tutorial.rst @@ -1590,24 +1590,24 @@ is often a :term:`correlated subquery`, which relies upon the enclosing SELECT statement in order to acquire at least one of its FROM clauses. The :func:`.select` construct can be modified to act as a -column expression by calling either the :meth:`~.SelectBase.as_scalar` +column expression by calling either the :meth:`~.SelectBase.scalar_subquery` or :meth:`~.SelectBase.label` method: .. sourcecode:: pycon+sql - >>> stmt = select([func.count(addresses.c.id)]).\ + >>> subq = select([func.count(addresses.c.id)]).\ ... where(users.c.id == addresses.c.user_id).\ - ... as_scalar() + ... scalar_subquery() The above construct is now a :class:`~.expression.ScalarSelect` object, -and is no longer part of the :class:`~.expression.FromClause` hierarchy; -it instead is within the :class:`~.expression.ColumnElement` family of -expression constructs. We can place this construct the same as any +which is an adapter around the original :class:`.~expression.Select` +object; it participates within the :class:`~.expression.ColumnElement` +family of expression constructs. We can place this construct the same as any other column within another :func:`.select`: .. sourcecode:: pycon+sql - >>> conn.execute(select([users.c.name, stmt])).fetchall() + >>> conn.execute(select([users.c.name, subq])).fetchall() {opensql}SELECT users.name, (SELECT count(addresses.id) AS count_1 FROM addresses WHERE users.id = addresses.user_id) AS anon_1 @@ -1620,10 +1620,10 @@ it using :meth:`.SelectBase.label` instead: .. sourcecode:: pycon+sql - >>> stmt = select([func.count(addresses.c.id)]).\ + >>> subq = select([func.count(addresses.c.id)]).\ ... where(users.c.id == addresses.c.user_id).\ ... label("address_count") - >>> conn.execute(select([users.c.name, stmt])).fetchall() + >>> conn.execute(select([users.c.name, subq])).fetchall() {opensql}SELECT users.name, (SELECT count(addresses.id) AS count_1 FROM addresses WHERE users.id = addresses.user_id) AS address_count @@ -1633,7 +1633,7 @@ it using :meth:`.SelectBase.label` instead: .. seealso:: - :meth:`.Select.as_scalar` + :meth:`.Select.scalar_subquery` :meth:`.Select.label` @@ -1642,7 +1642,7 @@ it using :meth:`.SelectBase.label` instead: Correlated Subqueries --------------------- -Notice in the examples on :ref:`scalar_selects`, the FROM clause of each embedded +In the examples on :ref:`scalar_selects`, the FROM clause of each embedded select did not contain the ``users`` table in its FROM clause. This is because SQLAlchemy automatically :term:`correlates` embedded FROM objects to that of an enclosing query, if present, and if the inner SELECT statement would @@ -1653,7 +1653,8 @@ still have at least one FROM clause of its own. For example: >>> stmt = select([addresses.c.user_id]).\ ... where(addresses.c.user_id == users.c.id).\ ... where(addresses.c.email_address == 'jack@yahoo.com') - >>> enclosing_stmt = select([users.c.name]).where(users.c.id == stmt) + >>> enclosing_stmt = select([users.c.name]).\ + ... where(users.c.id == stmt.scalar_subquery()) >>> conn.execute(enclosing_stmt).fetchall() {opensql}SELECT users.name FROM users @@ -1679,7 +1680,7 @@ may be correlated: >>> enclosing_stmt = select( ... [users.c.name, addresses.c.email_address]).\ ... select_from(users.join(addresses)).\ - ... where(users.c.id == stmt) + ... where(users.c.id == stmt.scalar_subquery()) >>> conn.execute(enclosing_stmt).fetchall() {opensql}SELECT users.name, addresses.email_address FROM users JOIN addresses ON users.id = addresses.user_id @@ -1698,7 +1699,7 @@ as the argument: ... where(users.c.name == 'wendy').\ ... correlate(None) >>> enclosing_stmt = select([users.c.name]).\ - ... where(users.c.id == stmt) + ... where(users.c.id == stmt.scalar_subquery()) >>> conn.execute(enclosing_stmt).fetchall() {opensql}SELECT users.name FROM users @@ -1721,7 +1722,7 @@ by telling it to correlate all FROM clauses except for ``users``: >>> enclosing_stmt = select( ... [users.c.name, addresses.c.email_address]).\ ... select_from(users.join(addresses)).\ - ... where(users.c.id == stmt) + ... where(users.c.id == stmt.scalar_subquery()) >>> conn.execute(enclosing_stmt).fetchall() {opensql}SELECT users.name, addresses.email_address FROM users JOIN addresses ON users.id = addresses.user_id diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index d2c84f446a..00a110aa22 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -672,6 +672,7 @@ from ... import util from ...engine import default from ...engine import reflection from ...sql import compiler +from ...sql import elements from ...sql import expression from ...sql import quoted_name from ...sql import util as sql_util @@ -1671,9 +1672,7 @@ class MSSQLCompiler(compiler.SQLCompiler): # translate for schema-qualified table aliases t = self._schema_aliased_table(column.table) if t is not None: - converted = expression._corresponding_column_or_error( - t, column - ) + converted = elements._corresponding_column_or_error(t, column) if add_to_result_map is not None: add_to_result_map( column.name, diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index 44f90c47cd..9cae3c689f 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -788,8 +788,10 @@ from ... import types as sqltypes from ... import util from ...engine import default from ...engine import reflection +from ...sql import coercions from ...sql import compiler from ...sql import elements +from ...sql import roles from ...types import BINARY from ...types import BLOB from ...types import BOOLEAN @@ -1218,7 +1220,7 @@ class MySQLCompiler(compiler.SQLCompiler): def visit_on_duplicate_key_update(self, on_duplicate, **kw): if on_duplicate._parameter_ordering: parameter_ordering = [ - elements._column_as_key(key) + coercions.expect(roles.DMLColumnRole, key) for key in on_duplicate._parameter_ordering ] ordered_keys = set(parameter_ordering) @@ -1238,7 +1240,7 @@ class MySQLCompiler(compiler.SQLCompiler): val = on_duplicate.update.get(column.key) if val is None: continue - elif elements._is_literal(val): + elif coercions._is_literal(val): val = elements.BindParameter(None, val, type_=column.type) value_text = self.process(val.self_group(), use_schema=False) elif isinstance(val, elements.BindParameter) and val.type._isnull: diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index ceb6246442..f18bec932b 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -932,9 +932,11 @@ from ... import sql from ... import util from ...engine import default from ...engine import reflection +from ...sql import coercions from ...sql import compiler from ...sql import elements from ...sql import expression +from ...sql import roles from ...sql import sqltypes from ...sql import util as sql_util from ...types import BIGINT @@ -1774,7 +1776,7 @@ class PGCompiler(compiler.SQLCompiler): col_key = c.key if col_key in set_parameters: value = set_parameters.pop(col_key) - if elements._is_literal(value): + if coercions._is_literal(value): value = elements.BindParameter(None, value, type_=c.type) else: @@ -1806,7 +1808,8 @@ class PGCompiler(compiler.SQLCompiler): else self.process(k, use_schema=False) ) value_text = self.process( - elements._literal_as_binds(v), use_schema=False + coercions.expect(roles.ExpressionElementRole, v), + use_schema=False, ) action_set_ops.append("%s = %s" % (key_text, value_text)) diff --git a/lib/sqlalchemy/dialects/postgresql/ext.py b/lib/sqlalchemy/dialects/postgresql/ext.py index 4260282397..f9cbc945ac 100644 --- a/lib/sqlalchemy/dialects/postgresql/ext.py +++ b/lib/sqlalchemy/dialects/postgresql/ext.py @@ -6,9 +6,12 @@ # the MIT License: http://www.opensource.org/licenses/mit-license.php from .array import ARRAY +from ... import util +from ...sql import coercions from ...sql import elements from ...sql import expression from ...sql import functions +from ...sql import roles from ...sql.schema import ColumnCollectionConstraint @@ -50,16 +53,18 @@ class aggregate_order_by(expression.ColumnElement): __visit_name__ = "aggregate_order_by" def __init__(self, target, *order_by): - self.target = elements._literal_as_binds(target) + self.target = coercions.expect(roles.ExpressionElementRole, target) _lob = len(order_by) if _lob == 0: raise TypeError("at least one ORDER BY element is required") elif _lob == 1: - self.order_by = elements._literal_as_binds(order_by[0]) + self.order_by = coercions.expect( + roles.ExpressionElementRole, order_by[0] + ) else: self.order_by = elements.ClauseList( - *order_by, _literal_as_text=elements._literal_as_binds + *order_by, _literal_as_text_role=roles.ExpressionElementRole ) def self_group(self, against=None): @@ -166,7 +171,10 @@ class ExcludeConstraint(ColumnCollectionConstraint): expressions, operators = zip(*elements) for (expr, column, strname, add_element), operator in zip( - self._extract_col_expression_collection(expressions), operators + coercions.expect_col_expression_collection( + roles.DDLConstraintColumnRole, expressions + ), + operators, ): if add_element is not None: columns.append(add_element) @@ -177,8 +185,6 @@ class ExcludeConstraint(ColumnCollectionConstraint): # backwards compat self.operators[name] = operator - expr = expression._literal_as_column(expr) - render_exprs.append((expr, name, operator)) self._render_exprs = render_exprs @@ -193,9 +199,21 @@ class ExcludeConstraint(ColumnCollectionConstraint): self.using = kw.get("using", "gist") where = kw.get("where") if where is not None: - self.where = expression._literal_as_text( - where, allow_coercion_to_text=True + self.where = coercions.expect(roles.StatementOptionRole, where) + + def _set_parent(self, table): + super(ExcludeConstraint, self)._set_parent(table) + + self._render_exprs = [ + ( + expr if isinstance(expr, elements.ClauseElement) else colexpr, + name, + operator, + ) + for (expr, name, operator), colexpr in util.zip_longest( + self._render_exprs, self.columns ) + ] def copy(self, **kw): elements = [(col, self.operators[col]) for col in self.columns.keys()] diff --git a/lib/sqlalchemy/engine/result.py b/lib/sqlalchemy/engine/result.py index fcb1d41554..d0b16a7457 100644 --- a/lib/sqlalchemy/engine/result.py +++ b/lib/sqlalchemy/engine/result.py @@ -626,7 +626,7 @@ class ResultMetaData(object): if raiseerr: raise exc.NoSuchColumnError( "Could not locate column in row for column '%s'" - % expression._string_or_unprintable(key) + % util.string_or_unprintable(key) ) else: return None diff --git a/lib/sqlalchemy/ext/baked.py b/lib/sqlalchemy/ext/baked.py index 09a9d73b73..4bb55ca493 100644 --- a/lib/sqlalchemy/ext/baked.py +++ b/lib/sqlalchemy/ext/baked.py @@ -269,7 +269,8 @@ class BakedQuery(object): User.id == Address.user_id).correlate(Address) main_bq = self.bakery( - lambda s: s.query(Address.id, sub_bq.to_query(q).as_scalar()) + lambda s: s.query( + Address.id, sub_bq.to_query(q).scalar_subquery()) ) :param query_or_session: a :class:`.Query` object or a class diff --git a/lib/sqlalchemy/ext/declarative/clsregistry.py b/lib/sqlalchemy/ext/declarative/clsregistry.py index 25752f2e82..1d05ddc996 100644 --- a/lib/sqlalchemy/ext/declarative/clsregistry.py +++ b/lib/sqlalchemy/ext/declarative/clsregistry.py @@ -217,7 +217,7 @@ class _GetColumns(object): mp = class_mapper(self.cls, configure=False) if mp: if key not in mp.all_orm_descriptors: - raise exc.InvalidRequestError( + raise AttributeError( "Class %r does not have a mapped column named %r" % (self.cls, key) ) diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index 1b8c8c7f33..f37928cc1e 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -109,10 +109,6 @@ class QueryableAttribute( instance_state(instance), instance_dict(instance), passive ) - def __selectable__(self): - # TODO: conditionally attach this method based on clause_element ? - return self - @util.memoized_property def info(self): """Return the 'info' dictionary for the underlying SQL element. diff --git a/lib/sqlalchemy/orm/base.py b/lib/sqlalchemy/orm/base.py index b64b35e72e..f809d5891b 100644 --- a/lib/sqlalchemy/orm/base.py +++ b/lib/sqlalchemy/orm/base.py @@ -15,7 +15,6 @@ from . import exc from .. import exc as sa_exc from .. import inspection from .. import util -from ..sql import expression PASSIVE_NO_RESULT = util.symbol( @@ -356,13 +355,6 @@ def _is_mapped_class(entity): ) -def _attr_as_key(attr): - if hasattr(attr, "key"): - return attr.key - else: - return expression._column_as_key(attr) - - def _orm_columns(entity): insp = inspection.inspect(entity, False) if hasattr(insp, "selectable") and hasattr(insp.selectable, "c"): diff --git a/lib/sqlalchemy/orm/collections.py b/lib/sqlalchemy/orm/collections.py index b9297e15c9..6bd009fb89 100644 --- a/lib/sqlalchemy/orm/collections.py +++ b/lib/sqlalchemy/orm/collections.py @@ -110,8 +110,9 @@ from sqlalchemy.util.compat import inspect_getfullargspec from . import base from .. import exc as sa_exc from .. import util +from ..sql import coercions from ..sql import expression - +from ..sql import roles __all__ = [ "collection", @@ -243,7 +244,7 @@ def column_mapped_collection(mapping_spec): """ cols = [ - expression._only_column_elements(q, "mapping_spec") + coercions.expect(roles.ColumnArgumentRole, q, argname="mapping_spec") for q in util.to_list(mapping_spec) ] keyfunc = _PlainColumnGetter(cols) diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index d2b08a9087..5098a55ce3 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -35,6 +35,7 @@ from .base import MANYTOONE from .base import NOT_EXTENSION from .base import ONETOMANY from .. import inspect +from .. import inspection from .. import util from ..sql import operators @@ -270,6 +271,7 @@ class MapperProperty(_MappedAttribute, InspectionAttr, util.MemoizedSlots): ) +@inspection._self_inspects class PropComparator(operators.ColumnOperators): r"""Defines SQL operators for :class:`.MapperProperty` objects. diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 6eadffb160..ccf05a7833 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -45,8 +45,10 @@ from .. import log from .. import schema from .. import sql from .. import util +from ..sql import coercions from ..sql import expression from ..sql import operators +from ..sql import roles from ..sql import util as sql_util from ..sql import visitors @@ -644,8 +646,14 @@ class Mapper(InspectionAttr): self.batch = batch self.eager_defaults = eager_defaults self.column_prefix = column_prefix - self.polymorphic_on = expression._clause_element_as_expr( - polymorphic_on + self.polymorphic_on = ( + coercions.expect( + roles.ColumnArgumentOrKeyRole, + polymorphic_on, + argname="polymorphic_on", + ) + if polymorphic_on is not None + else None ) self._dependency_processors = [] self.validators = util.immutabledict() @@ -1548,14 +1556,6 @@ class Mapper(InspectionAttr): "can be passed for polymorphic_on" ) prop = self.polymorphic_on - elif not expression._is_column(self.polymorphic_on): - # polymorphic_on is not a Column and not a ColumnProperty; - # not supported right now. - raise sa_exc.ArgumentError( - "Only direct column-mapped " - "property or SQL expression " - "can be passed for polymorphic_on" - ) else: # polymorphic_on is a Column or SQL expression and # doesn't appear to be mapped. this means it can be 1. @@ -1851,11 +1851,7 @@ class Mapper(InspectionAttr): # generate a properties.ColumnProperty columns = util.to_list(prop) column = columns[0] - if not expression._is_column(column): - raise sa_exc.ArgumentError( - "%s=%r is not an instance of MapperProperty or Column" - % (key, prop) - ) + assert isinstance(column, expression.ColumnElement) prop = self._props.get(key, None) @@ -2260,6 +2256,9 @@ class Mapper(InspectionAttr): for table, columns in self._cols_by_table.items() ) + def __clause_element__(self): + return self.selectable + @property def selectable(self): """The :func:`.select` construct this :class:`.Mapper` selects from diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index e686b61c36..5106bff94e 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -28,7 +28,9 @@ from .base import state_str from .. import exc as sa_exc from .. import sql from .. import util +from ..sql import coercions from ..sql import expression +from ..sql import roles from ..sql.base import _from_objects @@ -1914,7 +1916,7 @@ class BulkUpdateEvaluate(BulkEvaluate, BulkUpdate): values = self._resolved_values_keys_as_propnames for key, value in values: self.value_evaluators[key] = evaluator_compiler.process( - expression._literal_as_binds(value) + coercions.expect(roles.ExpressionElementRole, value) ) def _do_post_synchronize(self): diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 2b323429fc..e2c10e50aa 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -19,7 +19,8 @@ from .interfaces import StrategizedProperty from .util import _orm_full_deannotate from .. import log from .. import util -from ..sql import expression +from ..sql import coercions +from ..sql import roles __all__ = ["ColumnProperty"] @@ -131,9 +132,14 @@ class ColumnProperty(StrategizedProperty): """ super(ColumnProperty, self).__init__() - self._orig_columns = [expression._labeled(c) for c in columns] + self._orig_columns = [ + coercions.expect(roles.LabeledColumnExprRole, c) for c in columns + ] self.columns = [ - expression._labeled(_orm_full_deannotate(c)) for c in columns + coercions.expect( + roles.LabeledColumnExprRole, _orm_full_deannotate(c) + ) + for c in columns ] self.group = kwargs.pop("group", None) self.deferred = kwargs.pop("deferred", False) diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 6d777ceae6..8ef1663a12 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -47,11 +47,12 @@ from .. import inspection from .. import log from .. import sql from .. import util +from ..sql import coercions from ..sql import expression +from ..sql import roles from ..sql import util as sql_util from ..sql import visitors from ..sql.base import ColumnCollection -from ..sql.expression import _interpret_as_from from ..sql.selectable import ForUpdateArg @@ -252,7 +253,7 @@ class Query(object): "expected when the base alias is being set." ) fa.append(info.selectable) - elif not info.is_selectable: + elif not info.is_clause_element or not info._is_from_clause: raise sa_exc.ArgumentError( "argument is not a mapped class, mapper, " "aliased(), or FromClause instance." @@ -309,9 +310,7 @@ class Query(object): def _adapt_col_list(self, cols): return [ - self._adapt_clause( - expression._literal_as_label_reference(o), True, True - ) + self._adapt_clause(coercions.expect(roles.ByOfRole, o), True, True) for o in cols ] @@ -637,6 +636,12 @@ class Query(object): return self.enable_eagerloads(False).statement.label(name) + @util.deprecated( + "1.4", + "The :meth:`.Query.as_scalar` method is deprecated and will be " + "removed in a future release. Please refer to " + ":meth:`.Query.scalar_subquery`.", + ) def as_scalar(self): """Return the full SELECT statement represented by this :class:`.Query`, converted to a scalar subquery. @@ -644,19 +649,20 @@ class Query(object): Analogous to :meth:`sqlalchemy.sql.expression.SelectBase.as_scalar`. """ + return self.scalar_subquery() - return self.enable_eagerloads(False).statement.as_scalar() - - @property - def selectable(self): - """Return the :class:`.Select` object emitted by this :class:`.Query`. - - Used for :func:`.inspect` compatibility, this is equivalent to:: + def scalar_subquery(self): + """Return the full SELECT statement represented by this + :class:`.Query`, converted to a scalar subquery. - query.enable_eagerloads(False).with_labels().statement + Analogous to + :meth:`sqlalchemy.sql.expression.SelectBase.scalar_subquery`. + .. versionchanged:: 1.4 the :meth:`.Query.scalar_subquery` method + replaces the :meth:`.Query.as_scalar` method. """ - return self.__clause_element__() + + return self.enable_eagerloads(False).statement.scalar_subquery() def __clause_element__(self): return self.enable_eagerloads(False).with_labels().statement @@ -1094,7 +1100,9 @@ class Query(object): self._correlate = self._correlate.union([None]) else: self._correlate = self._correlate.union( - sql_util.surface_selectables(_interpret_as_from(s)) + sql_util.surface_selectables( + coercions.expect(roles.FromClauseRole, s) + ) ) @_generative() @@ -1757,7 +1765,7 @@ class Query(object): """ for criterion in list(criterion): - criterion = expression._expression_literal_as_text(criterion) + criterion = coercions.expect(roles.WhereHavingRole, criterion) criterion = self._adapt_clause(criterion, True, True) @@ -1870,7 +1878,7 @@ class Query(object): """ - criterion = expression._expression_literal_as_text(criterion) + criterion = coercions.expect(roles.WhereHavingRole, criterion) if criterion is not None and not isinstance( criterion, sql.ClauseElement @@ -3184,7 +3192,7 @@ class Query(object): ORM tutorial """ - statement = expression._expression_literal_as_text(statement) + statement = coercions.expect(roles.SelectStatementRole, statement) if not isinstance( statement, (expression.TextClause, expression.SelectBase) @@ -4651,7 +4659,7 @@ class QueryContext(object): if query._statement is not None: if ( isinstance(query._statement, expression.SelectBase) - and not query._statement._textual + and not query._statement._is_textual and not query._statement.use_labels ): self.statement = query._statement.apply_labels() diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py index b40fad332b..8b03955e22 100644 --- a/lib/sqlalchemy/orm/relationships.py +++ b/lib/sqlalchemy/orm/relationships.py @@ -36,8 +36,10 @@ from .. import schema from .. import sql from .. import util from ..inspection import inspect +from ..sql import coercions from ..sql import expression from ..sql import operators +from ..sql import roles from ..sql import visitors from ..sql.util import _deep_deannotate from ..sql.util import _shallow_annotate @@ -63,7 +65,7 @@ def remote(expr): """ return _annotate_columns( - expression._clause_element_as_expr(expr), {"remote": True} + coercions.expect(roles.ColumnArgumentRole, expr), {"remote": True} ) @@ -83,7 +85,7 @@ def foreign(expr): """ return _annotate_columns( - expression._clause_element_as_expr(expr), {"foreign": True} + coercions.expect(roles.ColumnArgumentRole, expr), {"foreign": True} ) @@ -1897,7 +1899,9 @@ class RelationshipProperty(StrategizedProperty): self, attr, _orm_deannotate( - expression._only_column_elements(val, attr) + coercions.expect( + roles.ColumnArgumentRole, val, argname=attr + ) ), ) @@ -1905,17 +1909,23 @@ class RelationshipProperty(StrategizedProperty): # remote_side are all columns, not strings. if self.order_by is not False and self.order_by is not None: self.order_by = [ - expression._only_column_elements(x, "order_by") + coercions.expect( + roles.ColumnArgumentRole, x, argname="order_by" + ) for x in util.to_list(self.order_by) ] self._user_defined_foreign_keys = util.column_set( - expression._only_column_elements(x, "foreign_keys") + coercions.expect( + roles.ColumnArgumentRole, x, argname="foreign_keys" + ) for x in util.to_column_set(self._user_defined_foreign_keys) ) self.remote_side = util.column_set( - expression._only_column_elements(x, "remote_side") + coercions.expect( + roles.ColumnArgumentRole, x, argname="remote_side" + ) for x in util.to_column_set(self.remote_side) ) @@ -2653,7 +2663,9 @@ class JoinCondition(object): else: def repl(element): - if element in remote_side: + # use set() to avoid generating ``__eq__()`` expressions + # against each element + if element in set(remote_side): return element._annotate({"remote": True}) self.primaryjoin = visitors.replacement_traverse( diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 517fc2b36b..cc3361e904 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -32,7 +32,8 @@ from .. import exc as sa_exc from .. import sql from .. import util from ..inspection import inspect -from ..sql import expression +from ..sql import coercions +from ..sql import roles from ..sql import util as sql_util @@ -1257,9 +1258,7 @@ class Session(_SessionClassMethods): in order to execute the statement. """ - clause = expression._literal_as_text( - clause, allow_coercion_to_text=True - ) + clause = coercions.expect(roles.CoerceTextStatementRole, clause) if bind is None: bind = self.get_bind(mapper, clause=clause, **kw) diff --git a/lib/sqlalchemy/orm/strategy_options.py b/lib/sqlalchemy/orm/strategy_options.py index 912b6b5503..ebcb101adb 100644 --- a/lib/sqlalchemy/orm/strategy_options.py +++ b/lib/sqlalchemy/orm/strategy_options.py @@ -24,7 +24,8 @@ from .util import _orm_full_deannotate from .. import exc as sa_exc from .. import inspect from .. import util -from ..sql import expression as sql_expr +from ..sql import coercions +from ..sql import roles from ..sql.base import _generative from ..sql.base import Generative @@ -1543,7 +1544,9 @@ def with_expression(loadopt, key, expression): """ - expression = sql_expr._labeled(_orm_full_deannotate(expression)) + expression = coercions.expect( + roles.LabeledColumnExprRole, _orm_full_deannotate(expression) + ) return loadopt.set_column_strategy( (key,), {"query_expression": True}, opts={"expression": expression} diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 36da781a41..e574181068 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -628,6 +628,9 @@ class AliasedInsp(InspectionAttr): is_aliased_class = True "always returns True" + def __clause_element__(self): + return self.selectable + @property def class_(self): """Return the mapped class ultimately represented by this diff --git a/lib/sqlalchemy/sql/__init__.py b/lib/sqlalchemy/sql/__init__.py index fb5639ef37..00cafd8ff8 100644 --- a/lib/sqlalchemy/sql/__init__.py +++ b/lib/sqlalchemy/sql/__init__.py @@ -94,6 +94,22 @@ def __go(lcls): from .elements import ClauseList # noqa from .selectable import AnnotatedFromClause # noqa + from . import base + from . import coercions + from . import elements + from . import selectable + from . import schema + from . import sqltypes + from . import type_api + + base.coercions = elements.coercions = coercions + base.elements = elements + base.type_api = type_api + coercions.elements = elements + coercions.schema = schema + coercions.selectable = selectable + coercions.sqltypes = sqltypes + _prepare_annotations(ColumnElement, AnnotatedColumnElement) _prepare_annotations(FromClause, AnnotatedFromClause) _prepare_annotations(ClauseList, Annotated) diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index c5e5fd8a1b..9df0c932f9 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -17,6 +17,9 @@ from .visitors import ClauseVisitor from .. import exc from .. import util +coercions = None # type: types.ModuleType +elements = None # type: types.ModuleType +type_api = None # type: types.ModuleType PARSE_AUTOCOMMIT = util.symbol("PARSE_AUTOCOMMIT") NO_ARG = util.symbol("NO_ARG") @@ -589,8 +592,7 @@ class ColumnCollection(util.OrderedProperties): __hash__ = None - @util.dependencies("sqlalchemy.sql.elements") - def __eq__(self, elements, other): + def __eq__(self, other): l = [] for c in getattr(other, "_all_columns", other): for local in self._all_columns: @@ -636,8 +638,7 @@ class ColumnSet(util.ordered_column_set): def __add__(self, other): return list(self) + list(other) - @util.dependencies("sqlalchemy.sql.elements") - def __eq__(self, elements, other): + def __eq__(self, other): l = [] for c in other: for local in self: diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py new file mode 100644 index 0000000000..7c7222f9f3 --- /dev/null +++ b/lib/sqlalchemy/sql/coercions.py @@ -0,0 +1,580 @@ +# sql/coercions.py +# Copyright (C) 2005-2019 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +import numbers +import re + +from . import operators +from . import roles +from . import visitors +from .visitors import Visitable +from .. import exc +from .. import inspection +from .. import util +from ..util import collections_abc + +elements = None # type: types.ModuleType +schema = None # type: types.ModuleType +selectable = None # type: types.ModuleType +sqltypes = None # type: types.ModuleType + + +def _is_literal(element): + """Return whether or not the element is a "literal" in the context + of a SQL expression construct. + + """ + return not isinstance( + element, (Visitable, schema.SchemaEventTarget) + ) and not hasattr(element, "__clause_element__") + + +def _document_text_coercion(paramname, meth_rst, param_rst): + return util.add_parameter_text( + paramname, + ( + ".. warning:: " + "The %s argument to %s can be passed as a Python string argument, " + "which will be treated " + "as **trusted SQL text** and rendered as given. **DO NOT PASS " + "UNTRUSTED INPUT TO THIS PARAMETER**." + ) + % (param_rst, meth_rst), + ) + + +def expect(role, element, **kw): + # major case is that we are given a ClauseElement already, skip more + # elaborate logic up front if possible + impl = _impl_lookup[role] + + if not isinstance(element, (elements.ClauseElement, schema.SchemaItem)): + resolved = impl._resolve_for_clause_element(element, **kw) + else: + resolved = element + + if issubclass(resolved.__class__, impl._role_class): + if impl._post_coercion: + resolved = impl._post_coercion(resolved, **kw) + return resolved + else: + return impl._implicit_coercions(element, resolved, **kw) + + +def expect_as_key(role, element, **kw): + kw["as_key"] = True + return expect(role, element, **kw) + + +def expect_col_expression_collection(role, expressions): + for expr in expressions: + strname = None + column = None + + resolved = expect(role, expr) + if isinstance(resolved, util.string_types): + strname = resolved = expr + else: + cols = [] + visitors.traverse(resolved, {}, {"column": cols.append}) + if cols: + column = cols[0] + add_element = column if column is not None else strname + yield resolved, column, strname, add_element + + +class RoleImpl(object): + __slots__ = ("_role_class", "name", "_use_inspection") + + def _literal_coercion(self, element, **kw): + raise NotImplementedError() + + _post_coercion = None + + def __init__(self, role_class): + self._role_class = role_class + self.name = role_class._role_name + self._use_inspection = issubclass(role_class, roles.UsesInspection) + + def _resolve_for_clause_element(self, element, argname=None, **kw): + literal_coercion = self._literal_coercion + original_element = element + is_clause_element = False + + while hasattr(element, "__clause_element__") and not isinstance( + element, (elements.ClauseElement, schema.SchemaItem) + ): + element = element.__clause_element__() + is_clause_element = True + + if not is_clause_element: + if self._use_inspection: + insp = inspection.inspect(element, raiseerr=False) + if insp is not None: + try: + return insp.__clause_element__() + except AttributeError: + self._raise_for_expected(original_element, argname) + + return self._literal_coercion(element, argname=argname, **kw) + else: + return element + + def _implicit_coercions(self, element, resolved, argname=None, **kw): + self._raise_for_expected(element, argname) + + def _raise_for_expected(self, element, argname=None): + if argname: + raise exc.ArgumentError( + "%s expected for argument %r; got %r." + % (self.name, argname, element) + ) + else: + raise exc.ArgumentError( + "%s expected, got %r." % (self.name, element) + ) + + +class _StringOnly(object): + def _resolve_for_clause_element(self, element, argname=None, **kw): + return self._literal_coercion(element, **kw) + + +class _ReturnsStringKey(object): + def _implicit_coercions( + self, original_element, resolved, argname=None, **kw + ): + if isinstance(original_element, util.string_types): + return original_element + else: + self._raise_for_expected(original_element, argname) + + def _literal_coercion(self, element, **kw): + return element + + +class _ColumnCoercions(object): + def _warn_for_scalar_subquery_coercion(self): + util.warn_deprecated( + "coercing SELECT object to scalar subquery in a " + "column-expression context is deprecated in version 1.4; " + "please use the .scalar_subquery() method to produce a scalar " + "subquery. This automatic coercion will be removed in a " + "future release." + ) + + def _implicit_coercions( + self, original_element, resolved, argname=None, **kw + ): + if resolved._is_select_statement: + self._warn_for_scalar_subquery_coercion() + return resolved.scalar_subquery() + elif ( + resolved._is_from_clause + and isinstance(resolved, selectable.Alias) + and resolved.original._is_select_statement + ): + self._warn_for_scalar_subquery_coercion() + return resolved.original.scalar_subquery() + else: + self._raise_for_expected(original_element, argname) + + +def _no_text_coercion( + element, argname=None, exc_cls=exc.ArgumentError, extra=None +): + raise exc_cls( + "%(extra)sTextual SQL expression %(expr)r %(argname)sshould be " + "explicitly declared as text(%(expr)r)" + % { + "expr": util.ellipses_string(element), + "argname": "for argument %s" % (argname,) if argname else "", + "extra": "%s " % extra if extra else "", + } + ) + + +class _NoTextCoercion(object): + def _literal_coercion(self, element, argname=None): + if isinstance(element, util.string_types) and issubclass( + elements.TextClause, self._role_class + ): + _no_text_coercion(element, argname) + else: + self._raise_for_expected(element, argname) + + +class _CoerceLiterals(object): + _coerce_consts = False + _coerce_star = False + _coerce_numerics = False + + def _text_coercion(self, element, argname=None): + return _no_text_coercion(element, argname) + + def _literal_coercion(self, element, argname=None): + if isinstance(element, util.string_types): + if self._coerce_star and element == "*": + return elements.ColumnClause("*", is_literal=True) + else: + return self._text_coercion(element, argname) + + if self._coerce_consts: + if element is None: + return elements.Null() + elif element is False: + return elements.False_() + elif element is True: + return elements.True_() + + if self._coerce_numerics and isinstance(element, (numbers.Number)): + return elements.ColumnClause(str(element), is_literal=True) + + self._raise_for_expected(element, argname) + + +class ExpressionElementImpl( + _ColumnCoercions, RoleImpl, roles.ExpressionElementRole +): + def _literal_coercion(self, element, name=None, type_=None, argname=None): + if element is None: + return elements.Null() + else: + try: + return elements.BindParameter( + name, element, type_, unique=True + ) + except exc.ArgumentError: + self._raise_for_expected(element) + + +class BinaryElementImpl( + ExpressionElementImpl, RoleImpl, roles.BinaryElementRole +): + def _literal_coercion( + self, element, expr, operator, bindparam_type=None, argname=None + ): + try: + return expr._bind_param(operator, element, type_=bindparam_type) + except exc.ArgumentError: + self._raise_for_expected(element) + + def _post_coercion(self, resolved, expr, **kw): + if ( + isinstance(resolved, elements.BindParameter) + and resolved.type._isnull + ): + resolved = resolved._clone() + resolved.type = expr.type + return resolved + + +class InElementImpl(RoleImpl, roles.InElementRole): + def _implicit_coercions( + self, original_element, resolved, argname=None, **kw + ): + if resolved._is_from_clause: + if ( + isinstance(resolved, selectable.Alias) + and resolved.original._is_select_statement + ): + return resolved.original + else: + return resolved.select() + else: + self._raise_for_expected(original_element, argname) + + def _literal_coercion(self, element, expr, operator, **kw): + if isinstance(element, collections_abc.Iterable) and not isinstance( + element, util.string_types + ): + args = [] + for o in element: + if not _is_literal(o): + if not isinstance(o, operators.ColumnOperators): + self._raise_for_expected(element, **kw) + elif o is None: + o = elements.Null() + else: + o = expr._bind_param(operator, o) + args.append(o) + + return elements.ClauseList(*args) + + else: + self._raise_for_expected(element, **kw) + + def _post_coercion(self, element, expr, operator, **kw): + if element._is_select_statement: + return element.scalar_subquery() + elif isinstance(element, elements.ClauseList): + if len(element.clauses) == 0: + op, negate_op = ( + (operators.empty_in_op, operators.empty_notin_op) + if operator is operators.in_op + else (operators.empty_notin_op, operators.empty_in_op) + ) + return element.self_group(against=op)._annotate( + dict(in_ops=(op, negate_op)) + ) + else: + return element.self_group(against=operator) + + elif isinstance(element, elements.BindParameter) and element.expanding: + + if isinstance(expr, elements.Tuple): + element = element._with_expanding_in_types( + [elem.type for elem in expr] + ) + return element + else: + return element + + +class WhereHavingImpl( + _CoerceLiterals, _ColumnCoercions, RoleImpl, roles.WhereHavingRole +): + + _coerce_consts = True + + def _text_coercion(self, element, argname=None): + return _no_text_coercion(element, argname) + + +class StatementOptionImpl( + _CoerceLiterals, RoleImpl, roles.StatementOptionRole +): + + _coerce_consts = True + + def _text_coercion(self, element, argname=None): + return elements.TextClause(element) + + +class ColumnArgumentImpl(_NoTextCoercion, RoleImpl, roles.ColumnArgumentRole): + pass + + +class ColumnArgumentOrKeyImpl( + _ReturnsStringKey, RoleImpl, roles.ColumnArgumentOrKeyRole +): + pass + + +class ByOfImpl(_CoerceLiterals, _ColumnCoercions, RoleImpl, roles.ByOfRole): + + _coerce_consts = True + + def _text_coercion(self, element, argname=None): + return elements._textual_label_reference(element) + + +class OrderByImpl(ByOfImpl, RoleImpl, roles.OrderByRole): + def _post_coercion(self, resolved): + if ( + isinstance(resolved, self._role_class) + and resolved._order_by_label_element is not None + ): + return elements._label_reference(resolved) + else: + return resolved + + +class DMLColumnImpl(_ReturnsStringKey, RoleImpl, roles.DMLColumnRole): + def _post_coercion(self, element, as_key=False): + if as_key: + return element.key + else: + return element + + +class ConstExprImpl(RoleImpl, roles.ConstExprRole): + def _literal_coercion(self, element, argname=None): + if element is None: + return elements.Null() + elif element is False: + return elements.False_() + elif element is True: + return elements.True_() + else: + self._raise_for_expected(element, argname) + + +class TruncatedLabelImpl(_StringOnly, RoleImpl, roles.TruncatedLabelRole): + def _implicit_coercions( + self, original_element, resolved, argname=None, **kw + ): + if isinstance(original_element, util.string_types): + return resolved + else: + self._raise_for_expected(original_element, argname) + + def _literal_coercion(self, element, argname=None): + """coerce the given value to :class:`._truncated_label`. + + Existing :class:`._truncated_label` and + :class:`._anonymous_label` objects are passed + unchanged. + """ + + if isinstance(element, elements._truncated_label): + return element + else: + return elements._truncated_label(element) + + +class DDLExpressionImpl(_CoerceLiterals, RoleImpl, roles.DDLExpressionRole): + + _coerce_consts = True + + def _text_coercion(self, element, argname=None): + return elements.TextClause(element) + + +class DDLConstraintColumnImpl( + _ReturnsStringKey, RoleImpl, roles.DDLConstraintColumnRole +): + pass + + +class LimitOffsetImpl(RoleImpl, roles.LimitOffsetRole): + def _implicit_coercions(self, element, resolved, argname=None, **kw): + if resolved is None: + return None + else: + self._raise_for_expected(element, argname) + + def _literal_coercion(self, element, name, type_, **kw): + if element is None: + return None + else: + value = util.asint(element) + return selectable._OffsetLimitParam( + name, value, type_=type_, unique=True + ) + + +class LabeledColumnExprImpl( + ExpressionElementImpl, roles.LabeledColumnExprRole +): + def _implicit_coercions( + self, original_element, resolved, argname=None, **kw + ): + if isinstance(resolved, roles.ExpressionElementRole): + return resolved.label(None) + else: + new = super(LabeledColumnExprImpl, self)._implicit_coercions( + original_element, resolved, argname=argname, **kw + ) + if isinstance(new, roles.ExpressionElementRole): + return new.label(None) + else: + self._raise_for_expected(original_element, argname) + + +class ColumnsClauseImpl(_CoerceLiterals, RoleImpl, roles.ColumnsClauseRole): + + _coerce_consts = True + _coerce_numerics = True + _coerce_star = True + + _guess_straight_column = re.compile(r"^\w\S*$", re.I) + + def _text_coercion(self, element, argname=None): + element = str(element) + + guess_is_literal = not self._guess_straight_column.match(element) + raise exc.ArgumentError( + "Textual column expression %(column)r %(argname)sshould be " + "explicitly declared with text(%(column)r), " + "or use %(literal_column)s(%(column)r) " + "for more specificity" + % { + "column": util.ellipses_string(element), + "argname": "for argument %s" % (argname,) if argname else "", + "literal_column": "literal_column" + if guess_is_literal + else "column", + } + ) + + +class ReturnsRowsImpl(RoleImpl, roles.ReturnsRowsRole): + pass + + +class StatementImpl(_NoTextCoercion, RoleImpl, roles.StatementRole): + pass + + +class CoerceTextStatementImpl(_CoerceLiterals, RoleImpl, roles.StatementRole): + def _text_coercion(self, element, argname=None): + return elements.TextClause(element) + + +class SelectStatementImpl( + _NoTextCoercion, RoleImpl, roles.SelectStatementRole +): + def _implicit_coercions( + self, original_element, resolved, argname=None, **kw + ): + if resolved._is_text_clause: + return resolved.columns() + else: + self._raise_for_expected(original_element, argname) + + +class HasCTEImpl(ReturnsRowsImpl, roles.HasCTERole): + pass + + +class FromClauseImpl(_NoTextCoercion, RoleImpl, roles.FromClauseRole): + def _implicit_coercions( + self, original_element, resolved, argname=None, **kw + ): + if resolved._is_text_clause: + return resolved + else: + self._raise_for_expected(original_element, argname) + + +class DMLSelectImpl(_NoTextCoercion, RoleImpl, roles.DMLSelectRole): + def _implicit_coercions( + self, original_element, resolved, argname=None, **kw + ): + if resolved._is_from_clause: + if ( + isinstance(resolved, selectable.Alias) + and resolved.original._is_select_statement + ): + return resolved.original + else: + return resolved.select() + else: + self._raise_for_expected(original_element, argname) + + +class CompoundElementImpl( + _NoTextCoercion, RoleImpl, roles.CompoundElementRole +): + def _implicit_coercions(self, original_element, resolved, argname=None): + if resolved._is_from_clause: + return resolved + else: + self._raise_for_expected(original_element, argname) + + +_impl_lookup = {} + + +for name in dir(roles): + cls = getattr(roles, name) + if name.endswith("Role"): + name = name.replace("Role", "Impl") + if name in globals(): + impl = globals()[name](cls) + _impl_lookup[cls] = impl diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index c7fe3dc50e..8080d2cc66 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -27,14 +27,15 @@ import contextlib import itertools import re +from . import coercions from . import crud from . import elements from . import functions from . import operators +from . import roles from . import schema from . import selectable from . import sqltypes -from . import visitors from .. import exc from .. import util @@ -400,7 +401,9 @@ class TypeCompiler(util.with_metaclass(util.EnsureKWArgType, object)): return type_._compiler_dispatch(self, **kw) -class _CompileLabel(visitors.Visitable): +# this was a Visitable, but to allow accurate detection of +# column elements this is actually a column element +class _CompileLabel(elements.ColumnElement): """lightweight label object which acts as an expression.Label.""" @@ -766,10 +769,10 @@ class SQLCompiler(Compiled): else: col = with_cols[element.element] except KeyError: - elements._no_text_coercion( + coercions._no_text_coercion( element.element, - exc.CompileError, - "Can't resolve label reference for ORDER BY / GROUP BY.", + extra="Can't resolve label reference for ORDER BY / GROUP BY.", + exc_cls=exc.CompileError, ) else: kwargs["render_label_as_label"] = col @@ -1635,7 +1638,6 @@ class SQLCompiler(Compiled): if is_new_cte: self.ctes_by_name[cte_name] = cte - # look for embedded DML ctes and propagate autocommit if ( "autocommit" in cte.element._execution_options and "autocommit" not in self.execution_options @@ -1656,10 +1658,10 @@ class SQLCompiler(Compiled): self.ctes_recursive = True text = self.preparer.format_alias(cte, cte_name) if cte.recursive: - if isinstance(cte.original, selectable.Select): - col_source = cte.original - elif isinstance(cte.original, selectable.CompoundSelect): - col_source = cte.original.selects[0] + if isinstance(cte.element, selectable.Select): + col_source = cte.element + elif isinstance(cte.element, selectable.CompoundSelect): + col_source = cte.element.selects[0] else: assert False recur_cols = [ @@ -1810,7 +1812,7 @@ class SQLCompiler(Compiled): ): result_expr = _CompileLabel( col_expr, - elements._as_truncated(column.name), + coercions.expect(roles.TruncatedLabelRole, column.name), alt_names=(column.key,), ) elif ( @@ -1830,7 +1832,7 @@ class SQLCompiler(Compiled): # assert isinstance(column, elements.ColumnClause) result_expr = _CompileLabel( col_expr, - elements._as_truncated(column.name), + coercions.expect(roles.TruncatedLabelRole, column.name), alt_names=(column.key,), ) else: @@ -1880,7 +1882,7 @@ class SQLCompiler(Compiled): newelem = cloned[element] = element._clone() if ( - newelem.is_selectable + newelem._is_from_clause and newelem._is_join and isinstance(newelem.right, selectable.FromGrouping) ): @@ -1933,7 +1935,7 @@ class SQLCompiler(Compiled): # marker in the stack. kw["transform_clue"] = "select_container" newelem._copy_internals(clone=visit, **kw) - elif newelem.is_selectable and newelem._is_select: + elif newelem._is_returns_rows and newelem._is_select_statement: barrier_select = ( kw.get("transform_clue", None) == "select_container" ) @@ -2349,6 +2351,7 @@ class SQLCompiler(Compiled): + join_type + join.right._compiler_dispatch(self, asfrom=True, **kwargs) + " ON " + # TODO: likely need asfrom=True here? + join.onclause._compiler_dispatch(self, **kwargs) ) diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index 552f61b4a0..881ea9fcda 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -9,14 +9,16 @@ within INSERT and UPDATE statements. """ +import functools import operator +from . import coercions from . import dml from . import elements +from . import roles from .. import exc from .. import util - REQUIRED = util.symbol( "REQUIRED", """ @@ -174,7 +176,7 @@ def _get_crud_params(compiler, stmt, **kw): if check: raise exc.CompileError( "Unconsumed column names: %s" - % (", ".join("%s" % c for c in check)) + % (", ".join("%s" % (c,) for c in check)) ) if stmt._has_multi_parameters: @@ -207,8 +209,12 @@ def _key_getters_for_crud_column(compiler, stmt): # statement. _et = set(stmt._extra_froms) + c_key_role = functools.partial( + coercions.expect_as_key, roles.DMLColumnRole + ) + def _column_as_key(key): - str_key = elements._column_as_key(key) + str_key = c_key_role(key) if hasattr(key, "table") and key.table in _et: return (key.table.name, str_key) else: @@ -227,7 +233,9 @@ def _key_getters_for_crud_column(compiler, stmt): return col.key else: - _column_as_key = elements._column_as_key + _column_as_key = functools.partial( + coercions.expect_as_key, roles.DMLColumnRole + ) _getattr_col_key = _col_bind_name = operator.attrgetter("key") return _column_as_key, _getattr_col_key, _col_bind_name @@ -386,7 +394,7 @@ def _append_param_parameter( kw, ): value = parameters.pop(col_key) - if elements._is_literal(value): + if coercions._is_literal(value): value = _create_bind_param( compiler, c, @@ -633,9 +641,8 @@ def _get_multitable_params( values, kw, ): - normalized_params = dict( - (elements._clause_element_as_expr(c), param) + (coercions.expect(roles.DMLColumnRole, c), param) for c, param in stmt_parameters.items() ) affected_tables = set() @@ -645,7 +652,7 @@ def _get_multitable_params( affected_tables.add(t) check_columns[_getattr_col_key(c)] = c value = normalized_params[c] - if elements._is_literal(value): + if coercions._is_literal(value): value = _create_bind_param( compiler, c, @@ -697,7 +704,7 @@ def _extend_values_for_multiparams(compiler, stmt, values, kw): if col in row or col.key in row: key = col if col in row else col.key - if elements._is_literal(row[key]): + if coercions._is_literal(row[key]): new_param = _create_bind_param( compiler, col, @@ -730,7 +737,7 @@ def _get_stmt_parameters_params( # a non-Column expression on the left side; # add it to values() in an "as-is" state, # coercing right side to bound param - if elements._is_literal(v): + if coercions._is_literal(v): v = compiler.process( elements.BindParameter(None, v, type_=k.type), **kw ) diff --git a/lib/sqlalchemy/sql/ddl.py b/lib/sqlalchemy/sql/ddl.py index d87a6a1b04..ff36a68e4a 100644 --- a/lib/sqlalchemy/sql/ddl.py +++ b/lib/sqlalchemy/sql/ddl.py @@ -10,6 +10,7 @@ to invoke them for a create/drop call. """ +from . import roles from .base import _bind_or_error from .base import _generative from .base import Executable @@ -29,7 +30,7 @@ class _DDLCompiles(ClauseElement): return dialect.ddl_compiler(dialect, self, **kw) -class DDLElement(Executable, _DDLCompiles): +class DDLElement(roles.DDLRole, Executable, _DDLCompiles): """Base class for DDL expression constructs. This class is the base for the general purpose :class:`.DDL` class, diff --git a/lib/sqlalchemy/sql/default_comparator.py b/lib/sqlalchemy/sql/default_comparator.py index 9a12b84cd9..918f7524e5 100644 --- a/lib/sqlalchemy/sql/default_comparator.py +++ b/lib/sqlalchemy/sql/default_comparator.py @@ -8,32 +8,21 @@ """Default implementation of SQL comparison operations. """ + +from . import coercions from . import operators +from . import roles from . import type_api -from .elements import _clause_element_as_expr -from .elements import _const_expr -from .elements import _is_literal -from .elements import _literal_as_text from .elements import and_ from .elements import BinaryExpression -from .elements import BindParameter -from .elements import ClauseElement from .elements import ClauseList from .elements import collate from .elements import CollectionAggregate -from .elements import ColumnElement from .elements import False_ from .elements import Null from .elements import or_ -from .elements import TextClause from .elements import True_ -from .elements import Tuple from .elements import UnaryExpression -from .elements import Visitable -from .selectable import Alias -from .selectable import ScalarSelect -from .selectable import Selectable -from .selectable import SelectBase from .. import exc from .. import util @@ -62,7 +51,7 @@ def _boolean_compare( ): return BinaryExpression( expr, - _literal_as_text(obj), + coercions.expect(roles.ConstExprRole, obj), op, type_=result_type, negate=negate, @@ -71,7 +60,7 @@ def _boolean_compare( elif op in (operators.is_distinct_from, operators.isnot_distinct_from): return BinaryExpression( expr, - _literal_as_text(obj), + coercions.expect(roles.ConstExprRole, obj), op, type_=result_type, negate=negate, @@ -82,7 +71,7 @@ def _boolean_compare( if op in (operators.eq, operators.is_): return BinaryExpression( expr, - _const_expr(obj), + coercions.expect(roles.ConstExprRole, obj), operators.is_, negate=operators.isnot, type_=result_type, @@ -90,7 +79,7 @@ def _boolean_compare( elif op in (operators.ne, operators.isnot): return BinaryExpression( expr, - _const_expr(obj), + coercions.expect(roles.ConstExprRole, obj), operators.isnot, negate=operators.is_, type_=result_type, @@ -102,7 +91,9 @@ def _boolean_compare( "operators can be used with None/True/False" ) else: - obj = _check_literal(expr, op, obj) + obj = coercions.expect( + roles.BinaryElementRole, element=obj, operator=op, expr=expr + ) if reverse: return BinaryExpression( @@ -127,7 +118,9 @@ def _custom_op_operate(expr, op, obj, reverse=False, result_type=None, **kw): def _binary_operate(expr, op, obj, reverse=False, result_type=None, **kw): - obj = _check_literal(expr, op, obj) + obj = coercions.expect( + roles.BinaryElementRole, obj, expr=expr, operator=op + ) if reverse: left, right = obj, expr @@ -156,77 +149,22 @@ def _scalar(expr, op, fn, **kw): def _in_impl(expr, op, seq_or_selectable, negate_op, **kw): - seq_or_selectable = _clause_element_as_expr(seq_or_selectable) - - if isinstance(seq_or_selectable, ScalarSelect): - return _boolean_compare(expr, op, seq_or_selectable, negate=negate_op) - elif isinstance(seq_or_selectable, SelectBase): - - # TODO: if we ever want to support (x, y, z) IN (select x, - # y, z from table), we would need a multi-column version of - # as_scalar() to produce a multi- column selectable that - # does not export itself as a FROM clause - - return _boolean_compare( - expr, op, seq_or_selectable.as_scalar(), negate=negate_op, **kw - ) - elif isinstance(seq_or_selectable, (Selectable, TextClause)): - return _boolean_compare( - expr, op, seq_or_selectable, negate=negate_op, **kw - ) - elif isinstance(seq_or_selectable, ClauseElement): - if ( - isinstance(seq_or_selectable, BindParameter) - and seq_or_selectable.expanding - ): - - if isinstance(expr, Tuple): - seq_or_selectable = seq_or_selectable._with_expanding_in_types( - [elem.type for elem in expr] - ) - - return _boolean_compare( - expr, op, seq_or_selectable, negate=negate_op - ) - else: - raise exc.InvalidRequestError( - "in_() accepts" - " either a list of expressions, " - 'a selectable, or an "expanding" bound parameter: %r' - % seq_or_selectable - ) - - # Handle non selectable arguments as sequences - args = [] - for o in seq_or_selectable: - if not _is_literal(o): - if not isinstance(o, operators.ColumnOperators): - raise exc.InvalidRequestError( - "in_() accepts" - " either a list of expressions, " - 'a selectable, or an "expanding" bound parameter: %r' % o - ) - elif o is None: - o = Null() - else: - o = expr._bind_param(op, o) - args.append(o) - - if len(args) == 0: - op, negate_op = ( - (operators.empty_in_op, operators.empty_notin_op) - if op is operators.in_op - else (operators.empty_notin_op, operators.empty_in_op) - ) + seq_or_selectable = coercions.expect( + roles.InElementRole, seq_or_selectable, expr=expr, operator=op + ) + if "in_ops" in seq_or_selectable._annotations: + op, negate_op = seq_or_selectable._annotations["in_ops"] return _boolean_compare( - expr, op, ClauseList(*args).self_group(against=op), negate=negate_op + expr, op, seq_or_selectable, negate=negate_op, **kw ) def _getitem_impl(expr, op, other, **kw): if isinstance(expr.type, type_api.INDEXABLE): - other = _check_literal(expr, op, other) + other = coercions.expect( + roles.BinaryElementRole, other, expr=expr, operator=op + ) return _binary_operate(expr, op, other, **kw) else: _unsupported_impl(expr, op, other, **kw) @@ -257,7 +195,12 @@ def _match_impl(expr, op, other, **kw): return _boolean_compare( expr, operators.match_op, - _check_literal(expr, operators.match_op, other), + coercions.expect( + roles.BinaryElementRole, + other, + expr=expr, + operator=operators.match_op, + ), result_type=type_api.MATCHTYPE, negate=operators.notmatch_op if op is operators.match_op @@ -278,8 +221,18 @@ def _between_impl(expr, op, cleft, cright, **kw): return BinaryExpression( expr, ClauseList( - _check_literal(expr, operators.and_, cleft), - _check_literal(expr, operators.and_, cright), + coercions.expect( + roles.BinaryElementRole, + cleft, + expr=expr, + operator=operators.and_, + ), + coercions.expect( + roles.BinaryElementRole, + cright, + expr=expr, + operator=operators.and_, + ), operator=operators.and_, group=False, group_contents=False, @@ -349,22 +302,3 @@ operator_lookup = { "rshift": (_unsupported_impl,), "contains": (_unsupported_impl,), } - - -def _check_literal(expr, operator, other, bindparam_type=None): - if isinstance(other, (ColumnElement, TextClause)): - if isinstance(other, BindParameter) and other.type._isnull: - other = other._clone() - other.type = expr.type - return other - elif hasattr(other, "__clause_element__"): - other = other.__clause_element__() - elif isinstance(other, type_api.TypeEngine.Comparator): - other = other.expr - - if isinstance(other, (SelectBase, Alias)): - return other.as_scalar() - elif not isinstance(other, Visitable): - return expr._bind_param(operator, other, type_=bindparam_type) - else: - return other diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py index 3c40e79143..c7d83fc12b 100644 --- a/lib/sqlalchemy/sql/dml.py +++ b/lib/sqlalchemy/sql/dml.py @@ -9,18 +9,16 @@ Provide :class:`.Insert`, :class:`.Update` and :class:`.Delete`. """ +from . import coercions +from . import roles from .base import _from_objects from .base import _generative from .base import DialectKWArgs from .base import Executable from .elements import _clone -from .elements import _column_as_key -from .elements import _literal_as_text from .elements import and_ from .elements import ClauseElement from .elements import Null -from .selectable import _interpret_as_from -from .selectable import _interpret_as_select from .selectable import HasCTE from .selectable import HasPrefixes from .. import exc @@ -28,7 +26,12 @@ from .. import util class UpdateBase( - HasCTE, DialectKWArgs, HasPrefixes, Executable, ClauseElement + roles.DMLRole, + HasCTE, + DialectKWArgs, + HasPrefixes, + Executable, + ClauseElement, ): """Form the base for ``INSERT``, ``UPDATE``, and ``DELETE`` statements. @@ -210,7 +213,7 @@ class ValuesBase(UpdateBase): _post_values_clause = None def __init__(self, table, values, prefixes): - self.table = _interpret_as_from(table) + self.table = coercions.expect(roles.FromClauseRole, table) self.parameters, self._has_multi_parameters = self._process_colparams( values ) @@ -604,13 +607,16 @@ class Insert(ValuesBase): ) self.parameters, self._has_multi_parameters = self._process_colparams( - {_column_as_key(n): Null() for n in names} + { + coercions.expect(roles.DMLColumnRole, n, as_key=True): Null() + for n in names + } ) self.select_names = names self.inline = True self.include_insert_from_select_defaults = include_defaults - self.select = _interpret_as_select(select) + self.select = coercions.expect(roles.DMLSelectRole, select) def _copy_internals(self, clone=_clone, **kw): # TODO: coverage @@ -678,7 +684,7 @@ class Update(ValuesBase): users.update().values(name='ed').where( users.c.name==select([addresses.c.email_address]).\ where(addresses.c.user_id==users.c.id).\ - as_scalar() + scalar_subquery() ) :param values: @@ -744,7 +750,7 @@ class Update(ValuesBase): users.update().values( name=select([addresses.c.email_address]).\ where(addresses.c.user_id==users.c.id).\ - as_scalar() + scalar_subquery() ) .. seealso:: @@ -759,7 +765,9 @@ class Update(ValuesBase): self._bind = bind self._returning = returning if whereclause is not None: - self._whereclause = _literal_as_text(whereclause) + self._whereclause = coercions.expect( + roles.WhereHavingRole, whereclause + ) else: self._whereclause = None self.inline = inline @@ -785,10 +793,13 @@ class Update(ValuesBase): """ if self._whereclause is not None: self._whereclause = and_( - self._whereclause, _literal_as_text(whereclause) + self._whereclause, + coercions.expect(roles.WhereHavingRole, whereclause), ) else: - self._whereclause = _literal_as_text(whereclause) + self._whereclause = coercions.expect( + roles.WhereHavingRole, whereclause + ) @property def _extra_froms(self): @@ -846,7 +857,7 @@ class Delete(UpdateBase): users.delete().where( users.c.name==select([addresses.c.email_address]).\ where(addresses.c.user_id==users.c.id).\ - as_scalar() + scalar_subquery() ) .. versionchanged:: 1.2.0 @@ -858,14 +869,16 @@ class Delete(UpdateBase): """ self._bind = bind - self.table = _interpret_as_from(table) + self.table = coercions.expect(roles.FromClauseRole, table) self._returning = returning if prefixes: self._setup_prefixes(prefixes) if whereclause is not None: - self._whereclause = _literal_as_text(whereclause) + self._whereclause = coercions.expect( + roles.WhereHavingRole, whereclause + ) else: self._whereclause = None @@ -883,10 +896,13 @@ class Delete(UpdateBase): if self._whereclause is not None: self._whereclause = and_( - self._whereclause, _literal_as_text(whereclause) + self._whereclause, + coercions.expect(roles.WhereHavingRole, whereclause), ) else: - self._whereclause = _literal_as_text(whereclause) + self._whereclause = coercions.expect( + roles.WhereHavingRole, whereclause + ) @property def _extra_froms(self): diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index e634e5a367..a333303ec2 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -13,12 +13,13 @@ from __future__ import unicode_literals import itertools -import numbers import operator import re from . import clause_compare +from . import coercions from . import operators +from . import roles from . import type_api from .annotation import Annotated from .base import _generative @@ -26,6 +27,7 @@ from .base import Executable from .base import Immutable from .base import NO_ARG from .base import PARSE_AUTOCOMMIT +from .coercions import _document_text_coercion from .visitors import cloned_traverse from .visitors import traverse from .visitors import Visitable @@ -38,20 +40,6 @@ def _clone(element, **kw): return element._clone() -def _document_text_coercion(paramname, meth_rst, param_rst): - return util.add_parameter_text( - paramname, - ( - ".. warning:: " - "The %s argument to %s can be passed as a Python string argument, " - "which will be treated " - "as **trusted SQL text** and rendered as given. **DO NOT PASS " - "UNTRUSTED INPUT TO THIS PARAMETER**." - ) - % (param_rst, meth_rst), - ) - - def collate(expression, collation): """Return the clause ``expression COLLATE collation``. @@ -71,7 +59,7 @@ def collate(expression, collation): """ - expr = _literal_as_binds(expression) + expr = coercions.expect(roles.ExpressionElementRole, expression) return BinaryExpression( expr, CollationClause(collation), operators.collate, type_=expr.type ) @@ -127,7 +115,7 @@ def between(expr, lower_bound, upper_bound, symmetric=False): :meth:`.ColumnElement.between` """ - expr = _literal_as_binds(expr) + expr = coercions.expect(roles.ExpressionElementRole, expr) return expr.between(lower_bound, upper_bound, symmetric=symmetric) @@ -172,11 +160,11 @@ def not_(clause): same result. """ - return operators.inv(_literal_as_binds(clause)) + return operators.inv(coercions.expect(roles.ExpressionElementRole, clause)) @inspection._self_inspects -class ClauseElement(Visitable): +class ClauseElement(roles.SQLRole, Visitable): """Base class for elements of a programmatically constructed SQL expression. @@ -188,13 +176,20 @@ class ClauseElement(Visitable): supports_execution = False _from_objects = [] bind = None + description = None _is_clone_of = None - is_selectable = False + is_clause_element = True + is_selectable = False - description = None - _order_by_label_element = None + _is_textual = False + _is_from_clause = False + _is_returns_rows = False + _is_text_clause = False _is_from_container = False + _is_select_statement = False + + _order_by_label_element = None def _clone(self): """Create a shallow copy of this ClauseElement. @@ -238,7 +233,7 @@ class ClauseElement(Visitable): """ - raise NotImplementedError(self.__class__) + raise NotImplementedError() @property def _constructor(self): @@ -394,6 +389,7 @@ class ClauseElement(Visitable): return [] def self_group(self, against=None): + # type: (Optional[Any]) -> ClauseElement """Apply a 'grouping' to this :class:`.ClauseElement`. This method is overridden by subclasses to return a @@ -553,7 +549,20 @@ class ClauseElement(Visitable): ) -class ColumnElement(operators.ColumnOperators, ClauseElement): +class ColumnElement( + roles.ColumnArgumentOrKeyRole, + roles.StatementOptionRole, + roles.WhereHavingRole, + roles.BinaryElementRole, + roles.OrderByRole, + roles.ColumnsClauseRole, + roles.LimitOffsetRole, + roles.DMLColumnRole, + roles.DDLConstraintColumnRole, + roles.DDLExpressionRole, + operators.ColumnOperators, + ClauseElement, +): """Represent a column-oriented SQL expression suitable for usage in the "columns" clause, WHERE clause etc. of a statement. @@ -586,17 +595,13 @@ class ColumnElement(operators.ColumnOperators, ClauseElement): :class:`.TypeEngine` objects) are applied to the value. * any special object value, typically ORM-level constructs, which - feature a method called ``__clause_element__()``. The Core + feature an accessor called ``__clause_element__()``. The Core expression system looks for this method when an object of otherwise unknown type is passed to a function that is looking to coerce the - argument into a :class:`.ColumnElement` expression. The - ``__clause_element__()`` method, if present, should return a - :class:`.ColumnElement` instance. The primary use of - ``__clause_element__()`` within SQLAlchemy is that of class-bound - attributes on ORM-mapped classes; a ``User`` class which contains a - mapped attribute named ``.name`` will have a method - ``User.name.__clause_element__()`` which when invoked returns the - :class:`.Column` called ``name`` associated with the mapped table. + argument into a :class:`.ColumnElement` and sometimes a + :class:`.SelectBase` expression. It is used within the ORM to + convert from ORM-specific objects like mapped classes and + mapped attributes into Core expression objects. * The Python ``None`` value is typically interpreted as ``NULL``, which in SQLAlchemy Core produces an instance of :func:`.null`. @@ -702,6 +707,7 @@ class ColumnElement(operators.ColumnOperators, ClauseElement): _alt_names = () def self_group(self, against=None): + # type: (Module, Module, Optional[Any]) -> ClauseEleent if ( against in (operators.and_, operators.or_, operators._asbool) and self.type._type_affinity is type_api.BOOLEANTYPE._type_affinity @@ -826,7 +832,9 @@ class ColumnElement(operators.ColumnOperators, ClauseElement): else: key = name co = ColumnClause( - _as_truncated(name) if name_is_truncatable else name, + coercions.expect(roles.TruncatedLabelRole, name) + if name_is_truncatable + else name, type_=getattr(self, "type", None), _selectable=selectable, ) @@ -878,7 +886,7 @@ class ColumnElement(operators.ColumnOperators, ClauseElement): ) -class BindParameter(ColumnElement): +class BindParameter(roles.InElementRole, ColumnElement): r"""Represent a "bound expression". :class:`.BindParameter` is invoked explicitly using the @@ -1235,7 +1243,8 @@ class BindParameter(ColumnElement): "bindparams collection argument required for _cache_key " "implementation. Bound parameter cache keys are not safe " "to use without accommodating for the value or callable " - "within the parameter itself.") + "within the parameter itself." + ) else: bindparams.append(self) return (BindParameter, self.type._cache_key, self._orig_key) @@ -1282,7 +1291,20 @@ class TypeClause(ClauseElement): return (TypeClause, self.type._cache_key) -class TextClause(Executable, ClauseElement): +class TextClause( + roles.DDLConstraintColumnRole, + roles.DDLExpressionRole, + roles.StatementOptionRole, + roles.WhereHavingRole, + roles.OrderByRole, + roles.FromClauseRole, + roles.SelectStatementRole, + roles.CoerceTextStatementRole, + roles.BinaryElementRole, + roles.InElementRole, + Executable, + ClauseElement, +): """Represent a literal SQL text fragment. E.g.:: @@ -1304,6 +1326,10 @@ class TextClause(Executable, ClauseElement): __visit_name__ = "textclause" + _is_text_clause = True + + _is_textual = True + _bind_params_regex = re.compile(r"(? Union[Grouping, TextClause] if against is operators.in_op: return Grouping(self) else: @@ -1715,7 +1737,7 @@ class TextClause(Executable, ClauseElement): ) -class Null(ColumnElement): +class Null(roles.ConstExprRole, ColumnElement): """Represent the NULL keyword in a SQL statement. :class:`.Null` is accessed as a constant via the @@ -1739,7 +1761,7 @@ class Null(ColumnElement): return (Null,) -class False_(ColumnElement): +class False_(roles.ConstExprRole, ColumnElement): """Represent the ``false`` keyword, or equivalent, in a SQL statement. :class:`.False_` is accessed as a constant via the @@ -1798,7 +1820,7 @@ class False_(ColumnElement): return (False_,) -class True_(ColumnElement): +class True_(roles.ConstExprRole, ColumnElement): """Represent the ``true`` keyword, or equivalent, in a SQL statement. :class:`.True_` is accessed as a constant via the @@ -1864,7 +1886,12 @@ class True_(ColumnElement): return (True_,) -class ClauseList(ClauseElement): +class ClauseList( + roles.InElementRole, + roles.OrderByRole, + roles.ColumnsClauseRole, + ClauseElement, +): """Describe a list of clauses, separated by an operator. By default, is comma-separated, such as a column listing. @@ -1877,16 +1904,22 @@ class ClauseList(ClauseElement): self.operator = kwargs.pop("operator", operators.comma_op) self.group = kwargs.pop("group", True) self.group_contents = kwargs.pop("group_contents", True) - text_converter = kwargs.pop( - "_literal_as_text", _expression_literal_as_text + + self._text_converter_role = text_converter_role = kwargs.pop( + "_literal_as_text_role", roles.WhereHavingRole ) if self.group_contents: self.clauses = [ - text_converter(clause).self_group(against=self.operator) + coercions.expect(text_converter_role, clause).self_group( + against=self.operator + ) for clause in clauses ] else: - self.clauses = [text_converter(clause) for clause in clauses] + self.clauses = [ + coercions.expect(text_converter_role, clause) + for clause in clauses + ] self._is_implicitly_boolean = operators.is_boolean(self.operator) def __iter__(self): @@ -1902,10 +1935,14 @@ class ClauseList(ClauseElement): def append(self, clause): if self.group_contents: self.clauses.append( - _literal_as_text(clause).self_group(against=self.operator) + coercions.expect(self._text_converter_role, clause).self_group( + against=self.operator + ) ) else: - self.clauses.append(_literal_as_text(clause)) + self.clauses.append( + coercions.expect(self._text_converter_role, clause) + ) def _copy_internals(self, clone=_clone, **kw): self.clauses = [clone(clause, **kw) for clause in self.clauses] @@ -1923,6 +1960,7 @@ class ClauseList(ClauseElement): return list(itertools.chain(*[c._from_objects for c in self.clauses])) def self_group(self, against=None): + # type: (Optional[Any]) -> ClauseElement if self.group and operators.is_precedent(self.operator, against): return Grouping(self) else: @@ -1947,7 +1985,7 @@ class BooleanClauseList(ClauseList, ColumnElement): convert_clauses = [] clauses = [ - _expression_literal_as_text(clause) + coercions.expect(roles.WhereHavingRole, clause) for clause in util.coerce_generator_arg(clauses) ] for clause in clauses: @@ -2055,6 +2093,7 @@ class BooleanClauseList(ClauseList, ColumnElement): return (self,) def self_group(self, against=None): + # type: (Optional[Any]) -> ClauseElement if not self.clauses: return self else: @@ -2092,7 +2131,9 @@ class Tuple(ClauseList, ColumnElement): """ - clauses = [_literal_as_binds(c) for c in clauses] + clauses = [ + coercions.expect(roles.ExpressionElementRole, c) for c in clauses + ] self._type_tuple = [arg.type for arg in clauses] self.type = kw.pop( "type_", @@ -2283,12 +2324,20 @@ class Case(ColumnElement): if value is not None: whenlist = [ - (_literal_as_binds(c).self_group(), _literal_as_binds(r)) + ( + coercions.expect( + roles.ExpressionElementRole, c + ).self_group(), + coercions.expect(roles.ExpressionElementRole, r), + ) for (c, r) in whens ] else: whenlist = [ - (_no_literals(c).self_group(), _literal_as_binds(r)) + ( + coercions.expect(roles.ColumnArgumentRole, c).self_group(), + coercions.expect(roles.ExpressionElementRole, r), + ) for (c, r) in whens ] @@ -2300,12 +2349,12 @@ class Case(ColumnElement): if value is None: self.value = None else: - self.value = _literal_as_binds(value) + self.value = coercions.expect(roles.ExpressionElementRole, value) self.type = type_ self.whens = whenlist if else_ is not None: - self.else_ = _literal_as_binds(else_) + self.else_ = coercions.expect(roles.ExpressionElementRole, else_) else: self.else_ = None @@ -2455,7 +2504,9 @@ class Cast(ColumnElement): """ self.type = type_api.to_instance(type_) - self.clause = _literal_as_binds(expression, type_=self.type) + self.clause = coercions.expect( + roles.ExpressionElementRole, expression, type_=self.type + ) self.typeclause = TypeClause(self.type) def _copy_internals(self, clone=_clone, **kw): @@ -2557,7 +2608,9 @@ class TypeCoerce(ColumnElement): """ self.type = type_api.to_instance(type_) - self.clause = _literal_as_binds(expression, type_=self.type) + self.clause = coercions.expect( + roles.ExpressionElementRole, expression, type_=self.type + ) def _copy_internals(self, clone=_clone, **kw): self.clause = clone(self.clause, **kw) @@ -2598,7 +2651,7 @@ class Extract(ColumnElement): """ self.type = type_api.INTEGERTYPE self.field = field - self.expr = _literal_as_binds(expr, None) + self.expr = coercions.expect(roles.ExpressionElementRole, expr) def _copy_internals(self, clone=_clone, **kw): self.expr = clone(self.expr, **kw) @@ -2733,7 +2786,7 @@ class UnaryExpression(ColumnElement): """ return UnaryExpression( - _literal_as_label_reference(column), + coercions.expect(roles.ByOfRole, column), modifier=operators.nullsfirst_op, wraps_column_expression=False, ) @@ -2776,7 +2829,7 @@ class UnaryExpression(ColumnElement): """ return UnaryExpression( - _literal_as_label_reference(column), + coercions.expect(roles.ByOfRole, column), modifier=operators.nullslast_op, wraps_column_expression=False, ) @@ -2817,7 +2870,7 @@ class UnaryExpression(ColumnElement): """ return UnaryExpression( - _literal_as_label_reference(column), + coercions.expect(roles.ByOfRole, column), modifier=operators.desc_op, wraps_column_expression=False, ) @@ -2857,7 +2910,7 @@ class UnaryExpression(ColumnElement): """ return UnaryExpression( - _literal_as_label_reference(column), + coercions.expect(roles.ByOfRole, column), modifier=operators.asc_op, wraps_column_expression=False, ) @@ -2898,7 +2951,7 @@ class UnaryExpression(ColumnElement): :data:`.func` """ - expr = _literal_as_binds(expr) + expr = coercions.expect(roles.ExpressionElementRole, expr) return UnaryExpression( expr, operator=operators.distinct_op, @@ -2953,6 +3006,7 @@ class UnaryExpression(ColumnElement): return ClauseElement._negate(self) def self_group(self, against=None): + # type: (Optional[Any]) -> ClauseElement if self.operator and operators.is_precedent(self.operator, against): return Grouping(self) else: @@ -2990,10 +3044,8 @@ class CollectionAggregate(UnaryExpression): """ - expr = _literal_as_binds(expr) + expr = coercions.expect(roles.ExpressionElementRole, expr) - if expr.is_selectable and hasattr(expr, "as_scalar"): - expr = expr.as_scalar() expr = expr.self_group() return CollectionAggregate( expr, @@ -3023,9 +3075,7 @@ class CollectionAggregate(UnaryExpression): """ - expr = _literal_as_binds(expr) - if expr.is_selectable and hasattr(expr, "as_scalar"): - expr = expr.as_scalar() + expr = coercions.expect(roles.ExpressionElementRole, expr) expr = expr.self_group() return CollectionAggregate( expr, @@ -3064,6 +3114,7 @@ class AsBoolean(UnaryExpression): self._is_implicitly_boolean = element._is_implicitly_boolean def self_group(self, against=None): + # type: (Optional[Any]) -> ClauseElement return self def _cache_key(self, **kw): @@ -3155,6 +3206,8 @@ class BinaryExpression(ColumnElement): ) def self_group(self, against=None): + # type: (Optional[Any]) -> ClauseElement + if operators.is_precedent(self.operator, against): return Grouping(self) else: @@ -3191,6 +3244,7 @@ class Slice(ColumnElement): self.type = type_api.NULLTYPE def self_group(self, against=None): + # type: (Optional[Any]) -> ClauseElement assert against is operator.getitem return self @@ -3215,6 +3269,7 @@ class Grouping(ColumnElement): self.type = getattr(element, "type", type_api.NULLTYPE) def self_group(self, against=None): + # type: (Optional[Any]) -> ClauseElement return self @util.memoized_property @@ -3363,13 +3418,12 @@ class Over(ColumnElement): self.element = element if order_by is not None: self.order_by = ClauseList( - *util.to_list(order_by), - _literal_as_text=_literal_as_label_reference + *util.to_list(order_by), _literal_as_text_role=roles.ByOfRole ) if partition_by is not None: self.partition_by = ClauseList( *util.to_list(partition_by), - _literal_as_text=_literal_as_label_reference + _literal_as_text_role=roles.ByOfRole ) if range_: @@ -3534,8 +3588,7 @@ class WithinGroup(ColumnElement): self.element = element if order_by is not None: self.order_by = ClauseList( - *util.to_list(order_by), - _literal_as_text=_literal_as_label_reference + *util.to_list(order_by), _literal_as_text_role=roles.ByOfRole ) def over(self, partition_by=None, order_by=None, range_=None, rows=None): @@ -3658,7 +3711,7 @@ class FunctionFilter(ColumnElement): """ for criterion in list(criterion): - criterion = _expression_literal_as_text(criterion) + criterion = coercions.expect(roles.WhereHavingRole, criterion) if self.criterion is not None: self.criterion = self.criterion & criterion @@ -3727,7 +3780,7 @@ class FunctionFilter(ColumnElement): ) -class Label(ColumnElement): +class Label(roles.LabeledColumnExprRole, ColumnElement): """Represents a column label (AS). Represent a label, as typically applied to any column-level @@ -3801,6 +3854,7 @@ class Label(ColumnElement): return self._element.self_group(against=operators.as_) def self_group(self, against=None): + # type: (Optional[Any]) -> ClauseElement return self._apply_to_inner(self._element.self_group, against=against) def _negate(self): @@ -3849,7 +3903,7 @@ class Label(ColumnElement): return e -class ColumnClause(Immutable, ColumnElement): +class ColumnClause(roles.LabeledColumnExprRole, Immutable, ColumnElement): """Represents a column expression from any textual string. The :class:`.ColumnClause`, a lightweight analogue to the @@ -3985,14 +4039,14 @@ class ColumnClause(Immutable, ColumnElement): if ( self.is_literal or self.table is None - or self.table._textual + or self.table._is_textual or not hasattr(other, "proxy_set") or ( isinstance(other, ColumnClause) and ( other.is_literal or other.table is None - or other.table._textual + or other.table._is_textual ) ) ): @@ -4083,7 +4137,7 @@ class ColumnClause(Immutable, ColumnElement): counter += 1 label = _label - return _as_truncated(label) + return coercions.expect(roles.TruncatedLabelRole, label) else: return name @@ -4110,7 +4164,7 @@ class ColumnClause(Immutable, ColumnElement): # otherwise its considered to be a label is_literal = self.is_literal and (name is None or name == self.name) c = self._constructor( - _as_truncated(name or self.name) + coercions.expect(roles.TruncatedLabelRole, name or self.name) if name_is_truncatable else (name or self.name), type_=self.type, @@ -4250,6 +4304,108 @@ class quoted_name(util.MemoizedSlots, util.text_type): return "'%s'" % backslashed +def _expand_cloned(elements): + """expand the given set of ClauseElements to be the set of all 'cloned' + predecessors. + + """ + return itertools.chain(*[x._cloned_set for x in elements]) + + +def _select_iterables(elements): + """expand tables into individual columns in the + given list of column expressions. + + """ + return itertools.chain(*[c._select_iterable for c in elements]) + + +def _cloned_intersection(a, b): + """return the intersection of sets a and b, counting + any overlap between 'cloned' predecessors. + + The returned set is in terms of the entities present within 'a'. + + """ + all_overlap = set(_expand_cloned(a)).intersection(_expand_cloned(b)) + return set( + elem for elem in a if all_overlap.intersection(elem._cloned_set) + ) + + +def _cloned_difference(a, b): + all_overlap = set(_expand_cloned(a)).intersection(_expand_cloned(b)) + return set( + elem for elem in a if not all_overlap.intersection(elem._cloned_set) + ) + + +def _find_columns(clause): + """locate Column objects within the given expression.""" + + cols = util.column_set() + traverse(clause, {}, {"column": cols.add}) + return cols + + +def _type_from_args(args): + for a in args: + if not a.type._isnull: + return a.type + else: + return type_api.NULLTYPE + + +def _corresponding_column_or_error(fromclause, column, require_embedded=False): + c = fromclause.corresponding_column( + column, require_embedded=require_embedded + ) + if c is None: + raise exc.InvalidRequestError( + "Given column '%s', attached to table '%s', " + "failed to locate a corresponding column from table '%s'" + % (column, getattr(column, "table", None), fromclause.description) + ) + return c + + +class AnnotatedColumnElement(Annotated): + def __init__(self, element, values): + Annotated.__init__(self, element, values) + ColumnElement.comparator._reset(self) + for attr in ("name", "key", "table"): + if self.__dict__.get(attr, False) is None: + self.__dict__.pop(attr) + + def _with_annotations(self, values): + clone = super(AnnotatedColumnElement, self)._with_annotations(values) + ColumnElement.comparator._reset(clone) + return clone + + @util.memoized_property + def name(self): + """pull 'name' from parent, if not present""" + return self._Annotated__element.name + + @util.memoized_property + def table(self): + """pull 'table' from parent, if not present""" + return self._Annotated__element.table + + @util.memoized_property + def key(self): + """pull 'key' from parent, if not present""" + return self._Annotated__element.key + + @util.memoized_property + def info(self): + return self._Annotated__element.info + + @util.memoized_property + def anon_label(self): + return self._Annotated__element.anon_label + + class _truncated_label(quoted_name): """A unicode subclass used to identify symbolic " "names that may require truncation.""" @@ -4378,349 +4534,3 @@ class _anonymous_label(_truncated_label): else: # else skip the constructor call return self % map_ - - -def _as_truncated(value): - """coerce the given value to :class:`._truncated_label`. - - Existing :class:`._truncated_label` and - :class:`._anonymous_label` objects are passed - unchanged. - """ - - if isinstance(value, _truncated_label): - return value - else: - return _truncated_label(value) - - -def _string_or_unprintable(element): - if isinstance(element, util.string_types): - return element - else: - try: - return str(element) - except Exception: - return "unprintable element %r" % element - - -def _expand_cloned(elements): - """expand the given set of ClauseElements to be the set of all 'cloned' - predecessors. - - """ - return itertools.chain(*[x._cloned_set for x in elements]) - - -def _select_iterables(elements): - """expand tables into individual columns in the - given list of column expressions. - - """ - return itertools.chain(*[c._select_iterable for c in elements]) - - -def _cloned_intersection(a, b): - """return the intersection of sets a and b, counting - any overlap between 'cloned' predecessors. - - The returned set is in terms of the entities present within 'a'. - - """ - all_overlap = set(_expand_cloned(a)).intersection(_expand_cloned(b)) - return set( - elem for elem in a if all_overlap.intersection(elem._cloned_set) - ) - - -def _cloned_difference(a, b): - all_overlap = set(_expand_cloned(a)).intersection(_expand_cloned(b)) - return set( - elem for elem in a if not all_overlap.intersection(elem._cloned_set) - ) - - -@util.dependencies("sqlalchemy.sql.functions") -def _labeled(functions, element): - if not hasattr(element, "name") or isinstance( - element, functions.FunctionElement - ): - return element.label(None) - else: - return element - - -def _is_column(col): - """True if ``col`` is an instance of :class:`.ColumnElement`.""" - - return isinstance(col, ColumnElement) - - -def _find_columns(clause): - """locate Column objects within the given expression.""" - - cols = util.column_set() - traverse(clause, {}, {"column": cols.add}) - return cols - - -# there is some inconsistency here between the usage of -# inspect() vs. checking for Visitable and __clause_element__. -# Ideally all functions here would derive from inspect(), -# however the inspect() versions add significant callcount -# overhead for critical functions like _interpret_as_column_or_from(). -# Generally, the column-based functions are more performance critical -# and are fine just checking for __clause_element__(). It is only -# _interpret_as_from() where we'd like to be able to receive ORM entities -# that have no defined namespace, hence inspect() is needed there. - - -def _column_as_key(element): - if isinstance(element, util.string_types): - return element - if hasattr(element, "__clause_element__"): - element = element.__clause_element__() - try: - return element.key - except AttributeError: - return None - - -def _clause_element_as_expr(element): - if hasattr(element, "__clause_element__"): - return element.__clause_element__() - else: - return element - - -def _literal_as_label_reference(element): - if isinstance(element, util.string_types): - return _textual_label_reference(element) - - elif hasattr(element, "__clause_element__"): - element = element.__clause_element__() - - return _literal_as_text(element) - - -def _literal_and_labels_as_label_reference(element): - if isinstance(element, util.string_types): - return _textual_label_reference(element) - - elif hasattr(element, "__clause_element__"): - element = element.__clause_element__() - - if ( - isinstance(element, ColumnElement) - and element._order_by_label_element is not None - ): - return _label_reference(element) - else: - return _literal_as_text(element) - - -def _expression_literal_as_text(element): - return _literal_as_text(element) - - -def _literal_as(element, text_fallback): - if isinstance(element, Visitable): - return element - elif hasattr(element, "__clause_element__"): - return element.__clause_element__() - elif isinstance(element, util.string_types): - return text_fallback(element) - elif isinstance(element, (util.NoneType, bool)): - return _const_expr(element) - else: - raise exc.ArgumentError( - "SQL expression object expected, got object of type %r " - "instead" % type(element) - ) - - -def _literal_as_text(element, allow_coercion_to_text=False): - if allow_coercion_to_text: - return _literal_as(element, TextClause) - else: - return _literal_as(element, _no_text_coercion) - - -def _literal_as_column(element): - return _literal_as(element, ColumnClause) - - -def _no_column_coercion(element): - element = str(element) - guess_is_literal = not _guess_straight_column.match(element) - raise exc.ArgumentError( - "Textual column expression %(column)r should be " - "explicitly declared with text(%(column)r), " - "or use %(literal_column)s(%(column)r) " - "for more specificity" - % { - "column": util.ellipses_string(element), - "literal_column": "literal_column" - if guess_is_literal - else "column", - } - ) - - -def _no_text_coercion(element, exc_cls=exc.ArgumentError, extra=None): - raise exc_cls( - "%(extra)sTextual SQL expression %(expr)r should be " - "explicitly declared as text(%(expr)r)" - % { - "expr": util.ellipses_string(element), - "extra": "%s " % extra if extra else "", - } - ) - - -def _no_literals(element): - if hasattr(element, "__clause_element__"): - return element.__clause_element__() - elif not isinstance(element, Visitable): - raise exc.ArgumentError( - "Ambiguous literal: %r. Use the 'text()' " - "function to indicate a SQL expression " - "literal, or 'literal()' to indicate a " - "bound value." % (element,) - ) - else: - return element - - -def _is_literal(element): - return not isinstance(element, Visitable) and not hasattr( - element, "__clause_element__" - ) - - -def _only_column_elements_or_none(element, name): - if element is None: - return None - else: - return _only_column_elements(element, name) - - -def _only_column_elements(element, name): - if hasattr(element, "__clause_element__"): - element = element.__clause_element__() - if not isinstance(element, ColumnElement): - raise exc.ArgumentError( - "Column-based expression object expected for argument " - "'%s'; got: '%s', type %s" % (name, element, type(element)) - ) - return element - - -def _literal_as_binds(element, name=None, type_=None): - if hasattr(element, "__clause_element__"): - return element.__clause_element__() - elif not isinstance(element, Visitable): - if element is None: - return Null() - else: - return BindParameter(name, element, type_=type_, unique=True) - else: - return element - - -_guess_straight_column = re.compile(r"^\w\S*$", re.I) - - -def _interpret_as_column_or_from(element): - if isinstance(element, Visitable): - return element - elif hasattr(element, "__clause_element__"): - return element.__clause_element__() - - insp = inspection.inspect(element, raiseerr=False) - if insp is None: - if isinstance(element, (util.NoneType, bool)): - return _const_expr(element) - elif hasattr(insp, "selectable"): - return insp.selectable - - # be forgiving as this is an extremely common - # and known expression - if element == "*": - guess_is_literal = True - elif isinstance(element, (numbers.Number)): - return ColumnClause(str(element), is_literal=True) - else: - _no_column_coercion(element) - return ColumnClause(element, is_literal=guess_is_literal) - - -def _const_expr(element): - if isinstance(element, (Null, False_, True_)): - return element - elif element is None: - return Null() - elif element is False: - return False_() - elif element is True: - return True_() - else: - raise exc.ArgumentError("Expected None, False, or True") - - -def _type_from_args(args): - for a in args: - if not a.type._isnull: - return a.type - else: - return type_api.NULLTYPE - - -def _corresponding_column_or_error(fromclause, column, require_embedded=False): - c = fromclause.corresponding_column( - column, require_embedded=require_embedded - ) - if c is None: - raise exc.InvalidRequestError( - "Given column '%s', attached to table '%s', " - "failed to locate a corresponding column from table '%s'" - % (column, getattr(column, "table", None), fromclause.description) - ) - return c - - -class AnnotatedColumnElement(Annotated): - def __init__(self, element, values): - Annotated.__init__(self, element, values) - ColumnElement.comparator._reset(self) - for attr in ("name", "key", "table"): - if self.__dict__.get(attr, False) is None: - self.__dict__.pop(attr) - - def _with_annotations(self, values): - clone = super(AnnotatedColumnElement, self)._with_annotations(values) - ColumnElement.comparator._reset(clone) - return clone - - @util.memoized_property - def name(self): - """pull 'name' from parent, if not present""" - return self._Annotated__element.name - - @util.memoized_property - def table(self): - """pull 'table' from parent, if not present""" - return self._Annotated__element.table - - @util.memoized_property - def key(self): - """pull 'key' from parent, if not present""" - return self._Annotated__element.key - - @util.memoized_property - def info(self): - return self._Annotated__element.info - - @util.memoized_property - def anon_label(self): - return self._Annotated__element.anon_label diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index f381879ce1..b04355cf5d 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -67,7 +67,6 @@ __all__ = [ "outerjoin", "over", "select", - "subquery", "table", "text", "tuple_", @@ -92,22 +91,7 @@ from .dml import Insert # noqa from .dml import Update # noqa from .dml import UpdateBase # noqa from .dml import ValuesBase # noqa -from .elements import _clause_element_as_expr # noqa -from .elements import _clone # noqa -from .elements import _cloned_difference # noqa -from .elements import _cloned_intersection # noqa -from .elements import _column_as_key # noqa -from .elements import _corresponding_column_or_error # noqa -from .elements import _expression_literal_as_text # noqa -from .elements import _is_column # noqa -from .elements import _labeled # noqa -from .elements import _literal_as_binds # noqa -from .elements import _literal_as_column # noqa -from .elements import _literal_as_label_reference # noqa -from .elements import _literal_as_text # noqa -from .elements import _only_column_elements # noqa from .elements import _select_iterables # noqa -from .elements import _string_or_unprintable # noqa from .elements import _truncated_label # noqa from .elements import between # noqa from .elements import BinaryExpression # noqa @@ -147,7 +131,6 @@ from .functions import func # noqa from .functions import Function # noqa from .functions import FunctionElement # noqa from .functions import modifier # noqa -from .selectable import _interpret_as_from # noqa from .selectable import Alias # noqa from .selectable import CompoundSelect # noqa from .selectable import CTE # noqa @@ -160,6 +143,7 @@ from .selectable import HasPrefixes # noqa from .selectable import HasSuffixes # noqa from .selectable import Join # noqa from .selectable import Lateral # noqa +from .selectable import ReturnsRows # noqa from .selectable import ScalarSelect # noqa from .selectable import Select # noqa from .selectable import Selectable # noqa @@ -171,7 +155,6 @@ from .selectable import TextAsFrom # noqa from .visitors import Visitable # noqa from ..util.langhelpers import public_factory # noqa - # factory functions - these pull class-bound constructors and classmethods # from SQL elements and selectables into public functions. This allows # the functions to be available in the sqlalchemy.sql.* namespace and diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index d0aa239881..1737899986 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -9,14 +9,15 @@ """ from . import annotation +from . import coercions from . import operators +from . import roles from . import schema from . import sqltypes from . import util as sqlutil from .base import ColumnCollection from .base import Executable from .elements import _clone -from .elements import _literal_as_binds from .elements import _type_from_args from .elements import BinaryExpression from .elements import BindParameter @@ -83,7 +84,12 @@ class FunctionElement(Executable, ColumnElement, FromClause): """Construct a :class:`.FunctionElement`. """ args = [ - _literal_as_binds(c, getattr(self, "name", None)) for c in clauses + coercions.expect( + roles.ExpressionElementRole, + c, + name=getattr(self, "name", None), + ) + for c in clauses ] self._has_args = self._has_args or bool(args) self.clause_expr = ClauseList( @@ -686,7 +692,12 @@ class GenericFunction(util.with_metaclass(_GenericMeta, Function)): def __init__(self, *args, **kwargs): parsed_args = kwargs.pop("_parsed_args", None) if parsed_args is None: - parsed_args = [_literal_as_binds(c, self.name) for c in args] + parsed_args = [ + coercions.expect( + roles.ExpressionElementRole, c, name=self.name + ) + for c in args + ] self._has_args = self._has_args or bool(parsed_args) self.packagenames = [] self._bind = kwargs.get("bind", None) @@ -751,7 +762,10 @@ class ReturnTypeFromArgs(GenericFunction): """Define a function whose return type is the same as its arguments.""" def __init__(self, *args, **kwargs): - args = [_literal_as_binds(c, self.name) for c in args] + args = [ + coercions.expect(roles.ExpressionElementRole, c, name=self.name) + for c in args + ] kwargs.setdefault("type_", _type_from_args(args)) kwargs["_parsed_args"] = args super(ReturnTypeFromArgs, self).__init__(*args, **kwargs) @@ -880,7 +894,7 @@ class array_agg(GenericFunction): type = sqltypes.ARRAY def __init__(self, *args, **kwargs): - args = [_literal_as_binds(c) for c in args] + args = [coercions.expect(roles.ExpressionElementRole, c) for c in args] default_array_type = kwargs.pop("_default_array_type", sqltypes.ARRAY) if "type_" not in kwargs: diff --git a/lib/sqlalchemy/sql/operators.py b/lib/sqlalchemy/sql/operators.py index 8479c1d594..b8bbb45252 100644 --- a/lib/sqlalchemy/sql/operators.py +++ b/lib/sqlalchemy/sql/operators.py @@ -1053,7 +1053,7 @@ class ColumnOperators(Operators): expr = 5 == mytable.c.somearray.any_() # mysql '5 = ANY (SELECT value FROM table)' - expr = 5 == select([table.c.value]).as_scalar().any_() + expr = 5 == select([table.c.value]).scalar_subquery().any_() .. seealso:: @@ -1078,7 +1078,7 @@ class ColumnOperators(Operators): expr = 5 == mytable.c.somearray.all_() # mysql '5 = ALL (SELECT value FROM table)' - expr = 5 == select([table.c.value]).as_scalar().all_() + expr = 5 == select([table.c.value]).scalar_subquery().all_() .. seealso:: diff --git a/lib/sqlalchemy/sql/roles.py b/lib/sqlalchemy/sql/roles.py new file mode 100644 index 0000000000..2d3aaf903a --- /dev/null +++ b/lib/sqlalchemy/sql/roles.py @@ -0,0 +1,157 @@ +# sql/roles.py +# Copyright (C) 2005-2019 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + + +class SQLRole(object): + """Define a "role" within a SQL statement structure. + + Classes within SQL Core participate within SQLRole hierarchies in order + to more accurately indicate where they may be used within SQL statements + of all types. + + .. versionadded:: 1.4 + + """ + + +class UsesInspection(object): + pass + + +class ColumnArgumentRole(SQLRole): + _role_name = "Column expression" + + +class ColumnArgumentOrKeyRole(ColumnArgumentRole): + _role_name = "Column expression or string key" + + +class ColumnListRole(SQLRole): + """Elements suitable for forming comma separated lists of expressions.""" + + +class TruncatedLabelRole(SQLRole): + _role_name = "String SQL identifier" + + +class ColumnsClauseRole(UsesInspection, ColumnListRole): + _role_name = "Column expression or FROM clause" + + @property + def _select_iterable(self): + raise NotImplementedError() + + +class LimitOffsetRole(SQLRole): + _role_name = "LIMIT / OFFSET expression" + + +class ByOfRole(ColumnListRole): + _role_name = "GROUP BY / OF / etc. expression" + + +class OrderByRole(ByOfRole): + _role_name = "ORDER BY expression" + + +class StructuralRole(SQLRole): + pass + + +class StatementOptionRole(StructuralRole): + _role_name = "statement sub-expression element" + + +class WhereHavingRole(StructuralRole): + _role_name = "SQL expression for WHERE/HAVING role" + + +class ExpressionElementRole(SQLRole): + _role_name = "SQL expression element" + + +class ConstExprRole(ExpressionElementRole): + _role_name = "Constant True/False/None expression" + + +class LabeledColumnExprRole(ExpressionElementRole): + pass + + +class BinaryElementRole(ExpressionElementRole): + _role_name = "SQL expression element or literal value" + + +class InElementRole(SQLRole): + _role_name = ( + "IN expression list, SELECT construct, or bound parameter object" + ) + + +class FromClauseRole(ColumnsClauseRole): + _role_name = "FROM expression, such as a Table or alias() object" + + @property + def _hide_froms(self): + raise NotImplementedError() + + +class CoerceTextStatementRole(SQLRole): + _role_name = "Executable SQL, text() construct, or string statement" + + +class StatementRole(CoerceTextStatementRole): + _role_name = "Executable SQL or text() construct" + + +class ReturnsRowsRole(StatementRole): + _role_name = ( + "Row returning expression such as a SELECT, or an " + "INSERT/UPDATE/DELETE with RETURNING" + ) + + +class SelectStatementRole(ReturnsRowsRole): + _role_name = "SELECT construct or equivalent text() construct" + + +class HasCTERole(ReturnsRowsRole): + pass + + +class CompoundElementRole(SQLRole): + """SELECT statements inside a CompoundSelect, e.g. UNION, EXTRACT, etc.""" + + _role_name = ( + "SELECT construct for inclusion in a UNION or other set construct" + ) + + +class DMLRole(StatementRole): + pass + + +class DMLColumnRole(SQLRole): + _role_name = "SET/VALUES column expression or string key" + + +class DMLSelectRole(SQLRole): + """A SELECT statement embedded in DML, typically INSERT from SELECT """ + + _role_name = "SELECT statement or equivalent textual object" + + +class DDLRole(StatementRole): + pass + + +class DDLExpressionRole(StructuralRole): + _role_name = "SQL expression element for DDL constraint" + + +class DDLConstraintColumnRole(SQLRole): + _role_name = "String column name or column object for DDL constraint" diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index b045e006e7..62ff25a646 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -34,16 +34,16 @@ import collections import operator import sqlalchemy +from . import coercions from . import ddl +from . import roles from . import type_api from . import visitors from .base import _bind_or_error from .base import ColumnCollection from .base import DialectKWArgs from .base import SchemaEventTarget -from .elements import _as_truncated -from .elements import _document_text_coercion -from .elements import _literal_as_text +from .coercions import _document_text_coercion from .elements import ClauseElement from .elements import ColumnClause from .elements import ColumnElement @@ -1583,7 +1583,9 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause): ) try: c = self._constructor( - _as_truncated(name or self.name) + coercions.expect( + roles.TruncatedLabelRole, name if name else self.name + ) if name_is_truncatable else (name or self.name), self.type, @@ -2109,13 +2111,19 @@ class ForeignKey(DialectKWArgs, SchemaItem): class _NotAColumnExpr(object): + # the coercions system is not used in crud.py for the values passed in + # the insert().values() and update().values() methods, so the usual + # pathways to rejecting a coercion in the unlikely case of adding defaut + # generator objects to insert() or update() constructs aren't available; + # create a quick coercion rejection here that is specific to what crud.py + # calls on value objects. def _not_a_column_expr(self): raise exc.InvalidRequestError( "This %s cannot be used directly " "as a column expression." % self.__class__.__name__ ) - __clause_element__ = self_group = lambda self: self._not_a_column_expr() + self_group = lambda self: self._not_a_column_expr() # noqa _from_objects = property(lambda self: self._not_a_column_expr()) @@ -2274,7 +2282,7 @@ class ColumnDefault(DefaultGenerator): return "ColumnDefault(%r)" % (self.arg,) -class Sequence(DefaultGenerator): +class Sequence(roles.StatementRole, DefaultGenerator): """Represents a named database sequence. The :class:`.Sequence` object represents the name and configurational @@ -2759,25 +2767,6 @@ class ColumnCollectionMixin(object): if _autoattach and self._pending_colargs: self._check_attach() - @classmethod - def _extract_col_expression_collection(cls, expressions): - for expr in expressions: - strname = None - column = None - if hasattr(expr, "__clause_element__"): - expr = expr.__clause_element__() - - if not isinstance(expr, (ColumnElement, TextClause)): - # this assumes a string - strname = expr - else: - cols = [] - visitors.traverse(expr, {}, {"column": cols.append}) - if cols: - column = cols[0] - add_element = column if column is not None else strname - yield expr, column, strname, add_element - def _check_attach(self, evt=False): col_objs = [c for c in self._pending_colargs if isinstance(c, Column)] @@ -2960,7 +2949,7 @@ class CheckConstraint(ColumnCollectionConstraint): """ - self.sqltext = _literal_as_text(sqltext, allow_coercion_to_text=True) + self.sqltext = coercions.expect(roles.DDLExpressionRole, sqltext) columns = [] visitors.traverse(self.sqltext, {}, {"column": columns.append}) @@ -3630,7 +3619,9 @@ class Index(DialectKWArgs, ColumnCollectionMixin, SchemaItem): column, strname, add_element, - ) in self._extract_col_expression_collection(expressions): + ) in coercions.expect_col_expression_collection( + roles.DDLConstraintColumnRole, expressions + ): if add_element is not None: columns.append(add_element) processed_expressions.append(expr) diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 5167182fe2..41be9fc5a3 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -15,8 +15,9 @@ import itertools import operator from operator import attrgetter -from sqlalchemy.sql.visitors import Visitable +from . import coercions from . import operators +from . import roles from . import type_api from .annotation import Annotated from .base import _from_objects @@ -26,18 +27,12 @@ from .base import ColumnSet from .base import Executable from .base import Generative from .base import Immutable +from .coercions import _document_text_coercion from .elements import _anonymous_label -from .elements import _clause_element_as_expr from .elements import _clone from .elements import _cloned_difference from .elements import _cloned_intersection -from .elements import _document_text_coercion from .elements import _expand_cloned -from .elements import _interpret_as_column_or_from -from .elements import _literal_and_labels_as_label_reference -from .elements import _literal_as_label_reference -from .elements import _literal_as_text -from .elements import _no_text_coercion from .elements import _select_iterables from .elements import and_ from .elements import BindParameter @@ -48,75 +43,15 @@ from .elements import literal_column from .elements import True_ from .elements import UnaryExpression from .. import exc -from .. import inspection from .. import util -def _interpret_as_from(element): - insp = inspection.inspect(element, raiseerr=False) - if insp is None: - if isinstance(element, util.string_types): - _no_text_coercion(element) - try: - return insp.selectable - except AttributeError: - raise exc.ArgumentError("FROM expression expected") - - -def _interpret_as_select(element): - element = _interpret_as_from(element) - if isinstance(element, Alias): - element = element.original - if not isinstance(element, SelectBase): - element = element.select() - return element - - class _OffsetLimitParam(BindParameter): @property def _limit_offset_value(self): return self.effective_value -def _offset_or_limit_clause(element, name=None, type_=None): - """Convert the given value to an "offset or limit" clause. - - This handles incoming integers and converts to an expression; if - an expression is already given, it is passed through. - - """ - if element is None: - return None - elif hasattr(element, "__clause_element__"): - return element.__clause_element__() - elif isinstance(element, Visitable): - return element - else: - value = util.asint(element) - return _OffsetLimitParam(name, value, type_=type_, unique=True) - - -def _offset_or_limit_clause_asint(clause, attrname): - """Convert the "offset or limit" clause of a select construct to an - integer. - - This is only possible if the value is stored as a simple bound parameter. - Otherwise, a compilation error is raised. - - """ - if clause is None: - return None - try: - value = clause._limit_offset_value - except AttributeError: - raise exc.CompileError( - "This SELECT structure does not use a simple " - "integer value for %s" % attrname - ) - else: - return util.asint(value) - - def subquery(alias, *args, **kwargs): r"""Return an :class:`.Alias` object derived from a :class:`.Select`. @@ -133,8 +68,42 @@ def subquery(alias, *args, **kwargs): return Select(*args, **kwargs).alias(alias) -class Selectable(ClauseElement): - """mark a class as being selectable""" +class ReturnsRows(roles.ReturnsRowsRole, ClauseElement): + """The basemost class for Core contructs that have some concept of + columns that can represent rows. + + While the SELECT statement and TABLE are the primary things we think + of in this category, DML like INSERT, UPDATE and DELETE can also specify + RETURNING which means they can be used in CTEs and other forms, and + PostgreSQL has functions that return rows also. + + .. versionadded:: 1.4 + + """ + + _is_returns_rows = True + + # sub-elements of returns_rows + _is_from_clause = False + _is_select_statement = False + _is_lateral = False + + @property + def selectable(self): + raise NotImplementedError( + "This object is a base ReturnsRows object, but is not a " + "FromClause so has no .c. collection." + ) + + +class Selectable(ReturnsRows): + """mark a class as being selectable. + + This class is legacy as of 1.4 as the concept of a SQL construct which + "returns rows" is more generalized than one which can be the subject + of a SELECT. + + """ __visit_name__ = "selectable" @@ -190,7 +159,7 @@ class HasPrefixes(object): def _setup_prefixes(self, prefixes, dialect=None): self._prefixes = self._prefixes + tuple( [ - (_literal_as_text(p, allow_coercion_to_text=True), dialect) + (coercions.expect(roles.StatementOptionRole, p), dialect) for p in prefixes ] ) @@ -236,13 +205,13 @@ class HasSuffixes(object): def _setup_suffixes(self, suffixes, dialect=None): self._suffixes = self._suffixes + tuple( [ - (_literal_as_text(p, allow_coercion_to_text=True), dialect) + (coercions.expect(roles.StatementOptionRole, p), dialect) for p in suffixes ] ) -class FromClause(Selectable): +class FromClause(roles.FromClauseRole, Selectable): """Represent an element that can be used within the ``FROM`` clause of a ``SELECT`` statement. @@ -265,16 +234,6 @@ class FromClause(Selectable): named_with_column = False _hide_froms = [] - _is_join = False - _is_select = False - _is_from_container = False - - _is_lateral = False - - _textual = False - """a marker that allows us to easily distinguish a :class:`.TextAsFrom` - or similar object from other kinds of :class:`.FromClause` objects.""" - schema = None """Define the 'schema' attribute for this :class:`.FromClause`. @@ -284,6 +243,11 @@ class FromClause(Selectable): """ + is_selectable = has_selectable = True + _is_from_clause = True + _is_text_as_from = False + _is_join = False + def _translate_schema(self, effective_schema, map_): return effective_schema @@ -726,8 +690,8 @@ class Join(FromClause): :class:`.FromClause` object. """ - self.left = _interpret_as_from(left) - self.right = _interpret_as_from(right).self_group() + self.left = coercions.expect(roles.FromClauseRole, left) + self.right = coercions.expect(roles.FromClauseRole, right).self_group() if onclause is None: self.onclause = self._match_primaries(self.left, self.right) @@ -1292,7 +1256,9 @@ class Alias(FromClause): .. versionadded:: 0.9.0 """ - return _interpret_as_from(selectable).alias(name=name, flat=flat) + return coercions.expect(roles.FromClauseRole, selectable).alias( + name=name, flat=flat + ) def _init(self, selectable, name=None): baseselectable = selectable @@ -1327,14 +1293,6 @@ class Alias(FromClause): else: return self.name.encode("ascii", "backslashreplace") - def as_scalar(self): - try: - return self.element.as_scalar() - except AttributeError: - raise AttributeError( - "Element %s does not support " "'as_scalar()'" % self.element - ) - def is_derived_from(self, fromclause): if fromclause in self._cloned_set: return True @@ -1426,7 +1384,9 @@ class Lateral(Alias): :ref:`lateral_selects` - overview of usage. """ - return _interpret_as_from(selectable).lateral(name=name) + return coercions.expect(roles.FromClauseRole, selectable).lateral( + name=name + ) class TableSample(Alias): @@ -1488,7 +1448,7 @@ class TableSample(Alias): REPEATABLE sub-clause is also rendered. """ - return _interpret_as_from(selectable).tablesample( + return coercions.expect(roles.FromClauseRole, selectable).tablesample( sampling, name=name, seed=seed ) @@ -1523,7 +1483,7 @@ class CTE(Generative, HasSuffixes, Alias): Please see :meth:`.HasCte.cte` for detail on CTE usage. """ - return _interpret_as_from(selectable).cte( + return coercions.expect(roles.HasCTERole, selectable).cte( name=name, recursive=recursive ) @@ -1588,7 +1548,7 @@ class CTE(Generative, HasSuffixes, Alias): ) -class HasCTE(object): +class HasCTE(roles.HasCTERole): """Mixin that declares a class to include CTE support. .. versionadded:: 1.1 @@ -2059,13 +2019,22 @@ class ForUpdateArg(ClauseElement): self.key_share = key_share if of is not None: self.of = [ - _interpret_as_column_or_from(elem) for elem in util.to_list(of) + coercions.expect(roles.ColumnsClauseRole, elem) + for elem in util.to_list(of) ] else: self.of = None -class SelectBase(HasCTE, Executable, FromClause): +class SelectBase( + roles.SelectStatementRole, + roles.DMLSelectRole, + roles.CompoundElementRole, + roles.InElementRole, + HasCTE, + Executable, + FromClause, +): """Base class for SELECT statements. @@ -2075,15 +2044,32 @@ class SelectBase(HasCTE, Executable, FromClause): """ + _is_select_statement = True + + @util.deprecated( + "1.4", + "The :meth:`.SelectBase.as_scalar` method is deprecated and will be " + "removed in a future release. Please refer to " + ":meth:`.SelectBase.scalar_subquery`.", + ) def as_scalar(self): + return self.scalar_subquery() + + def scalar_subquery(self): """return a 'scalar' representation of this selectable, which can be used as a column expression. Typically, a select statement which has only one column in its columns - clause is eligible to be used as a scalar expression. + clause is eligible to be used as a scalar expression. The scalar + subquery can then be used in the WHERE clause or columns clause of + an enclosing SELECT. - The returned object is an instance of - :class:`ScalarSelect`. + Note that the scalar subquery differentiates from the FROM-level + subquery that can be produced using the :meth:`.SelectBase.subquery` + method. + + .. versionchanged: 1.4 - the ``.as_scalar()`` method was renamed to + :meth:`.SelectBase.scalar_subquery`. """ return ScalarSelect(self) @@ -2097,7 +2083,7 @@ class SelectBase(HasCTE, Executable, FromClause): :meth:`~.SelectBase.as_scalar`. """ - return self.as_scalar().label(name) + return self.scalar_subquery().label(name) @_generative @util.deprecated( @@ -2181,20 +2167,19 @@ class GenerativeSelect(SelectBase): {"autocommit": autocommit} ) if limit is not None: - self._limit_clause = _offset_or_limit_clause(limit) + self._limit_clause = self._offset_or_limit_clause(limit) if offset is not None: - self._offset_clause = _offset_or_limit_clause(offset) + self._offset_clause = self._offset_or_limit_clause(offset) self._bind = bind if order_by is not None: self._order_by_clause = ClauseList( *util.to_list(order_by), - _literal_as_text=_literal_and_labels_as_label_reference + _literal_as_text_role=roles.OrderByRole ) if group_by is not None: self._group_by_clause = ClauseList( - *util.to_list(group_by), - _literal_as_text=_literal_as_label_reference + *util.to_list(group_by), _literal_as_text_role=roles.ByOfRole ) @property @@ -2287,6 +2272,37 @@ class GenerativeSelect(SelectBase): """ self.use_labels = True + def _offset_or_limit_clause(self, element, name=None, type_=None): + """Convert the given value to an "offset or limit" clause. + + This handles incoming integers and converts to an expression; if + an expression is already given, it is passed through. + + """ + return coercions.expect( + roles.LimitOffsetRole, element, name=name, type_=type_ + ) + + def _offset_or_limit_clause_asint(self, clause, attrname): + """Convert the "offset or limit" clause of a select construct to an + integer. + + This is only possible if the value is stored as a simple bound + parameter. Otherwise, a compilation error is raised. + + """ + if clause is None: + return None + try: + value = clause._limit_offset_value + except AttributeError: + raise exc.CompileError( + "This SELECT structure does not use a simple " + "integer value for %s" % attrname + ) + else: + return util.asint(value) + @property def _limit(self): """Get an integer value for the limit. This should only be used @@ -2295,7 +2311,7 @@ class GenerativeSelect(SelectBase): isn't currently set to an integer. """ - return _offset_or_limit_clause_asint(self._limit_clause, "limit") + return self._offset_or_limit_clause_asint(self._limit_clause, "limit") @property def _simple_int_limit(self): @@ -2319,7 +2335,9 @@ class GenerativeSelect(SelectBase): offset isn't currently set to an integer. """ - return _offset_or_limit_clause_asint(self._offset_clause, "offset") + return self._offset_or_limit_clause_asint( + self._offset_clause, "offset" + ) @_generative def limit(self, limit): @@ -2339,7 +2357,7 @@ class GenerativeSelect(SelectBase): """ - self._limit_clause = _offset_or_limit_clause(limit) + self._limit_clause = self._offset_or_limit_clause(limit) @_generative def offset(self, offset): @@ -2361,7 +2379,7 @@ class GenerativeSelect(SelectBase): """ - self._offset_clause = _offset_or_limit_clause(offset) + self._offset_clause = self._offset_or_limit_clause(offset) @_generative def order_by(self, *clauses): @@ -2403,8 +2421,7 @@ class GenerativeSelect(SelectBase): if getattr(self, "_order_by_clause", None) is not None: clauses = list(self._order_by_clause) + list(clauses) self._order_by_clause = ClauseList( - *clauses, - _literal_as_text=_literal_and_labels_as_label_reference + *clauses, _literal_as_text_role=roles.OrderByRole ) def append_group_by(self, *clauses): @@ -2423,7 +2440,7 @@ class GenerativeSelect(SelectBase): if getattr(self, "_group_by_clause", None) is not None: clauses = list(self._group_by_clause) + list(clauses) self._group_by_clause = ClauseList( - *clauses, _literal_as_text=_literal_as_label_reference + *clauses, _literal_as_text_role=roles.ByOfRole ) @property @@ -2478,7 +2495,7 @@ class CompoundSelect(GenerativeSelect): # some DBs do not like ORDER BY in the inner queries of a UNION, etc. for n, s in enumerate(selects): - s = _clause_element_as_expr(s) + s = coercions.expect(roles.CompoundElementRole, s) if not numcols: numcols = len(s.c._all_columns) @@ -2741,7 +2758,6 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): _correlate = () _correlate_except = None _memoized_property = SelectBase._memoized_property - _is_select = True @util.deprecated_params( autocommit=( @@ -2965,12 +2981,14 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): self._distinct = True else: self._distinct = [ - _literal_as_text(e) for e in util.to_list(distinct) + coercions.expect(roles.WhereHavingRole, e) + for e in util.to_list(distinct) ] if from_obj is not None: self._from_obj = util.OrderedSet( - _interpret_as_from(f) for f in util.to_list(from_obj) + coercions.expect(roles.FromClauseRole, f) + for f in util.to_list(from_obj) ) else: self._from_obj = util.OrderedSet() @@ -2986,7 +3004,7 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): if cols_present: self._raw_columns = [] for c in columns: - c = _interpret_as_column_or_from(c) + c = coercions.expect(roles.ColumnsClauseRole, c) if isinstance(c, ScalarSelect): c = c.self_group(against=operators.comma_op) self._raw_columns.append(c) @@ -2994,16 +3012,16 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): self._raw_columns = [] if whereclause is not None: - self._whereclause = _literal_as_text(whereclause).self_group( - against=operators._asbool - ) + self._whereclause = coercions.expect( + roles.WhereHavingRole, whereclause + ).self_group(against=operators._asbool) else: self._whereclause = None if having is not None: - self._having = _literal_as_text(having).self_group( - against=operators._asbool - ) + self._having = coercions.expect( + roles.WhereHavingRole, having + ).self_group(against=operators._asbool) else: self._having = None @@ -3202,15 +3220,6 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): else: self._hints = self._hints.union({(selectable, dialect_name): text}) - @property - def type(self): - raise exc.InvalidRequestError( - "Select objects don't have a type. " - "Call as_scalar() on this Select " - "object to return a 'scalar' version " - "of this Select." - ) - @_memoized_property.method def locate_all_froms(self): """return a Set of all FromClause elements referenced by this Select. @@ -3496,7 +3505,7 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): self._reset_exported() rc = [] for c in columns: - c = _interpret_as_column_or_from(c) + c = coercions.expect(roles.ColumnsClauseRole, c) if isinstance(c, ScalarSelect): c = c.self_group(against=operators.comma_op) rc.append(c) @@ -3530,7 +3539,7 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): """ if expr: - expr = [_literal_as_label_reference(e) for e in expr] + expr = [coercions.expect(roles.ByOfRole, e) for e in expr] if isinstance(self._distinct, list): self._distinct = self._distinct + expr else: @@ -3618,7 +3627,7 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): self._correlate = () else: self._correlate = set(self._correlate).union( - _interpret_as_from(f) for f in fromclauses + coercions.expect(roles.FromClauseRole, f) for f in fromclauses ) @_generative @@ -3653,7 +3662,7 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): self._correlate_except = () else: self._correlate_except = set(self._correlate_except or ()).union( - _interpret_as_from(f) for f in fromclauses + coercions.expect(roles.FromClauseRole, f) for f in fromclauses ) def append_correlation(self, fromclause): @@ -3668,7 +3677,7 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): self._auto_correlate = False self._correlate = set(self._correlate).union( - _interpret_as_from(f) for f in fromclause + coercions.expect(roles.FromClauseRole, f) for f in fromclause ) def append_column(self, column): @@ -3689,7 +3698,7 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): """ self._reset_exported() - column = _interpret_as_column_or_from(column) + column = coercions.expect(roles.ColumnsClauseRole, column) if isinstance(column, ScalarSelect): column = column.self_group(against=operators.comma_op) @@ -3705,7 +3714,7 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): standard :term:`method chaining`. """ - clause = _literal_as_text(clause) + clause = coercions.expect(roles.WhereHavingRole, clause) self._prefixes = self._prefixes + (clause,) def append_whereclause(self, whereclause): @@ -3747,7 +3756,7 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): """ self._reset_exported() - fromclause = _interpret_as_from(fromclause) + fromclause = coercions.expect(roles.FromClauseRole, fromclause) self._from_obj = self._from_obj.union([fromclause]) @_memoized_property @@ -3894,7 +3903,7 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect): bind = property(bind, _set_bind) -class ScalarSelect(Generative, Grouping): +class ScalarSelect(roles.InElementRole, Generative, Grouping): _from_objects = [] _is_from_container = True _is_implicitly_boolean = False @@ -3956,7 +3965,7 @@ class Exists(UnaryExpression): else: if not args: args = ([literal_column("*")],) - s = Select(*args, **kwargs).as_scalar().self_group() + s = Select(*args, **kwargs).scalar_subquery().self_group() UnaryExpression.__init__( self, @@ -3999,6 +4008,7 @@ class Exists(UnaryExpression): return e +# TODO: rename to TextualSelect, this is not a FROM clause class TextAsFrom(SelectBase): """Wrap a :class:`.TextClause` construct within a :class:`.SelectBase` interface. @@ -4022,7 +4032,7 @@ class TextAsFrom(SelectBase): __visit_name__ = "text_as_from" - _textual = True + _is_textual = True def __init__(self, text, columns, positional=False): self.element = text diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index 0d39445527..6a520a2d59 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -14,13 +14,14 @@ import datetime as dt import decimal import json +from . import coercions from . import elements from . import operators +from . import roles from . import type_api from .base import _bind_or_error from .base import SchemaEventTarget from .elements import _defer_name -from .elements import _literal_as_binds from .elements import quoted_name from .elements import Slice from .elements import TypeCoerce as type_coerce # noqa @@ -2187,19 +2188,21 @@ class JSON(Indexable, TypeEngine): if not isinstance(index, util.string_types) and isinstance( index, compat.collections_abc.Sequence ): - index = default_comparator._check_literal( - self.expr, - operators.json_path_getitem_op, + index = coercions.expect( + roles.BinaryElementRole, index, + expr=self.expr, + operator=operators.json_path_getitem_op, bindparam_type=JSON.JSONPathType, ) operator = operators.json_path_getitem_op else: - index = default_comparator._check_literal( - self.expr, - operators.json_getitem_op, + index = coercions.expect( + roles.BinaryElementRole, index, + expr=self.expr, + operator=operators.json_getitem_op, bindparam_type=JSON.JSONIndexType, ) operator = operators.json_getitem_op @@ -2372,17 +2375,20 @@ class ARRAY(SchemaEventTarget, Indexable, Concatenable, TypeEngine): if self.type.zero_indexes: index = slice(index.start + 1, index.stop + 1, index.step) index = Slice( - _literal_as_binds( + coercions.expect( + roles.ExpressionElementRole, index.start, name=self.expr.key, type_=type_api.INTEGERTYPE, ), - _literal_as_binds( + coercions.expect( + roles.ExpressionElementRole, index.stop, name=self.expr.key, type_=type_api.INTEGERTYPE, ), - _literal_as_binds( + coercions.expect( + roles.ExpressionElementRole, index.step, name=self.expr.key, type_=type_api.INTEGERTYPE, @@ -2438,7 +2444,7 @@ class ARRAY(SchemaEventTarget, Indexable, Concatenable, TypeEngine): """ operator = operator if operator else operators.eq return operator( - elements._literal_as_binds(other), + coercions.expect(roles.ExpressionElementRole, other), elements.CollectionAggregate._create_any(self.expr), ) @@ -2473,7 +2479,7 @@ class ARRAY(SchemaEventTarget, Indexable, Concatenable, TypeEngine): """ operator = operator if operator else operators.eq return operator( - elements._literal_as_binds(other), + coercions.expect(roles.ExpressionElementRole, other), elements.CollectionAggregate._create_all(self.expr), ) diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py index bdeae96137..5eea27e08a 100644 --- a/lib/sqlalchemy/sql/type_api.py +++ b/lib/sqlalchemy/sql/type_api.py @@ -57,6 +57,9 @@ class TypeEngine(Visitable): default_comparator = None + def __clause_element__(self): + return self.expr + def __init__(self, expr): self.expr = expr self.type = expr.type diff --git a/lib/sqlalchemy/testing/__init__.py b/lib/sqlalchemy/testing/__init__.py index c52e9a76f1..090c8488a5 100644 --- a/lib/sqlalchemy/testing/__init__.py +++ b/lib/sqlalchemy/testing/__init__.py @@ -23,6 +23,7 @@ from .assertions import expect_warnings # noqa from .assertions import in_ # noqa from .assertions import is_ # noqa from .assertions import is_false # noqa +from .assertions import is_instance_of # noqa from .assertions import is_not_ # noqa from .assertions import is_true # noqa from .assertions import le_ # noqa diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py index d8038e225c..819fedcc77 100644 --- a/lib/sqlalchemy/testing/assertions.py +++ b/lib/sqlalchemy/testing/assertions.py @@ -247,6 +247,10 @@ def le_(a, b, msg=None): assert a <= b, msg or "%r != %r" % (a, b) +def is_instance_of(a, b, msg=None): + assert isinstance(a, b), msg or "%r is not an instance of %r" % (a, b) + + def is_true(a, msg=None): is_(a, True, msg=msg) diff --git a/lib/sqlalchemy/testing/suite/test_cte.py b/lib/sqlalchemy/testing/suite/test_cte.py index 012de7911d..c7e6a266ca 100644 --- a/lib/sqlalchemy/testing/suite/test_cte.py +++ b/lib/sqlalchemy/testing/suite/test_cte.py @@ -198,9 +198,9 @@ class CTETest(fixtures.TablesTest): conn.execute( some_other_table.delete().where( some_other_table.c.data - == select([cte.c.data]).where( - cte.c.id == some_other_table.c.id - ) + == select([cte.c.data]) + .where(cte.c.id == some_other_table.c.id) + .scalar_subquery() ) ) eq_( diff --git a/lib/sqlalchemy/testing/suite/test_results.py b/lib/sqlalchemy/testing/suite/test_results.py index a07e45df2d..8bf66d2eb9 100644 --- a/lib/sqlalchemy/testing/suite/test_results.py +++ b/lib/sqlalchemy/testing/suite/test_results.py @@ -98,7 +98,7 @@ class RowFetchTest(fixtures.TablesTest): """ datetable = self.tables.has_dates - s = select([datetable.alias("x").c.today]).as_scalar() + s = select([datetable.alias("x").c.today]).scalar_subquery() s2 = select([datetable.c.id, s.label("somelabel")]) row = config.db.execute(s2).first() diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py index e6e4907abb..a8cdc5ef7f 100644 --- a/lib/sqlalchemy/util/__init__.py +++ b/lib/sqlalchemy/util/__init__.py @@ -140,6 +140,7 @@ from .langhelpers import portable_instancemethod # noqa from .langhelpers import quoted_token_parser # noqa from .langhelpers import safe_reraise # noqa from .langhelpers import set_creation_order # noqa +from .langhelpers import string_or_unprintable # noqa from .langhelpers import symbol # noqa from .langhelpers import unbound_method_to_callable # noqa from .langhelpers import warn # noqa diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py index 7d1321e0b8..7a7faff60c 100644 --- a/lib/sqlalchemy/util/langhelpers.py +++ b/lib/sqlalchemy/util/langhelpers.py @@ -79,6 +79,16 @@ class safe_reraise(object): compat.reraise(type_, value, traceback) +def string_or_unprintable(element): + if isinstance(element, compat.string_types): + return element + else: + try: + return str(element) + except Exception: + return "unprintable element %r" % element + + def clsname_as_plain_name(cls): return " ".join( n.lower() for n in re.findall(r"([A-Z][a-z]+)", cls.__name__) diff --git a/test/dialect/mssql/test_compiler.py b/test/dialect/mssql/test_compiler.py index 30a11d16b2..498de763c6 100644 --- a/test/dialect/mssql/test_compiler.py +++ b/test/dialect/mssql/test_compiler.py @@ -267,7 +267,7 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): stmt = table.update().values( val=select([other.c.newval]) .where(table.c.sym == other.c.sym) - .as_scalar() + .scalar_subquery() ) self.assert_compile( @@ -334,14 +334,14 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): t = table("sometable", column("somecolumn")) self.assert_compile( - t.select().where(t.c.somecolumn == t.select()), + t.select().where(t.c.somecolumn == t.select().scalar_subquery()), "SELECT sometable.somecolumn FROM " "sometable WHERE sometable.somecolumn = " "(SELECT sometable.somecolumn FROM " "sometable)", ) self.assert_compile( - t.select().where(t.c.somecolumn != t.select()), + t.select().where(t.c.somecolumn != t.select().scalar_subquery()), "SELECT sometable.somecolumn FROM " "sometable WHERE sometable.somecolumn != " "(SELECT sometable.somecolumn FROM " @@ -844,7 +844,7 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): t1 = table("t1", column("x", Integer), column("y", Integer)) t2 = table("t2", column("x", Integer), column("y", Integer)) - order_by = select([t2.c.y]).where(t1.c.x == t2.c.x).as_scalar() + order_by = select([t2.c.y]).where(t1.c.x == t2.c.x).scalar_subquery() s = ( select([t1]) .where(t1.c.x == 5) @@ -1135,7 +1135,7 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): idx = Index("test_idx_data_1", tbl.c.data, mssql_where=tbl.c.data > 1) self.assert_compile( schema.CreateIndex(idx), - "CREATE INDEX test_idx_data_1 ON test (data) WHERE data > 1" + "CREATE INDEX test_idx_data_1 ON test (data) WHERE data > 1", ) def test_index_ordering(self): diff --git a/test/dialect/mssql/test_query.py b/test/dialect/mssql/test_query.py index bf836fc147..4ecf0634c9 100644 --- a/test/dialect/mssql/test_query.py +++ b/test/dialect/mssql/test_query.py @@ -130,7 +130,7 @@ class LegacySchemaAliasingTest(fixtures.TestBase, AssertsCompiledSQL): def test_column_subquery_to_alias(self): a1 = self.t2.alias("a1") - s = select([self.t2, select([a1.c.a]).as_scalar()]) + s = select([self.t2, select([a1.c.a]).scalar_subquery()]) self._assert_sql( s, "SELECT t2_1.a, t2_1.b, t2_1.c, " diff --git a/test/dialect/mysql/test_query.py b/test/dialect/mysql/test_query.py index e492f1e17f..39485ae102 100644 --- a/test/dialect/mysql/test_query.py +++ b/test/dialect/mysql/test_query.py @@ -247,7 +247,7 @@ class AnyAllTest(fixtures.TablesTest): def test_any_w_comparator(self): stuff = self.tables.stuff stmt = select([stuff.c.id]).where( - stuff.c.value > any_(select([stuff.c.value])) + stuff.c.value > any_(select([stuff.c.value]).scalar_subquery()) ) eq_(testing.db.execute(stmt).fetchall(), [(2,), (3,), (4,), (5,)]) @@ -255,13 +255,13 @@ class AnyAllTest(fixtures.TablesTest): def test_all_w_comparator(self): stuff = self.tables.stuff stmt = select([stuff.c.id]).where( - stuff.c.value >= all_(select([stuff.c.value])) + stuff.c.value >= all_(select([stuff.c.value]).scalar_subquery()) ) eq_(testing.db.execute(stmt).fetchall(), [(5,)]) def test_any_literal(self): stuff = self.tables.stuff - stmt = select([4 == any_(select([stuff.c.value]))]) + stmt = select([4 == any_(select([stuff.c.value]).scalar_subquery())]) is_(testing.db.execute(stmt).scalar(), True) diff --git a/test/ext/declarative/test_basic.py b/test/ext/declarative/test_basic.py index 3fe2f1bfe9..b6c911813b 100644 --- a/test/ext/declarative/test_basic.py +++ b/test/ext/declarative/test_basic.py @@ -705,7 +705,7 @@ class DeclarativeTest(DeclarativeTestBase): rel = relationship("User", primaryjoin="User.id==Bar.__table__.id") assert_raises_message( - exc.InvalidRequestError, + AttributeError, "does not have a mapped column named " "'__table__'", configure_mappers, ) @@ -1469,7 +1469,7 @@ class DeclarativeTest(DeclarativeTestBase): User.address_count = sa.orm.column_property( sa.select([sa.func.count(Address.id)]) .where(Address.user_id == User.id) - .as_scalar() + .scalar_subquery() ) Base.metadata.create_all() u1 = User( @@ -1514,9 +1514,9 @@ class DeclarativeTest(DeclarativeTestBase): # this doesn't really gain us anything. but if # one is used, lets have it function as expected... return sa.orm.column_property( - sa.select([sa.func.count(Address.id)]).where( - Address.user_id == cls.id - ) + sa.select([sa.func.count(Address.id)]) + .where(Address.user_id == cls.id) + .scalar_subquery() ) Base.metadata.create_all() @@ -1616,7 +1616,7 @@ class DeclarativeTest(DeclarativeTestBase): adr_count = sa.orm.column_property( sa.select( [sa.func.count(Address.id)], Address.user_id == id - ).as_scalar() + ).scalar_subquery() ) addresses = relationship(Address) @@ -1920,7 +1920,7 @@ class DeclarativeTest(DeclarativeTestBase): User.address_count = sa.orm.column_property( sa.select([sa.func.count(Address.id)]) .where(Address.user_id == User.id) - .as_scalar() + .scalar_subquery() ) Base.metadata.create_all() u1 = User( diff --git a/test/ext/declarative/test_mixin.py b/test/ext/declarative/test_mixin.py index ef9bbd354d..df7dea77c4 100644 --- a/test/ext/declarative/test_mixin.py +++ b/test/ext/declarative/test_mixin.py @@ -1851,7 +1851,7 @@ class DeclaredAttrTest(DeclarativeTestBase, testing.AssertsCompiledSQL): return column_property( select([func.count(Address.id)]) .where(Address.user_id == cls.id) - .as_scalar() + .scalar_subquery() ) class Address(Base): diff --git a/test/ext/test_baked.py b/test/ext/test_baked.py index 55cd9376bf..00c6a78b4d 100644 --- a/test/ext/test_baked.py +++ b/test/ext/test_baked.py @@ -765,9 +765,11 @@ class ResultTest(BakedTest): ) main_bq = self.bakery( - lambda s: s.query(Address.id, sub_bq.to_query(s).as_scalar()) + lambda s: s.query(Address.id, sub_bq.to_query(s).scalar_subquery()) + ) + main_bq += lambda q: q.filter( + sub_bq.to_query(q).scalar_subquery() == "ed" ) - main_bq += lambda q: q.filter(sub_bq.to_query(q).as_scalar() == "ed") main_bq += lambda q: q.order_by(Address.id) sess = Session() diff --git a/test/orm/inheritance/test_assorted_poly.py b/test/orm/inheritance/test_assorted_poly.py index 525824669c..07ffd93857 100644 --- a/test/orm/inheritance/test_assorted_poly.py +++ b/test/orm/inheritance/test_assorted_poly.py @@ -2184,6 +2184,7 @@ class CorrelateExceptWPolyAdaptTest( select([func.count(Superclass.id)]) .where(Superclass.common_id == id) .correlate_except(Superclass) + .scalar_subquery() ) if not use_correlate_except: @@ -2191,6 +2192,7 @@ class CorrelateExceptWPolyAdaptTest( select([func.count(Superclass.id)]) .where(Superclass.common_id == Common.id) .correlate(Common) + .scalar_subquery() ) return Common, Superclass diff --git a/test/orm/inheritance/test_basic.py b/test/orm/inheritance/test_basic.py index 59ecc7c986..472dafcc4b 100644 --- a/test/orm/inheritance/test_basic.py +++ b/test/orm/inheritance/test_basic.py @@ -344,8 +344,8 @@ class PolymorphicOnNotLocalTest(fixtures.MappedTest): assert_raises_message( sa_exc.ArgumentError, - "Only direct column-mapped property or " - "SQL expression can be passed for polymorphic_on", + r"Column expression or string key expected for argument " + r"'polymorphic_on'; got .*function", go, ) diff --git a/test/orm/inheritance/test_polymorphic_rel.py b/test/orm/inheritance/test_polymorphic_rel.py index c16573b237..bbf0d472a1 100644 --- a/test/orm/inheritance/test_polymorphic_rel.py +++ b/test/orm/inheritance/test_polymorphic_rel.py @@ -1295,7 +1295,7 @@ class _PolymorphicTestBase(object): subq = ( sess.query(engineers.c.person_id) .filter(Engineer.primary_language == "java") - .statement.as_scalar() + .statement.scalar_subquery() ) eq_(sess.query(Person).filter(Person.person_id.in_(subq)).one(), e1) @@ -1634,15 +1634,16 @@ class _PolymorphicTestBase(object): # this for a long time did not work with PolymorphicAliased and # PolymorphicUnions, which was due to the no_replacement_traverse - # annotation added to query.statement which then went into as_scalar(). - # this is removed as of :ticket:`4304` so now works. + # annotation added to query.statement which then went into + # scalar_subquery(). this is removed as of :ticket:`4304` so now + # works. eq_( sess.query(Person.name) .filter( sess.query(Company.name) .filter(Company.company_id == Person.company_id) .correlate(Person) - .as_scalar() + .scalar_subquery() == "Elbonia, Inc." ) .all(), @@ -1660,7 +1661,7 @@ class _PolymorphicTestBase(object): sess.query(Company.name) .filter(Company.company_id == paliased.company_id) .correlate(paliased) - .as_scalar() + .scalar_subquery() == "Elbonia, Inc." ) .all(), @@ -1678,7 +1679,7 @@ class _PolymorphicTestBase(object): sess.query(Company.name) .filter(Company.company_id == paliased.company_id) .correlate(paliased) - .as_scalar() + .scalar_subquery() == "Elbonia, Inc." ) .all(), @@ -1720,7 +1721,7 @@ class PolymorphicTest(_PolymorphicTestBase, _Polymorphic): sess.query(Company.name) .filter(Company.company_id == p_poly.company_id) .correlate(p_poly) - .as_scalar() + .scalar_subquery() == "Elbonia, Inc." ) .all(), @@ -1739,7 +1740,7 @@ class PolymorphicTest(_PolymorphicTestBase, _Polymorphic): sess.query(Company.name) .filter(Company.company_id == p_poly.company_id) .correlate(p_poly) - .as_scalar() + .scalar_subquery() == "Elbonia, Inc." ) .all(), diff --git a/test/orm/inheritance/test_single.py b/test/orm/inheritance/test_single.py index 1b28974b7b..a54cfbe933 100644 --- a/test/orm/inheritance/test_single.py +++ b/test/orm/inheritance/test_single.py @@ -962,7 +962,7 @@ class RelationshipToSingleTest( .select_from(Engineer) .filter(Engineer.company_id == Company.company_id) .correlate(Company) - .as_scalar() + .scalar_subquery() ) self.assert_compile( diff --git a/test/orm/test_collection.py b/test/orm/test_collection.py index 83f4f44511..eb0df3107f 100644 --- a/test/orm/test_collection.py +++ b/test/orm/test_collection.py @@ -1827,15 +1827,15 @@ class DictHelpersTest(fixtures.MappedTest): def test_column_mapped_assertions(self): assert_raises_message( sa_exc.ArgumentError, - "Column-based expression object expected " - "for argument 'mapping_spec'; got: 'a'", + "Column expression expected " + "for argument 'mapping_spec'; got 'a'.", collections.column_mapped_collection, "a", ) assert_raises_message( sa_exc.ArgumentError, - "Column-based expression object expected " - "for argument 'mapping_spec'; got: 'a'", + "Column expression expected " + "for argument 'mapping_spec'; got .*TextClause.", collections.column_mapped_collection, text("a"), ) diff --git a/test/orm/test_deprecations.py b/test/orm/test_deprecations.py index 679b3ec5bd..c5938416cd 100644 --- a/test/orm/test_deprecations.py +++ b/test/orm/test_deprecations.py @@ -550,6 +550,17 @@ class StrongIdentityMapTest(_fixtures.FixtureTest): class DeprecatedMapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): __dialect__ = "default" + def test_query_as_scalar(self): + users, User = self.tables.users, self.classes.User + + mapper(User, users) + s = Session() + with assertions.expect_deprecated( + r"The Query.as_scalar\(\) method is deprecated and will " + "be removed in a future release." + ): + s.query(User).as_scalar() + def test_cancel_order_by(self): users, User = self.tables.users, self.classes.User diff --git a/test/orm/test_eager_relations.py b/test/orm/test_eager_relations.py index 4adf9a72f2..499803543d 100644 --- a/test/orm/test_eager_relations.py +++ b/test/orm/test_eager_relations.py @@ -3258,7 +3258,7 @@ class SubqueryAliasingTest(fixtures.MappedTest, testing.AssertsCompiledSQL): b_table.c.a_id == a_table.c.id ) - self._fixture({"summation": column_property(cp)}) + self._fixture({"summation": column_property(cp.scalar_subquery())}) self.assert_compile( create_session() .query(A) @@ -3281,7 +3281,7 @@ class SubqueryAliasingTest(fixtures.MappedTest, testing.AssertsCompiledSQL): b_table.c.a_id == a_table.c.id ) - self._fixture({"summation": column_property(cp)}) + self._fixture({"summation": column_property(cp.scalar_subquery())}) self.assert_compile( create_session() .query(A) @@ -3306,7 +3306,7 @@ class SubqueryAliasingTest(fixtures.MappedTest, testing.AssertsCompiledSQL): .correlate(a_table) ) - self._fixture({"summation": column_property(cp)}) + self._fixture({"summation": column_property(cp.scalar_subquery())}) self.assert_compile( create_session() .query(A) @@ -3330,7 +3330,7 @@ class SubqueryAliasingTest(fixtures.MappedTest, testing.AssertsCompiledSQL): select([func.sum(b_table.c.value)]) .where(b_table.c.a_id == a_table.c.id) .correlate(a_table) - .as_scalar() + .scalar_subquery() ) # up until 0.8, this was ordering by a new subquery. @@ -3360,7 +3360,7 @@ class SubqueryAliasingTest(fixtures.MappedTest, testing.AssertsCompiledSQL): select([func.sum(b_table.c.value)]) .where(b_table.c.a_id == a_table.c.id) .correlate(a_table) - .as_scalar() + .scalar_subquery() .label("foo") ) self.assert_compile( @@ -3387,7 +3387,7 @@ class SubqueryAliasingTest(fixtures.MappedTest, testing.AssertsCompiledSQL): select([func.sum(b_table.c.value)]) .where(b_table.c.a_id == a_table.c.id) .correlate(a_table) - .as_scalar() + .scalar_subquery() ) # test a different unary operator self.assert_compile( @@ -4689,7 +4689,7 @@ class SubqueryTest(fixtures.MappedTest): tag_score = tag_score.label(labelname) user_score = user_score.label(labelname) else: - user_score = user_score.as_scalar() + user_score = user_score.scalar_subquery() mapper( Tag, @@ -4798,40 +4798,28 @@ class CorrelatedSubqueryTest(fixtures.MappedTest): ) def test_labeled_on_date_noalias(self): - self._do_test("label", True, False) + self._do_test(True, True, False) def test_scalar_on_date_noalias(self): - self._do_test("scalar", True, False) - - def test_plain_on_date_noalias(self): - self._do_test("none", True, False) + self._do_test(False, True, False) def test_labeled_on_limitid_noalias(self): - self._do_test("label", False, False) + self._do_test(True, False, False) def test_scalar_on_limitid_noalias(self): - self._do_test("scalar", False, False) - - def test_plain_on_limitid_noalias(self): - self._do_test("none", False, False) + self._do_test(False, False, False) def test_labeled_on_date_alias(self): - self._do_test("label", True, True) + self._do_test(True, True, True) def test_scalar_on_date_alias(self): - self._do_test("scalar", True, True) - - def test_plain_on_date_alias(self): - self._do_test("none", True, True) + self._do_test(False, True, True) def test_labeled_on_limitid_alias(self): - self._do_test("label", False, True) + self._do_test(True, False, True) def test_scalar_on_limitid_alias(self): - self._do_test("scalar", False, True) - - def test_plain_on_limitid_alias(self): - self._do_test("none", False, True) + self._do_test(False, False, True) def _do_test(self, labeled, ondate, aliasstuff): stuff, users = self.tables.stuff, self.tables.users @@ -4843,7 +4831,6 @@ class CorrelatedSubqueryTest(fixtures.MappedTest): pass mapper(Stuff, stuff) - if aliasstuff: salias = stuff.alias() else: @@ -4879,11 +4866,11 @@ class CorrelatedSubqueryTest(fixtures.MappedTest): else: operator = operators.eq - if labeled == "label": + if labeled: stuff_view = stuff_view.label("foo") operator = operators.eq - elif labeled == "scalar": - stuff_view = stuff_view.as_scalar() + else: + stuff_view = stuff_view.scalar_subquery() if ondate: mapper( diff --git a/test/orm/test_froms.py b/test/orm/test_froms.py index 1cccfff268..3cec10e681 100644 --- a/test/orm/test_froms.py +++ b/test/orm/test_froms.py @@ -161,79 +161,79 @@ class QueryCorrelatesLikeSelect(QueryTest, AssertsCompiledSQL): "WHERE addresses.user_id = users.id) AS anon_1 FROM users" ) - def test_as_scalar_select_auto_correlate(self): + def test_scalar_subquery_select_auto_correlate(self): addresses, users = self.tables.addresses, self.tables.users query = select( [func.count(addresses.c.id)], addresses.c.user_id == users.c.id - ).as_scalar() + ).scalar_subquery() query = select([users.c.name.label("users_name"), query]) self.assert_compile( query, self.query_correlated, dialect=default.DefaultDialect() ) - def test_as_scalar_select_explicit_correlate(self): + def test_scalar_subquery_select_explicit_correlate(self): addresses, users = self.tables.addresses, self.tables.users query = ( select( [func.count(addresses.c.id)], addresses.c.user_id == users.c.id ) .correlate(users) - .as_scalar() + .scalar_subquery() ) query = select([users.c.name.label("users_name"), query]) self.assert_compile( query, self.query_correlated, dialect=default.DefaultDialect() ) - def test_as_scalar_select_correlate_off(self): + def test_scalar_subquery_select_correlate_off(self): addresses, users = self.tables.addresses, self.tables.users query = ( select( [func.count(addresses.c.id)], addresses.c.user_id == users.c.id ) .correlate(None) - .as_scalar() + .scalar_subquery() ) query = select([users.c.name.label("users_name"), query]) self.assert_compile( query, self.query_not_correlated, dialect=default.DefaultDialect() ) - def test_as_scalar_query_auto_correlate(self): + def test_scalar_subquery_query_auto_correlate(self): sess = create_session() Address, User = self.classes.Address, self.classes.User query = ( sess.query(func.count(Address.id)) .filter(Address.user_id == User.id) - .as_scalar() + .scalar_subquery() ) query = sess.query(User.name, query) self.assert_compile( query, self.query_correlated, dialect=default.DefaultDialect() ) - def test_as_scalar_query_explicit_correlate(self): + def test_scalar_subquery_query_explicit_correlate(self): sess = create_session() Address, User = self.classes.Address, self.classes.User query = ( sess.query(func.count(Address.id)) .filter(Address.user_id == User.id) .correlate(self.tables.users) - .as_scalar() + .scalar_subquery() ) query = sess.query(User.name, query) self.assert_compile( query, self.query_correlated, dialect=default.DefaultDialect() ) - def test_as_scalar_query_correlate_off(self): + def test_scalar_subquery_query_correlate_off(self): sess = create_session() Address, User = self.classes.Address, self.classes.User query = ( sess.query(func.count(Address.id)) .filter(Address.user_id == User.id) .correlate(None) - .as_scalar() + .scalar_subquery() ) query = sess.query(User.name, query) self.assert_compile( @@ -3243,7 +3243,7 @@ class ExternalColumnsTest(QueryTest): users.c.id == addresses.c.user_id, ) .correlate(users) - .as_scalar() + .scalar_subquery() ), }, ) @@ -3398,7 +3398,9 @@ class ExternalColumnsTest(QueryTest): select( [func.count(addresses.c.id)], users.c.id == addresses.c.user_id, - ).correlate(users) + ) + .correlate(users) + .scalar_subquery() ), }, ) diff --git a/test/orm/test_lazy_relations.py b/test/orm/test_lazy_relations.py index 78680701e0..8d1e82fb54 100644 --- a/test/orm/test_lazy_relations.py +++ b/test/orm/test_lazy_relations.py @@ -1290,7 +1290,7 @@ class CorrelatedTest(fixtures.MappedTest): Stuff, primaryjoin=sa.and_( user_t.c.id == stuff.c.user_id, - stuff.c.id == (stuff_view.as_scalar()), + stuff.c.id == (stuff_view.scalar_subquery()), ), ) }, diff --git a/test/orm/test_mapper.py b/test/orm/test_mapper.py index 93cf19faea..fa1f1fdf88 100644 --- a/test/orm/test_mapper.py +++ b/test/orm/test_mapper.py @@ -764,7 +764,7 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): expr = User.name + "name" expr2 = sa.select([User.name, users.c.id]) m.add_property("x", column_property(expr)) - m.add_property("y", column_property(expr2)) + m.add_property("y", column_property(expr2.scalar_subquery())) assert User.x.property.columns[0] is not expr assert User.x.property.columns[0].element.left is users.c.name diff --git a/test/orm/test_query.py b/test/orm/test_query.py index 86b304be9d..27176d6fb9 100644 --- a/test/orm/test_query.py +++ b/test/orm/test_query.py @@ -467,7 +467,7 @@ class RawSelectTest(QueryTest, AssertsCompiledSQL): select([func.count(Address.id)]) .where(User.id == Address.user_id) .correlate(User) - .as_scalar(), + .scalar_subquery(), ] ), "SELECT users.name, addresses.id, " @@ -489,7 +489,7 @@ class RawSelectTest(QueryTest, AssertsCompiledSQL): select([func.count(Address.id)]) .where(uu.id == Address.user_id) .correlate(uu) - .as_scalar(), + .scalar_subquery(), ] ), # for a long time, "uu.id = address.user_id" was reversed; @@ -1072,7 +1072,7 @@ class InvalidGenerationsTest(QueryTest, AssertsCompiledSQL): assert_raises_message( sa_exc.ArgumentError, - "Object .*User.* is not legal as a SQL literal value", + "SQL expression element expected, got .*User", distinct, User, ) @@ -1080,7 +1080,7 @@ class InvalidGenerationsTest(QueryTest, AssertsCompiledSQL): ua = aliased(User) assert_raises_message( sa_exc.ArgumentError, - "Object .*User.* is not legal as a SQL literal value", + "SQL expression element expected, got .*User", distinct, ua, ) @@ -1088,21 +1088,21 @@ class InvalidGenerationsTest(QueryTest, AssertsCompiledSQL): s = Session() assert_raises_message( sa_exc.ArgumentError, - "Object .*User.* is not legal as a SQL literal value", + "SQL expression element or literal value expected, got .*User", lambda: s.query(User).filter(User.name == User), ) u1 = User() assert_raises_message( sa_exc.ArgumentError, - "Object .*User.* is not legal as a SQL literal value", + "SQL expression element expected, got .*User", distinct, u1, ) assert_raises_message( sa_exc.ArgumentError, - "Object .*User.* is not legal as a SQL literal value", + "SQL expression element or literal value expected, got .*User", lambda: s.query(User).filter(User.name == u1), ) @@ -1757,16 +1757,17 @@ class ExpressionTest(QueryTest, AssertsCompiledSQL): session = create_session() - q = session.query(User.id).filter(User.id == 7) + q = session.query(User.id).filter(User.id == 7).scalar_subquery() q = session.query(Address).filter(Address.user_id == q) + assert isinstance(q._criterion.right, expression.ColumnElement) self.assert_compile( q, "SELECT addresses.id AS addresses_id, addresses.user_id " "AS addresses_user_id, addresses.email_address AS " "addresses_email_address FROM addresses WHERE " - "addresses.user_id = (SELECT users.id AS users_id " + "addresses.user_id = (SELECT users.id " "FROM users WHERE users.id = :id_1)", ) @@ -1842,12 +1843,12 @@ class ExpressionTest(QueryTest, AssertsCompiledSQL): "WHERE users.id = :id_1) AS foo", ) - def test_as_scalar(self): + def test_scalar_subquery(self): User = self.classes.User session = create_session() - q = session.query(User.id).filter(User.id == 7).as_scalar() + q = session.query(User.id).filter(User.id == 7).scalar_subquery() self.assert_compile( session.query(User).filter(User.id.in_(q)), @@ -1866,7 +1867,7 @@ class ExpressionTest(QueryTest, AssertsCompiledSQL): session.query(User.id) .filter(User.id == bindparam("foo")) .params(foo=7) - .subquery() + .scalar_subquery() ) q = session.query(User).filter(User.id.in_(q)) @@ -2081,6 +2082,8 @@ class ColumnPropertyTest(_fixtures.FixtureTest, AssertsCompiledSQL): ) if label: stmt = stmt.label("email_ad") + else: + stmt = stmt.scalar_subquery() mapper( User, @@ -5491,7 +5494,7 @@ class SessionBindTest(QueryTest): column_property( select([func.sum(Address.id)]) .where(Address.user_id == User.id) - .as_scalar() + .scalar_subquery() ), ) session = Session() diff --git a/test/orm/test_rel_fn.py b/test/orm/test_rel_fn.py index 5e6ac53fe8..259d5ea9c2 100644 --- a/test/orm/test_rel_fn.py +++ b/test/orm/test_rel_fn.py @@ -470,7 +470,7 @@ class _JoinFixtures(object): self.left, self.right, primaryjoin=self.left.c.id == func.foo(self.right.c.lid), - consider_as_foreign_keys=[self.right.c.lid], + consider_as_foreign_keys={self.right.c.lid}, **kw ) @@ -480,10 +480,10 @@ class _JoinFixtures(object): self.composite_multi_ref, self.composite_target, self.composite_multi_ref, - consider_as_foreign_keys=[ + consider_as_foreign_keys={ self.composite_multi_ref.c.uid2, self.composite_multi_ref.c.oid, - ], + }, **kw ) @@ -1099,10 +1099,10 @@ class DetermineJoinTest(_JoinFixtures, fixtures.TestBase, AssertsCompiledSQL): self.m2mleft, self.m2mright, secondary=self.m2msecondary_ambig_fks, - consider_as_foreign_keys=[ + consider_as_foreign_keys={ self.m2msecondary_ambig_fks.c.lid1, self.m2msecondary_ambig_fks.c.rid1, - ], + }, ) def test_determine_join_w_fks_ambig_m2m(self): diff --git a/test/orm/test_relationships.py b/test/orm/test_relationships.py index 652b8bacd6..494ec2b0dc 100644 --- a/test/orm/test_relationships.py +++ b/test/orm/test_relationships.py @@ -2288,9 +2288,8 @@ class JoinConditionErrorTest(fixtures.TestBase): assert_raises_message( sa.exc.ArgumentError, - "Column-based expression object expected " - "for argument '%s'; got: '%s', type %r" - % (argname, arg[0], type(arg[0])), + "Column expression expected " + "for argument '%s'; got '%s'" % (argname, arg[0]), configure_mappers, ) diff --git a/test/orm/test_update_delete.py b/test/orm/test_update_delete.py index 9e3f8074fc..217a4f77aa 100644 --- a/test/orm/test_update_delete.py +++ b/test/orm/test_update_delete.py @@ -326,13 +326,15 @@ class UpdateDeleteTest(fixtures.MappedTest): assert_raises( exc.InvalidRequestError, sess.query(User) - .filter(User.name == select([func.max(User.name)])) + .filter( + User.name == select([func.max(User.name)]).scalar_subquery() + ) .delete, synchronize_session="evaluate", ) sess.query(User).filter( - User.name == select([func.max(User.name)]) + User.name == select([func.max(User.name)]).scalar_subquery() ).delete(synchronize_session="fetch") assert john not in sess @@ -969,7 +971,7 @@ class UpdateDeleteFromTest(fixtures.MappedTest): subq = ( s.query(func.max(Document.title).label("title")) .group_by(Document.user_id) - .subquery() + .scalar_subquery() ) s.query(Document).filter(Document.title.in_(subq)).update( @@ -999,7 +1001,7 @@ class UpdateDeleteFromTest(fixtures.MappedTest): subq = ( s.query(func.max(Document.title).label("title")) .group_by(Document.user_id) - .subquery() + .scalar_subquery() ) # this would work with Firebird if you do literal_column('1') diff --git a/test/profiles.txt b/test/profiles.txt index 5c51d623f6..d12b3c4ae0 100644 --- a/test/profiles.txt +++ b/test/profiles.txt @@ -1,15 +1,15 @@ # /home/classic/dev/sqlalchemy/test/profiles.txt # This file is written out on a per-environment basis. -# For each test in aaa_profiling, the corresponding function and +# For each test in aaa_profiling, the corresponding function and # environment is located within this file. If it doesn't exist, # the test is skipped. -# If a callcount does exist, it is compared to what we received. +# If a callcount does exist, it is compared to what we received. # assertions are raised if the counts do not match. -# -# To add a new callcount test, apply the function_call_count -# decorator and re-run the tests using the --write-profiles +# +# To add a new callcount test, apply the function_call_count +# decorator and re-run the tests using the --write-profiles # option - this file will be rewritten including the new count. -# +# # TEST: test.aaa_profiling.test_compiler.CompileTest.test_insert @@ -803,14 +803,14 @@ test.aaa_profiling.test_resultset.ResultSetTest.test_unicode 3.7_sqlite_pysqlite # TEST: test.aaa_profiling.test_zoomark.ZooMarkTest.test_invocation -test.aaa_profiling.test_zoomark.ZooMarkTest.test_invocation 2.7_postgresql_psycopg2_dbapiunicode_cextensions 6010,303,3905,12590,1177,2109,2565 -test.aaa_profiling.test_zoomark.ZooMarkTest.test_invocation 2.7_postgresql_psycopg2_dbapiunicode_nocextensions 6054,303,4025,13894,1292,2122,2796 -test.aaa_profiling.test_zoomark.ZooMarkTest.test_invocation 3.7_postgresql_psycopg2_dbapiunicode_cextensions 5742,282,3865,12494,1163,2042,2604 -test.aaa_profiling.test_zoomark.ZooMarkTest.test_invocation 3.7_postgresql_psycopg2_dbapiunicode_nocextensions 5808,282,3993,13758,1273,2061,2816 +test.aaa_profiling.test_zoomark.ZooMarkTest.test_invocation 2.7_postgresql_psycopg2_dbapiunicode_cextensions 6401,324,3953,12769,1200,2189,2624 +test.aaa_profiling.test_zoomark.ZooMarkTest.test_invocation 2.7_postgresql_psycopg2_dbapiunicode_nocextensions 6445,324,4073,14073,1315,2202,2855 +test.aaa_profiling.test_zoomark.ZooMarkTest.test_invocation 3.7_postgresql_psycopg2_dbapiunicode_cextensions 6116,303,3913,12679,1186,2122,2667 +test.aaa_profiling.test_zoomark.ZooMarkTest.test_invocation 3.7_postgresql_psycopg2_dbapiunicode_nocextensions 6165,303,4041,13945,1295,2141,2883 # TEST: test.aaa_profiling.test_zoomark_orm.ZooMarkTest.test_invocation -test.aaa_profiling.test_zoomark_orm.ZooMarkTest.test_invocation 2.7_postgresql_psycopg2_dbapiunicode_cextensions 6627,413,6961,18335,1191,2723 -test.aaa_profiling.test_zoomark_orm.ZooMarkTest.test_invocation 2.7_postgresql_psycopg2_dbapiunicode_nocextensions 6719,418,7081,19404,1297,2758 -test.aaa_profiling.test_zoomark_orm.ZooMarkTest.test_invocation 3.7_postgresql_psycopg2_dbapiunicode_cextensions 6601,403,7149,18918,1178,2790 -test.aaa_profiling.test_zoomark_orm.ZooMarkTest.test_invocation 3.7_postgresql_psycopg2_dbapiunicode_nocextensions 6696,408,7285,20029,1280,2831 +test.aaa_profiling.test_zoomark_orm.ZooMarkTest.test_invocation 2.7_postgresql_psycopg2_dbapiunicode_cextensions 6957,409,7143,18720,1214,2829 +test.aaa_profiling.test_zoomark_orm.ZooMarkTest.test_invocation 2.7_postgresql_psycopg2_dbapiunicode_nocextensions 7053,414,7263,19789,1320,2864 +test.aaa_profiling.test_zoomark_orm.ZooMarkTest.test_invocation 3.7_postgresql_psycopg2_dbapiunicode_cextensions 6946,400,7331,19307,1201,2895 +test.aaa_profiling.test_zoomark_orm.ZooMarkTest.test_invocation 3.7_postgresql_psycopg2_dbapiunicode_nocextensions 7046,405,7467,20418,1302,2935 diff --git a/test/sql/test_case_statement.py b/test/sql/test_case_statement.py index 7bb612cfa4..491ff42bc9 100644 --- a/test/sql/test_case_statement.py +++ b/test/sql/test_case_statement.py @@ -131,7 +131,7 @@ class CaseTest(fixtures.TestBase, AssertsCompiledSQL): def test_literal_interpretation_ambiguous(self): assert_raises_message( exc.ArgumentError, - r"Ambiguous literal: 'x'. Use the 'text\(\)' function", + r"Column expression expected, got 'x'", case, [("x", "y")], ) @@ -139,7 +139,7 @@ class CaseTest(fixtures.TestBase, AssertsCompiledSQL): def test_literal_interpretation_ambiguous_tuple(self): assert_raises_message( exc.ArgumentError, - r"Ambiguous literal: \('x', 'y'\). Use the 'text\(\)' function", + r"Column expression expected, got \('x', 'y'\)", case, [(("x", "y"), "z")], ) diff --git a/test/sql/test_compare.py b/test/sql/test_compare.py index 3608851ed6..f2feea7572 100644 --- a/test/sql/test_compare.py +++ b/test/sql/test_compare.py @@ -243,8 +243,8 @@ class CompareAndCopyTest(fixtures.TestBase): FromGrouping(table_a.alias("b")), ), lambda: ( - select([table_a.c.a]).as_scalar(), - select([table_a.c.a]).where(table_a.c.b == 5).as_scalar(), + select([table_a.c.a]).scalar_subquery(), + select([table_a.c.a]).where(table_a.c.b == 5).scalar_subquery(), ), lambda: ( exists().where(table_a.c.a == 5), @@ -291,6 +291,7 @@ class CompareAndCopyTest(fixtures.TestBase): and "__init__" in cls.__dict__ and not issubclass(cls, (Annotated)) and "orm" not in cls.__module__ + and "compiler" not in cls.__module__ and "crud" not in cls.__module__ and "dialects" not in cls.__module__ # TODO: dialects? ).difference({ColumnElement, UnaryExpression}) diff --git a/test/sql/test_compiler.py b/test/sql/test_compiler.py index 9d6e17a1d2..e012c2713e 100644 --- a/test/sql/test_compiler.py +++ b/test/sql/test_compiler.py @@ -71,7 +71,6 @@ from sqlalchemy.sql import column from sqlalchemy.sql import compiler from sqlalchemy.sql import label from sqlalchemy.sql import table -from sqlalchemy.sql.expression import _literal_as_text from sqlalchemy.sql.expression import ClauseList from sqlalchemy.sql.expression import HasPrefixes from sqlalchemy.testing import assert_raises @@ -165,7 +164,8 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): "columns; use this object directly within a " "column-level expression.", lambda: hasattr( - select([table1.c.myid]).as_scalar().self_group(), "columns" + select([table1.c.myid]).scalar_subquery().self_group(), + "columns", ), ) assert_raises_message( @@ -174,14 +174,17 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): "columns; use this object directly within a " "column-level expression.", lambda: hasattr( - select([table1.c.myid]).as_scalar(), "columns" + select([table1.c.myid]).scalar_subquery(), "columns" ), ) else: assert not hasattr( - select([table1.c.myid]).as_scalar().self_group(), "columns" + select([table1.c.myid]).scalar_subquery().self_group(), + "columns", + ) + assert not hasattr( + select([table1.c.myid]).scalar_subquery(), "columns" ) - assert not hasattr(select([table1.c.myid]).as_scalar(), "columns") def test_prefix_constructor(self): class Pref(HasPrefixes): @@ -327,7 +330,12 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): ) def test_select_precol_compile_ordering(self): - s1 = select([column("x")]).select_from(text("a")).limit(5).as_scalar() + s1 = ( + select([column("x")]) + .select_from(text("a")) + .limit(5) + .scalar_subquery() + ) s2 = select([s1]).limit(10) class MyCompiler(compiler.SQLCompiler): @@ -631,7 +639,7 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): ) self.assert_compile( - exists(s.as_scalar()), + exists(s.scalar_subquery()), "EXISTS (SELECT mytable.myid FROM mytable " "WHERE mytable.myid = :myid_1)", ) @@ -754,22 +762,12 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): "addresses, users WHERE addresses.user_id = " "users.user_id) AS s", ) - self.assert_compile( - table1.select( - table1.c.myid - == select([table1.c.myid], table1.c.name == "jack") - ), - "SELECT mytable.myid, mytable.name, " - "mytable.description FROM mytable WHERE " - "mytable.myid = (SELECT mytable.myid FROM " - "mytable WHERE mytable.name = :name_1)", - ) self.assert_compile( table1.select( table1.c.myid == select( [table2.c.otherid], table1.c.name == table2.c.othername - ) + ).scalar_subquery() ), "SELECT mytable.myid, mytable.name, " "mytable.description FROM mytable WHERE " @@ -822,7 +820,7 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): order_by=[ select( [table2.c.otherid], table1.c.myid == table2.c.otherid - ) + ).scalar_subquery() ] ), "SELECT mytable.myid, mytable.name, " @@ -831,42 +829,22 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): "myothertable WHERE mytable.myid = " "myothertable.otherid)", ) - self.assert_compile( - table1.select( - order_by=[ - desc( - select( - [table2.c.otherid], - table1.c.myid == table2.c.otherid, - ) - ) - ] - ), - "SELECT mytable.myid, mytable.name, " - "mytable.description FROM mytable ORDER BY " - "(SELECT myothertable.otherid FROM " - "myothertable WHERE mytable.myid = " - "myothertable.otherid) DESC", - ) def test_scalar_select(self): - assert_raises_message( - exc.InvalidRequestError, - r"Select objects don't have a type\. Call as_scalar\(\) " - r"on this Select object to return a 'scalar' " - r"version of this Select\.", - func.coalesce, - select([table1.c.myid]), + + self.assert_compile( + func.coalesce(select([table1.c.myid]).scalar_subquery()), + "coalesce((SELECT mytable.myid FROM mytable))", ) - s = select([table1.c.myid], correlate=False).as_scalar() + s = select([table1.c.myid], correlate=False).scalar_subquery() self.assert_compile( select([table1, s]), "SELECT mytable.myid, mytable.name, " "mytable.description, (SELECT mytable.myid " "FROM mytable) AS anon_1 FROM mytable", ) - s = select([table1.c.myid]).as_scalar() + s = select([table1.c.myid]).scalar_subquery() self.assert_compile( select([table2, s]), "SELECT myothertable.otherid, " @@ -874,7 +852,7 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): "mytable.myid FROM mytable) AS anon_1 FROM " "myothertable", ) - s = select([table1.c.myid]).correlate(None).as_scalar() + s = select([table1.c.myid]).correlate(None).scalar_subquery() self.assert_compile( select([table1, s]), "SELECT mytable.myid, mytable.name, " @@ -882,17 +860,17 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): "FROM mytable) AS anon_1 FROM mytable", ) - s = select([table1.c.myid]).as_scalar() + s = select([table1.c.myid]).scalar_subquery() s2 = s.where(table1.c.myid == 5) self.assert_compile( s2, "(SELECT mytable.myid FROM mytable WHERE mytable.myid = :myid_1)", ) self.assert_compile(s, "(SELECT mytable.myid FROM mytable)") - # test that aliases use as_scalar() when used in an explicitly + # test that aliases use scalar_subquery() when used in an explicitly # scalar context - s = select([table1.c.myid]).alias() + s = select([table1.c.myid]).scalar_subquery() self.assert_compile( select([table1.c.myid]).where(table1.c.myid == s), "SELECT mytable.myid FROM mytable WHERE " @@ -902,10 +880,9 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): self.assert_compile( select([table1.c.myid]).where(s > table1.c.myid), "SELECT mytable.myid FROM mytable WHERE " - "mytable.myid < (SELECT mytable.myid FROM " - "mytable)", + "(SELECT mytable.myid FROM mytable) > mytable.myid", ) - s = select([table1.c.myid]).as_scalar() + s = select([table1.c.myid]).scalar_subquery() self.assert_compile( select([table2, s]), "SELECT myothertable.otherid, " @@ -922,7 +899,7 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): "- :param_1 AS anon_1", ) self.assert_compile( - select([select([table1.c.name]).as_scalar() + literal("x")]), + select([select([table1.c.name]).scalar_subquery() + literal("x")]), "SELECT (SELECT mytable.name FROM mytable) " "|| :param_1 AS anon_1", ) @@ -939,7 +916,7 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): # scalar selects should not have any attributes on their 'c' or # 'columns' attribute - s = select([table1.c.myid]).as_scalar() + s = select([table1.c.myid]).scalar_subquery() try: s.c.foo except exc.InvalidRequestError as err: @@ -965,12 +942,12 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): qlat = ( select([zips.c.latitude], zips.c.zipcode == zipcode) .correlate(None) - .as_scalar() + .scalar_subquery() ) qlng = ( select([zips.c.longitude], zips.c.zipcode == zipcode) .correlate(None) - .as_scalar() + .scalar_subquery() ) q = select( @@ -999,10 +976,10 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): zalias = zips.alias("main_zip") qlat = select( [zips.c.latitude], zips.c.zipcode == zalias.c.zipcode - ).as_scalar() + ).scalar_subquery() qlng = select( [zips.c.longitude], zips.c.zipcode == zalias.c.zipcode - ).as_scalar() + ).scalar_subquery() q = select( [ places.c.id, @@ -1025,7 +1002,9 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): ) a1 = table2.alias("t2alias") - s1 = select([a1.c.otherid], table1.c.myid == a1.c.otherid).as_scalar() + s1 = select( + [a1.c.otherid], table1.c.myid == a1.c.otherid + ).scalar_subquery() j1 = table1.join(table2, table1.c.myid == table2.c.otherid) s2 = select([table1, s1], from_obj=j1) self.assert_compile( @@ -2337,7 +2316,11 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): assert s3.compile().params == {"myid": 9, "myotherid": 7} # test using same 'unique' param object twice in one compile - s = select([table1.c.myid]).where(table1.c.myid == 12).as_scalar() + s = ( + select([table1.c.myid]) + .where(table1.c.myid == 12) + .scalar_subquery() + ) s2 = select([table1, s], table1.c.myid == s) self.assert_compile( s2, @@ -2884,7 +2867,7 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): lambda: sel1.c, ) - # calling label or as_scalar doesn't compile + # calling label or scalar_subquery doesn't compile # anything. sel2 = select([func.substr(my_str, 2, 3)]).label("my_substr") @@ -2895,7 +2878,7 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): dialect=default.DefaultDialect(), ) - sel3 = select([my_str]).as_scalar() + sel3 = select([my_str]).scalar_subquery() assert_raises_message( exc.CompileError, "Cannot compile Column object until its 'name' is assigned.", @@ -3919,13 +3902,13 @@ class CorrelateTest(fixtures.TestBase, AssertsCompiledSQL): def test_correlate_semiauto_where(self): t1, t2, s1 = self._fixture() self._assert_where_correlated( - select([t2]).where(t2.c.a == s1.correlate(t2)) + select([t2]).where(t2.c.a == s1.correlate(t2).scalar_subquery()) ) def test_correlate_semiauto_column(self): t1, t2, s1 = self._fixture() self._assert_column_correlated( - select([t2, s1.correlate(t2).as_scalar()]) + select([t2, s1.correlate(t2).scalar_subquery()]) ) def test_correlate_semiauto_from(self): @@ -3935,31 +3918,35 @@ class CorrelateTest(fixtures.TestBase, AssertsCompiledSQL): def test_correlate_semiauto_having(self): t1, t2, s1 = self._fixture() self._assert_having_correlated( - select([t2]).having(t2.c.a == s1.correlate(t2)) + select([t2]).having(t2.c.a == s1.correlate(t2).scalar_subquery()) ) def test_correlate_except_inclusion_where(self): t1, t2, s1 = self._fixture() self._assert_where_correlated( - select([t2]).where(t2.c.a == s1.correlate_except(t1)) + select([t2]).where( + t2.c.a == s1.correlate_except(t1).scalar_subquery() + ) ) def test_correlate_except_exclusion_where(self): t1, t2, s1 = self._fixture() self._assert_where_uncorrelated( - select([t2]).where(t2.c.a == s1.correlate_except(t2)) + select([t2]).where( + t2.c.a == s1.correlate_except(t2).scalar_subquery() + ) ) def test_correlate_except_inclusion_column(self): t1, t2, s1 = self._fixture() self._assert_column_correlated( - select([t2, s1.correlate_except(t1).as_scalar()]) + select([t2, s1.correlate_except(t1).scalar_subquery()]) ) def test_correlate_except_exclusion_column(self): t1, t2, s1 = self._fixture() self._assert_column_uncorrelated( - select([t2, s1.correlate_except(t2).as_scalar()]) + select([t2, s1.correlate_except(t2).scalar_subquery()]) ) def test_correlate_except_inclusion_from(self): @@ -3977,22 +3964,28 @@ class CorrelateTest(fixtures.TestBase, AssertsCompiledSQL): def test_correlate_except_none(self): t1, t2, s1 = self._fixture() self._assert_where_all_correlated( - select([t1, t2]).where(t2.c.a == s1.correlate_except(None)) + select([t1, t2]).where( + t2.c.a == s1.correlate_except(None).scalar_subquery() + ) ) def test_correlate_except_having(self): t1, t2, s1 = self._fixture() self._assert_having_correlated( - select([t2]).having(t2.c.a == s1.correlate_except(t1)) + select([t2]).having( + t2.c.a == s1.correlate_except(t1).scalar_subquery() + ) ) def test_correlate_auto_where(self): t1, t2, s1 = self._fixture() - self._assert_where_correlated(select([t2]).where(t2.c.a == s1)) + self._assert_where_correlated( + select([t2]).where(t2.c.a == s1.scalar_subquery()) + ) def test_correlate_auto_column(self): t1, t2, s1 = self._fixture() - self._assert_column_correlated(select([t2, s1.as_scalar()])) + self._assert_column_correlated(select([t2, s1.scalar_subquery()])) def test_correlate_auto_from(self): t1, t2, s1 = self._fixture() @@ -4000,18 +3993,20 @@ class CorrelateTest(fixtures.TestBase, AssertsCompiledSQL): def test_correlate_auto_having(self): t1, t2, s1 = self._fixture() - self._assert_having_correlated(select([t2]).having(t2.c.a == s1)) + self._assert_having_correlated( + select([t2]).having(t2.c.a == s1.scalar_subquery()) + ) def test_correlate_disabled_where(self): t1, t2, s1 = self._fixture() self._assert_where_uncorrelated( - select([t2]).where(t2.c.a == s1.correlate(None)) + select([t2]).where(t2.c.a == s1.correlate(None).scalar_subquery()) ) def test_correlate_disabled_column(self): t1, t2, s1 = self._fixture() self._assert_column_uncorrelated( - select([t2, s1.correlate(None).as_scalar()]) + select([t2, s1.correlate(None).scalar_subquery()]) ) def test_correlate_disabled_from(self): @@ -4023,19 +4018,21 @@ class CorrelateTest(fixtures.TestBase, AssertsCompiledSQL): def test_correlate_disabled_having(self): t1, t2, s1 = self._fixture() self._assert_having_uncorrelated( - select([t2]).having(t2.c.a == s1.correlate(None)) + select([t2]).having(t2.c.a == s1.correlate(None).scalar_subquery()) ) def test_correlate_all_where(self): t1, t2, s1 = self._fixture() self._assert_where_all_correlated( - select([t1, t2]).where(t2.c.a == s1.correlate(t1, t2)) + select([t1, t2]).where( + t2.c.a == s1.correlate(t1, t2).scalar_subquery() + ) ) def test_correlate_all_column(self): t1, t2, s1 = self._fixture() self._assert_column_all_correlated( - select([t1, t2, s1.correlate(t1, t2).as_scalar()]) + select([t1, t2, s1.correlate(t1, t2).scalar_subquery()]) ) def test_correlate_all_from(self): @@ -4049,7 +4046,7 @@ class CorrelateTest(fixtures.TestBase, AssertsCompiledSQL): assert_raises_message( exc.InvalidRequestError, "returned no FROM clauses due to auto-correlation", - select([t1, t2]).where(t2.c.a == s1).compile, + select([t1, t2]).where(t2.c.a == s1.scalar_subquery()).compile, ) def test_correlate_from_all_ok(self): @@ -4063,7 +4060,7 @@ class CorrelateTest(fixtures.TestBase, AssertsCompiledSQL): def test_correlate_auto_where_singlefrom(self): t1, t2, s1 = self._fixture() s = select([t1.c.a]) - s2 = select([t1]).where(t1.c.a == s) + s2 = select([t1]).where(t1.c.a == s.scalar_subquery()) self.assert_compile( s2, "SELECT t1.a FROM t1 WHERE t1.a = " "(SELECT t1.a FROM t1)" ) @@ -4073,7 +4070,7 @@ class CorrelateTest(fixtures.TestBase, AssertsCompiledSQL): s = select([t1.c.a]) - s2 = select([t1]).where(t1.c.a == s.correlate(t1)) + s2 = select([t1]).where(t1.c.a == s.correlate(t1).scalar_subquery()) self._assert_where_single_full_correlated(s2) def test_correlate_except_semiauto_where_singlefrom(self): @@ -4081,7 +4078,9 @@ class CorrelateTest(fixtures.TestBase, AssertsCompiledSQL): s = select([t1.c.a]) - s2 = select([t1]).where(t1.c.a == s.correlate_except(t2)) + s2 = select([t1]).where( + t1.c.a == s.correlate_except(t2).scalar_subquery() + ) self._assert_where_single_full_correlated(s2) def test_correlate_alone_noeffect(self): @@ -4098,7 +4097,7 @@ class CorrelateTest(fixtures.TestBase, AssertsCompiledSQL): s = select([t2.c.b]).where(t1.c.a == t2.c.a) s = s.correlate_except(t2).alias("s") - s2 = select([func.foo(s.c.b)]).as_scalar() + s2 = select([func.foo(s.c.b)]).scalar_subquery() s3 = select([t1], order_by=s2) self.assert_compile( @@ -4155,8 +4154,8 @@ class CorrelateTest(fixtures.TestBase, AssertsCompiledSQL): t3 = table("t3", column("z")) s = select([t1.c.x]).where(t1.c.x == t2.c.y) - s2 = select([t3.c.z]).where(t3.c.z == s.as_scalar()) - s3 = select([t1]).where(t1.c.x == s2.as_scalar()) + s2 = select([t3.c.z]).where(t3.c.z == s.scalar_subquery()) + s3 = select([t1]).where(t1.c.x == s2.scalar_subquery()) self.assert_compile( s3, @@ -4214,15 +4213,6 @@ class CoercionTest(fixtures.TestBase, AssertsCompiledSQL): dialect=default.DefaultDialect(supports_native_boolean=False), ) - def test_null_constant(self): - self.assert_compile(_literal_as_text(None), "NULL") - - def test_false_constant(self): - self.assert_compile(_literal_as_text(False), "false") - - def test_true_constant(self): - self.assert_compile(_literal_as_text(True), "true") - def test_val_and_false(self): t = self._fixture() self.assert_compile(and_(t.c.id == 1, False), "false") diff --git a/test/sql/test_cte.py b/test/sql/test_cte.py index 7008bc1cca..ac46e7d5d9 100644 --- a/test/sql/test_cte.py +++ b/test/sql/test_cte.py @@ -47,7 +47,9 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): select([regional_sales.c.region]) .where( regional_sales.c.total_sales - > select([func.sum(regional_sales.c.total_sales) / 10]) + > select( + [func.sum(regional_sales.c.total_sales) / 10] + ).scalar_subquery() ) .cte("top_regions") ) diff --git a/test/sql/test_defaults.py b/test/sql/test_defaults.py index 76ef38e1f1..ed7af2572e 100644 --- a/test/sql/test_defaults.py +++ b/test/sql/test_defaults.py @@ -657,8 +657,8 @@ class DefaultTest(fixtures.TestBase): ): assert_raises_message( sa.exc.ArgumentError, - "SQL expression object expected, got object of type " - "<.* 'list'> instead", + r"SQL expression for WHERE/HAVING role expected, " + r"got \[(?:Sequence|ColumnDefault|DefaultClause)\('y'.*\)\]", t.select, [const], ) @@ -913,7 +913,7 @@ class PKDefaultTest(fixtures.TablesTest): "id", Integer, primary_key=True, - default=sa.select([func.max(t2.c.nextid)]).as_scalar(), + default=sa.select([func.max(t2.c.nextid)]).scalar_subquery(), ), Column("data", String(30)), ) @@ -1740,7 +1740,7 @@ class SpecialTypePKTest(fixtures.TestBase): self._run_test(server_default="1", autoincrement=False) def test_clause(self): - stmt = select([cast("INT_1", type_=self.MyInteger)]).as_scalar() + stmt = select([cast("INT_1", type_=self.MyInteger)]).scalar_subquery() self._run_test(default=stmt) @testing.requires.returning diff --git a/test/sql/test_delete.py b/test/sql/test_delete.py index f572a510ca..1f4c49c562 100644 --- a/test/sql/test_delete.py +++ b/test/sql/test_delete.py @@ -109,13 +109,13 @@ class DeleteTest(_DeleteTestBase, fixtures.TablesTest, AssertsCompiledSQL): stmt, "DELETE FROM mytable AS t1 WHERE t1.myid = :myid_1" ) - def test_correlated(self): + def test_non_correlated_select(self): table1, table2 = self.tables.mytable, self.tables.myothertable # test a non-correlated WHERE clause s = select([table2.c.othername], table2.c.otherid == 7) self.assert_compile( - delete(table1, table1.c.name == s), + delete(table1, table1.c.name == s.scalar_subquery()), "DELETE FROM mytable " "WHERE mytable.name = (" "SELECT myothertable.othername " @@ -124,10 +124,13 @@ class DeleteTest(_DeleteTestBase, fixtures.TablesTest, AssertsCompiledSQL): ")", ) + def test_correlated_select(self): + table1, table2 = self.tables.mytable, self.tables.myothertable + # test one that is actually correlated... s = select([table2.c.othername], table2.c.otherid == table1.c.myid) self.assert_compile( - table1.delete(table1.c.name == s), + table1.delete(table1.c.name == s.scalar_subquery()), "DELETE FROM mytable " "WHERE mytable.name = (" "SELECT myothertable.othername " diff --git a/test/sql/test_deprecations.py b/test/sql/test_deprecations.py index 7990cd56c6..8e8591aecc 100644 --- a/test/sql/test_deprecations.py +++ b/test/sql/test_deprecations.py @@ -6,6 +6,7 @@ from sqlalchemy import column from sqlalchemy import create_engine from sqlalchemy import exc from sqlalchemy import ForeignKey +from sqlalchemy import func from sqlalchemy import Integer from sqlalchemy import MetaData from sqlalchemy import select @@ -17,6 +18,8 @@ from sqlalchemy import text from sqlalchemy import util from sqlalchemy.engine import default from sqlalchemy.schema import DDL +from sqlalchemy.sql import coercions +from sqlalchemy.sql import roles from sqlalchemy.sql import util as sql_util from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_message @@ -24,6 +27,7 @@ from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import engines from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures +from sqlalchemy.testing import is_true from sqlalchemy.testing import mock @@ -372,6 +376,117 @@ class ForUpdateTest(fixtures.TestBase, AssertsCompiledSQL): eq_(s._for_update_arg.nowait, True) +class SubqueryCoercionsTest(fixtures.TestBase, AssertsCompiledSQL): + def test_column_roles(self): + stmt = select([table1.c.myid]) + + for role in [ + roles.WhereHavingRole, + roles.ExpressionElementRole, + roles.ByOfRole, + roles.OrderByRole, + # roles.LabeledColumnExprRole + ]: + with testing.expect_deprecated( + "coercing SELECT object to scalar " + "subquery in a column-expression context is deprecated" + ): + coerced = coercions.expect(role, stmt) + is_true(coerced.compare(stmt.scalar_subquery())) + + with testing.expect_deprecated( + "coercing SELECT object to scalar " + "subquery in a column-expression context is deprecated" + ): + coerced = coercions.expect(role, stmt.alias()) + is_true(coerced.compare(stmt.scalar_subquery())) + + def test_labeled_role(self): + stmt = select([table1.c.myid]) + + with testing.expect_deprecated( + "coercing SELECT object to scalar " + "subquery in a column-expression context is deprecated" + ): + coerced = coercions.expect(roles.LabeledColumnExprRole, stmt) + is_true(coerced.compare(stmt.scalar_subquery().label(None))) + + with testing.expect_deprecated( + "coercing SELECT object to scalar " + "subquery in a column-expression context is deprecated" + ): + coerced = coercions.expect( + roles.LabeledColumnExprRole, stmt.alias() + ) + is_true(coerced.compare(stmt.scalar_subquery().label(None))) + + def test_scalar_select(self): + + with testing.expect_deprecated( + "coercing SELECT object to scalar " + "subquery in a column-expression context is deprecated" + ): + self.assert_compile( + func.coalesce(select([table1.c.myid])), + "coalesce((SELECT mytable.myid FROM mytable))", + ) + + with testing.expect_deprecated( + "coercing SELECT object to scalar " + "subquery in a column-expression context is deprecated" + ): + s = select([table1.c.myid]).alias() + self.assert_compile( + select([table1.c.myid]).where(table1.c.myid == s), + "SELECT mytable.myid FROM mytable WHERE " + "mytable.myid = (SELECT mytable.myid FROM " + "mytable)", + ) + + with testing.expect_deprecated( + "coercing SELECT object to scalar " + "subquery in a column-expression context is deprecated" + ): + self.assert_compile( + select([table1.c.myid]).where(s > table1.c.myid), + "SELECT mytable.myid FROM mytable WHERE " + "mytable.myid < (SELECT mytable.myid FROM " + "mytable)", + ) + + with testing.expect_deprecated( + "coercing SELECT object to scalar " + "subquery in a column-expression context is deprecated" + ): + s = select([table1.c.myid]).alias() + self.assert_compile( + select([table1.c.myid]).where(table1.c.myid == s), + "SELECT mytable.myid FROM mytable WHERE " + "mytable.myid = (SELECT mytable.myid FROM " + "mytable)", + ) + + with testing.expect_deprecated( + "coercing SELECT object to scalar " + "subquery in a column-expression context is deprecated" + ): + self.assert_compile( + select([table1.c.myid]).where(s > table1.c.myid), + "SELECT mytable.myid FROM mytable WHERE " + "mytable.myid < (SELECT mytable.myid FROM " + "mytable)", + ) + + def test_as_scalar(self): + with testing.expect_deprecated( + r"The SelectBase.as_scalar\(\) method is deprecated and " + "will be removed in a future release." + ): + stmt = select([table1.c.myid]).as_scalar() + + is_true(stmt.compare(select([table1.c.myid]).scalar_subquery())) + + class TextTest(fixtures.TestBase, AssertsCompiledSQL): __dialect__ = "default" @@ -425,3 +540,11 @@ class TextTest(fixtures.TestBase, AssertsCompiledSQL): "The text.autocommit parameter is deprecated" ): t = text("select id, name from user", autocommit=True) + + +table1 = table( + "mytable", + column("myid", Integer), + column("name", String), + column("description", String), +) diff --git a/test/sql/test_generative.py b/test/sql/test_generative.py index 9b902e8ffd..da139d7c04 100644 --- a/test/sql/test_generative.py +++ b/test/sql/test_generative.py @@ -22,7 +22,7 @@ from sqlalchemy.sql import operators from sqlalchemy.sql import table from sqlalchemy.sql import util as sql_util from sqlalchemy.sql import visitors -from sqlalchemy.sql.expression import _clone +from sqlalchemy.sql.elements import _clone from sqlalchemy.sql.expression import _from_objects from sqlalchemy.sql.visitors import ClauseVisitor from sqlalchemy.sql.visitors import cloned_traverse @@ -306,7 +306,7 @@ class BinaryEndpointTraversalTest(fixtures.TestBase): def test_subquery(self): a, b, c = column("a"), column("b"), column("c") - subq = select([c]).where(c == a).as_scalar() + subq = select([c]).where(c == a).scalar_subquery() expr = and_(a == b, b == subq) self._assert_traversal( expr, [(operators.eq, a, b), (operators.eq, b, subq)] @@ -706,7 +706,9 @@ class ClauseTest(fixtures.TestBase, AssertsCompiledSQL): select.append_whereclause(t1.c.col2 == 7) self.assert_compile( - select([t2]).where(t2.c.col1 == Vis().traverse(s)), + select([t2]).where( + t2.c.col1 == Vis().traverse(s).scalar_subquery() + ), "SELECT table2.col1, table2.col2, table2.col3 " "FROM table2 WHERE table2.col1 = " "(SELECT * FROM table1 WHERE table1.col1 = table2.col1 " @@ -739,7 +741,7 @@ class ClauseTest(fixtures.TestBase, AssertsCompiledSQL): t1a = t1.alias() s = select([1], t1.c.col1 == t1a.c.col1, from_obj=t1a).correlate(t1a) - s = select([t1]).where(t1.c.col1 == s) + s = select([t1]).where(t1.c.col1 == s.scalar_subquery()) self.assert_compile( s, "SELECT table1.col1, table1.col2, table1.col3 FROM table1 " @@ -760,7 +762,7 @@ class ClauseTest(fixtures.TestBase, AssertsCompiledSQL): s = select([t1]).where(t1.c.col1 == "foo").alias() s2 = select([1], t1.c.col1 == s.c.col1, from_obj=s).correlate(t1) - s3 = select([t1]).where(t1.c.col1 == s2) + s3 = select([t1]).where(t1.c.col1 == s2.scalar_subquery()) self.assert_compile( s3, "SELECT table1.col1, table1.col2, table1.col3 " @@ -982,7 +984,7 @@ class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL): s = select( [literal_column("*")], from_obj=[t1alias, t2alias] - ).as_scalar() + ).scalar_subquery() assert t2alias in s._froms assert t1alias in s._froms @@ -1011,7 +1013,7 @@ class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL): s = ( select([literal_column("*")], from_obj=[t1alias, t2alias]) .correlate(t2alias) - .as_scalar() + .scalar_subquery() ) self.assert_compile( select([literal_column("*")], t2alias.c.col1 == s), @@ -1037,7 +1039,7 @@ class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL): s = ( select([literal_column("*")]) .where(t1.c.col1 == t2.c.col1) - .as_scalar() + .scalar_subquery() ) self.assert_compile( select([t1.c.col1, s]), @@ -1064,7 +1066,7 @@ class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL): select([literal_column("*")]) .where(t1.c.col1 == t2.c.col1) .correlate(t1) - .as_scalar() + .scalar_subquery() ) self.assert_compile( select([t1.c.col1, s]), @@ -1102,7 +1104,7 @@ class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL): select([t2.c.col1]) .where(t2.c.col1 == t1.c.col1) .correlate(t2) - .as_scalar() + .scalar_subquery() ) # test subquery - given only t1 and t2 in the enclosing selectable, @@ -1112,7 +1114,7 @@ class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL): select([t2.c.col1]) .where(t2.c.col1 == t1.c.col1) .correlate_except(t1) - .as_scalar() + .scalar_subquery() ) # use both subqueries in statements @@ -1257,7 +1259,9 @@ class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL): [literal_column("*")], t1.c.col1 == t2.c.col2, from_obj=[t1, t2], - ).correlate(t1) + ) + .correlate(t1) + .scalar_subquery() ) ), "SELECT t1alias.col1, t1alias.col2, t1alias.col3, " @@ -1277,7 +1281,9 @@ class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL): [literal_column("*")], t1.c.col1 == t2.c.col2, from_obj=[t1, t2], - ).correlate(t2) + ) + .correlate(t2) + .scalar_subquery() ) ), "SELECT t1alias.col1, t1alias.col2, t1alias.col3, " @@ -1381,9 +1387,9 @@ class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL): select([t1alias, t2alias]).where( t1alias.c.col1 == vis.traverse( - select( - ["*"], t1.c.col1 == t2.c.col2, from_obj=[t1, t2] - ).correlate(t1) + select(["*"], t1.c.col1 == t2.c.col2, from_obj=[t1, t2]) + .correlate(t1) + .scalar_subquery() ) ), "SELECT t1alias.col1, t1alias.col2, t1alias.col3, " @@ -1403,9 +1409,9 @@ class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL): t2alias.select().where( t2alias.c.col2 == vis.traverse( - select( - ["*"], t1.c.col1 == t2.c.col2, from_obj=[t1, t2] - ).correlate(t2) + select(["*"], t1.c.col1 == t2.c.col2, from_obj=[t1, t2]) + .correlate(t2) + .scalar_subquery() ) ), "SELECT t2alias.col1, t2alias.col2, t2alias.col3 " diff --git a/test/sql/test_inspect.py b/test/sql/test_inspect.py index 0b7aa7a555..d2a2c1c484 100644 --- a/test/sql/test_inspect.py +++ b/test/sql/test_inspect.py @@ -5,6 +5,7 @@ from sqlalchemy import inspect from sqlalchemy import Integer from sqlalchemy import MetaData from sqlalchemy import Table +from sqlalchemy.sql import ClauseElement from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ @@ -37,5 +38,13 @@ class TestCoreInspection(fixtures.TestBase): # absence of __clause_element__ as a test for "this is the clause # element" must be maintained + class Foo(ClauseElement): + pass + + assert not hasattr(Foo(), "__clause_element__") + + def test_col_now_has_a_clauseelement(self): + x = Column("foo", Integer) - assert not hasattr(x, "__clause_element__") + + assert hasattr(x, "__clause_element__") diff --git a/test/sql/test_join_rewriting.py b/test/sql/test_join_rewriting.py index e91557cd15..b9bcfc16a1 100644 --- a/test/sql/test_join_rewriting.py +++ b/test/sql/test_join_rewriting.py @@ -259,8 +259,8 @@ class _JoinRewriteTestBase(AssertsCompiledSQL): self._test(s, self._f_b1a_where_in_b2a) def test_anon_scalar_subqueries(self): - s1 = select([1]).as_scalar() - s2 = select([2]).as_scalar() + s1 = select([1]).scalar_subquery() + s2 = select([2]).scalar_subquery() s = select([s1, s2]).apply_labels() self._test(s, self._anon_scalar_subqueries) diff --git a/test/sql/test_metadata.py b/test/sql/test_metadata.py index 3d60fb60ed..3f96676096 100644 --- a/test/sql/test_metadata.py +++ b/test/sql/test_metadata.py @@ -31,7 +31,6 @@ from sqlalchemy import Unicode from sqlalchemy import UniqueConstraint from sqlalchemy import util from sqlalchemy.engine import default -from sqlalchemy.sql import elements from sqlalchemy.sql import naming from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_message @@ -3378,7 +3377,8 @@ class ConstraintTest(fixtures.TestBase): assert_raises_message( exc.ArgumentError, - r"Element Table\('t2', .* is not a string name or column element", + r"String column name or column object for DDL constraint " + r"expected, got .*SomeClass", Index, "foo", SomeClass(), @@ -4609,7 +4609,7 @@ class NamingConventionTest(fixtures.TestBase, AssertsCompiledSQL): u1 = self._fixture( naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"} ) - ck = CheckConstraint(u1.c.data == "x", name=elements._defer_name(None)) + ck = CheckConstraint(u1.c.data == "x", name=naming._defer_name(None)) assert_raises_message( exc.InvalidRequestError, diff --git a/test/sql/test_operators.py b/test/sql/test_operators.py index c6eff6ac93..f85a601bab 100644 --- a/test/sql/test_operators.py +++ b/test/sql/test_operators.py @@ -26,6 +26,7 @@ from sqlalchemy.schema import Table from sqlalchemy.sql import all_ from sqlalchemy.sql import any_ from sqlalchemy.sql import asc +from sqlalchemy.sql import coercions from sqlalchemy.sql import collate from sqlalchemy.sql import column from sqlalchemy.sql import compiler @@ -34,10 +35,10 @@ from sqlalchemy.sql import false from sqlalchemy.sql import literal from sqlalchemy.sql import null from sqlalchemy.sql import operators +from sqlalchemy.sql import roles from sqlalchemy.sql import sqltypes from sqlalchemy.sql import table from sqlalchemy.sql import true -from sqlalchemy.sql.elements import _literal_as_text from sqlalchemy.sql.elements import BindParameter from sqlalchemy.sql.elements import Label from sqlalchemy.sql.expression import BinaryExpression @@ -82,13 +83,17 @@ class DefaultColumnComparatorTest(fixtures.TestBase): assert left.comparator.operate(operator, right).compare( BinaryExpression( - _literal_as_text(left), _literal_as_text(right), operator + coercions.expect(roles.WhereHavingRole, left), + coercions.expect(roles.WhereHavingRole, right), + operator, ) ) assert operator(left, right).compare( BinaryExpression( - _literal_as_text(left), _literal_as_text(right), operator + coercions.expect(roles.WhereHavingRole, left), + coercions.expect(roles.WhereHavingRole, right), + operator, ) ) @@ -227,8 +232,9 @@ class DefaultColumnComparatorTest(fixtures.TestBase): left = column("left") foo = ClauseList() assert_raises_message( - exc.InvalidRequestError, - r"in_\(\) accepts either a list of expressions, a selectable", + exc.ArgumentError, + r"IN expression list, SELECT construct, or bound parameter " + r"object expected, got .*ClauseList", left.in_, [foo], ) @@ -237,8 +243,9 @@ class DefaultColumnComparatorTest(fixtures.TestBase): left = column("left") right = column("right") assert_raises_message( - exc.InvalidRequestError, - r"in_\(\) accepts either a list of expressions, a selectable", + exc.ArgumentError, + r"IN expression list, SELECT construct, or bound parameter " + r"object expected, got .*ColumnClause", left.in_, right, ) @@ -253,8 +260,9 @@ class DefaultColumnComparatorTest(fixtures.TestBase): left = column("left") right = column("right", HasGetitem) assert_raises_message( - exc.InvalidRequestError, - r"in_\(\) accepts either a list of expressions, a selectable", + exc.ArgumentError, + r"IN expression list, SELECT construct, or bound parameter " + r"object expected, got .*ColumnClause", left.in_, right, ) @@ -1680,7 +1688,7 @@ class InTest(fixtures.TestBase, testing.AssertsCompiledSQL): select( [ self.table1.c.myid.in_( - select([self.table2.c.otherid]).as_scalar() + select([self.table2.c.otherid]).scalar_subquery() ) ] ), @@ -1738,6 +1746,29 @@ class InTest(fixtures.TestBase, testing.AssertsCompiledSQL): self.table1.c.myid.in_([None]), "mytable.myid IN (NULL)" ) + def test_in_set(self): + self.assert_compile( + self.table1.c.myid.in_({1, 2, 3}), + "mytable.myid IN (:myid_1, :myid_2, :myid_3)", + ) + + def test_in_arbitrary_sequence(self): + class MySeq(object): + def __init__(self, d): + self.d = d + + def __getitem__(self, idx): + return self.d[idx] + + def __iter__(self): + return iter(self.d) + + seq = MySeq([1, 2, 3]) + self.assert_compile( + self.table1.c.myid.in_(seq), + "mytable.myid IN (:myid_1, :myid_2, :myid_3)", + ) + def test_empty_in_dynamic_1(self): self.assert_compile( self.table1.c.myid.in_([]), @@ -2073,7 +2104,9 @@ class NegationTest(fixtures.TestBase, testing.AssertsCompiledSQL): assert not (self.table1.c.myid + 5)._is_implicitly_boolean assert not not_(column("x", Boolean))._is_implicitly_boolean assert ( - not select([self.table1.c.myid]).as_scalar()._is_implicitly_boolean + not select([self.table1.c.myid]) + .scalar_subquery() + ._is_implicitly_boolean ) assert not text("x = y")._is_implicitly_boolean assert not literal_column("x = y")._is_implicitly_boolean @@ -2869,7 +2902,8 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL): t = self._fixture() self.assert_compile( - 5 == any_(select([t.c.data]).where(t.c.data < 10)), + 5 + == any_(select([t.c.data]).where(t.c.data < 10).scalar_subquery()), ":param_1 = ANY (SELECT tab1.data " "FROM tab1 WHERE tab1.data < :data_1)", checkparams={"data_1": 10, "param_1": 5}, @@ -2879,7 +2913,11 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL): t = self._fixture() self.assert_compile( - 5 == select([t.c.data]).where(t.c.data < 10).as_scalar().any_(), + 5 + == select([t.c.data]) + .where(t.c.data < 10) + .scalar_subquery() + .any_(), ":param_1 = ANY (SELECT tab1.data " "FROM tab1 WHERE tab1.data < :data_1)", checkparams={"data_1": 10, "param_1": 5}, @@ -2889,7 +2927,8 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL): t = self._fixture() self.assert_compile( - 5 == all_(select([t.c.data]).where(t.c.data < 10)), + 5 + == all_(select([t.c.data]).where(t.c.data < 10).scalar_subquery()), ":param_1 = ALL (SELECT tab1.data " "FROM tab1 WHERE tab1.data < :data_1)", checkparams={"data_1": 10, "param_1": 5}, @@ -2899,7 +2938,11 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL): t = self._fixture() self.assert_compile( - 5 == select([t.c.data]).where(t.c.data < 10).as_scalar().all_(), + 5 + == select([t.c.data]) + .where(t.c.data < 10) + .scalar_subquery() + .all_(), ":param_1 = ALL (SELECT tab1.data " "FROM tab1 WHERE tab1.data < :data_1)", checkparams={"data_1": 10, "param_1": 5}, diff --git a/test/sql/test_resultset.py b/test/sql/test_resultset.py index 5987c77467..6e48374ca3 100644 --- a/test/sql/test_resultset.py +++ b/test/sql/test_resultset.py @@ -121,7 +121,7 @@ class ResultProxyTest(fixtures.TablesTest): sel = ( select([users.c.user_id]) .where(users.c.user_name == "jack") - .as_scalar() + .scalar_subquery() ) for row in select([sel + 1, sel + 3], bind=users.bind).execute(): eq_(row["anon_1"], 8) diff --git a/test/sql/test_returning.py b/test/sql/test_returning.py index aa10626c5c..ad1eb33480 100644 --- a/test/sql/test_returning.py +++ b/test/sql/test_returning.py @@ -222,7 +222,7 @@ class CompositeStatementTest(fixtures.TestBase): stmt = ( t2.insert() - .values(x=select([t1.c.x]).as_scalar()) + .values(x=select([t1.c.x]).scalar_subquery()) .returning(t2.c.x) ) diff --git a/test/sql/test_roles.py b/test/sql/test_roles.py new file mode 100644 index 0000000000..81934de241 --- /dev/null +++ b/test/sql/test_roles.py @@ -0,0 +1,218 @@ +from sqlalchemy import Column +from sqlalchemy import exc +from sqlalchemy import Integer +from sqlalchemy import MetaData +from sqlalchemy import select +from sqlalchemy import Table +from sqlalchemy import text +from sqlalchemy.schema import DDL +from sqlalchemy.schema import Sequence +from sqlalchemy.sql import ClauseElement +from sqlalchemy.sql import coercions +from sqlalchemy.sql import column +from sqlalchemy.sql import false +from sqlalchemy.sql import False_ +from sqlalchemy.sql import literal +from sqlalchemy.sql import roles +from sqlalchemy.sql import true +from sqlalchemy.sql import True_ +from sqlalchemy.sql.coercions import expect +from sqlalchemy.sql.elements import _truncated_label +from sqlalchemy.sql.elements import Null +from sqlalchemy.testing import assert_raises_message +from sqlalchemy.testing import fixtures +from sqlalchemy.testing import is_ +from sqlalchemy.testing import is_instance_of +from sqlalchemy.testing import is_true + +m = MetaData() + +t = Table("t", m, Column("q", Integer)) + + +class NotAThing1(object): + pass + + +not_a_thing1 = NotAThing1() + + +class NotAThing2(ClauseElement): + pass + + +not_a_thing2 = NotAThing2() + + +class NotAThing3(object): + def __clause_element__(self): + return not_a_thing2 + + +not_a_thing3 = NotAThing3() + + +class RoleTest(fixtures.TestBase): + # TODO: the individual role tests here are incomplete. The functionality + # of each role is covered by other tests in the sql testing suite however + # ideally they'd all have direct tests here as well. + + def _test_role_neg_comparisons(self, role): + impl = coercions._impl_lookup[role] + role_name = impl.name + + assert_raises_message( + exc.ArgumentError, + r"%s expected, got .*NotAThing1" % role_name, + expect, + role, + not_a_thing1, + ) + + assert_raises_message( + exc.ArgumentError, + r"%s expected, got .*NotAThing2" % role_name, + expect, + role, + not_a_thing2, + ) + + assert_raises_message( + exc.ArgumentError, + r"%s expected, got .*NotAThing3" % role_name, + expect, + role, + not_a_thing3, + ) + + assert_raises_message( + exc.ArgumentError, + r"%s expected for argument 'foo'; got .*NotAThing3" % role_name, + expect, + role, + not_a_thing3, + argname="foo", + ) + + def test_const_expr_role(self): + t = true() + is_(expect(roles.ConstExprRole, t), t) + + f = false() + is_(expect(roles.ConstExprRole, f), f) + + is_instance_of(expect(roles.ConstExprRole, True), True_) + + is_instance_of(expect(roles.ConstExprRole, False), False_) + + is_instance_of(expect(roles.ConstExprRole, None), Null) + + def test_truncated_label_role(self): + is_instance_of( + expect(roles.TruncatedLabelRole, "foobar"), _truncated_label + ) + + def test_labeled_column_expr_role(self): + c = column("q") + is_true(expect(roles.LabeledColumnExprRole, c).compare(c)) + + is_true( + expect(roles.LabeledColumnExprRole, c.label("foo")).compare( + c.label("foo") + ) + ) + + is_true( + expect( + roles.LabeledColumnExprRole, + select([column("q")]).scalar_subquery(), + ).compare(select([column("q")]).label(None)) + ) + + is_true( + expect(roles.LabeledColumnExprRole, not_a_thing1).compare( + literal(not_a_thing1).label(None) + ) + ) + + def test_scalar_select_no_coercion(self): + # this is also tested in test/sql/test_deprecations.py; when the + # deprecation is turned to an error, those tests go away, and these + # will assert the correct exception plus informative error message. + assert_raises_message( + exc.SADeprecationWarning, + "coercing SELECT object to scalar subquery in a column-expression " + "context is deprecated", + expect, + roles.LabeledColumnExprRole, + select([column("q")]), + ) + + assert_raises_message( + exc.SADeprecationWarning, + "coercing SELECT object to scalar subquery in a column-expression " + "context is deprecated", + expect, + roles.LabeledColumnExprRole, + select([column("q")]).alias(), + ) + + def test_statement_no_text_coercion(self): + assert_raises_message( + exc.ArgumentError, + r"Textual SQL expression 'select \* from table' should be " + r"explicitly declared", + expect, + roles.StatementRole, + "select * from table", + ) + + def test_statement_text_coercion(self): + is_true( + expect( + roles.CoerceTextStatementRole, "select * from table" + ).compare(text("select * from table")) + ) + + def test_select_statement_no_text_coercion(self): + assert_raises_message( + exc.ArgumentError, + r"Textual SQL expression 'select \* from table' should be " + r"explicitly declared", + expect, + roles.SelectStatementRole, + "select * from table", + ) + + def test_statement_coercion_select(self): + is_true( + expect(roles.CoerceTextStatementRole, select([t])).compare( + select([t]) + ) + ) + + def test_statement_coercion_ddl(self): + d1 = DDL("hi") + is_(expect(roles.CoerceTextStatementRole, d1), d1) + + def test_statement_coercion_sequence(self): + s1 = Sequence("hi") + is_(expect(roles.CoerceTextStatementRole, s1), s1) + + def test_columns_clause_role(self): + is_(expect(roles.ColumnsClauseRole, t.c.q), t.c.q) + + def test_truncated_label_role_neg(self): + self._test_role_neg_comparisons(roles.TruncatedLabelRole) + + def test_where_having_role_neg(self): + self._test_role_neg_comparisons(roles.WhereHavingRole) + + def test_by_of_role_neg(self): + self._test_role_neg_comparisons(roles.ByOfRole) + + def test_const_expr_role_neg(self): + self._test_role_neg_comparisons(roles.ConstExprRole) + + def test_columns_clause_role_neg(self): + self._test_role_neg_comparisons(roles.ColumnsClauseRole) diff --git a/test/sql/test_selectable.py b/test/sql/test_selectable.py index a4d3e1b406..f88243fc27 100644 --- a/test/sql/test_selectable.py +++ b/test/sql/test_selectable.py @@ -31,7 +31,6 @@ from sqlalchemy import util from sqlalchemy.sql import Alias from sqlalchemy.sql import column from sqlalchemy.sql import elements -from sqlalchemy.sql import expression from sqlalchemy.sql import table from sqlalchemy.sql import util as sql_util from sqlalchemy.sql import visitors @@ -221,7 +220,7 @@ class SelectableTest( s3c1 = s3._clone() eq_( - expression._cloned_intersection([s1c1, s3c1], [s2c1, s1c2]), + elements._cloned_intersection([s1c1, s3c1], [s2c1, s1c2]), set([s1c1]), ) @@ -240,7 +239,7 @@ class SelectableTest( s3c1 = s3._clone() eq_( - expression._cloned_difference([s1c1, s2c1, s3c1], [s2c1, s1c2]), + elements._cloned_difference([s1c1, s2c1, s3c1], [s2c1, s1c2]), set([s3c1]), ) @@ -313,7 +312,7 @@ class SelectableTest( assert sel3.corresponding_column(col) is sel3.c.foo def test_with_only_generative(self): - s1 = table1.select().as_scalar() + s1 = table1.select().scalar_subquery() self.assert_compile( s1.with_only_columns([s1]), "SELECT (SELECT table1.col1, table1.col2, " @@ -365,12 +364,18 @@ class SelectableTest( criterion = a.c.col1 == table2.c.col2 self.assert_(criterion.compare(j.onclause)) + @testing.fails("not supported with rework, need a new approach") def test_alias_handles_column_context(self): # not quite a use case yet but this is expected to become # prominent w/ PostgreSQL's tuple functions stmt = select([table1.c.col1, table1.c.col2]) a = stmt.alias("a") + + # TODO: this case is crazy, sending SELECT or FROMCLAUSE has to + # be figured out - is it a scalar row query? what kinds of + # statements go into functions in PG. seems likely select statment, + # but not alias, subquery or other FROM object self.assert_compile( select([func.foo(a)]), "SELECT foo(SELECT table1.col1, table1.col2 FROM table1) " @@ -652,7 +657,7 @@ class SelectableTest( self.assert_(criterion.compare(j.onclause)) def test_scalar_cloned_comparator(self): - sel = select([table1.c.col1]).as_scalar() + sel = select([table1.c.col1]).scalar_subquery() expr = sel == table1.c.col1 sel2 = visitors.ReplacingCloningVisitor().traverse(sel) @@ -2535,7 +2540,7 @@ class ResultMapTest(fixtures.TestBase): def test_column_subquery_plain(self): t = self._fixture() - s1 = select([t.c.x]).where(t.c.x > 5).as_scalar() + s1 = select([t.c.x]).where(t.c.x > 5).scalar_subquery() s2 = select([s1]) mapping = self._mapping(s2) assert t.c.x not in mapping diff --git a/test/sql/test_text.py b/test/sql/test_text.py index 188ac3878c..afc7755983 100644 --- a/test/sql/test_text.py +++ b/test/sql/test_text.py @@ -564,7 +564,7 @@ class AsFromTest(fixtures.TestBase, AssertsCompiledSQL): def test_scalar_subquery(self): t = text("select id from user").columns(id=Integer) - subq = t.as_scalar() + subq = t.scalar_subquery() assert subq.type._type_affinity is Integer()._type_affinity diff --git a/test/sql/test_types.py b/test/sql/test_types.py index a1b1f024b9..a5c9313f80 100644 --- a/test/sql/test_types.py +++ b/test/sql/test_types.py @@ -2608,7 +2608,8 @@ class ExpressionTest( assert_raises_message( exc.ArgumentError, - r"Object some_sqla_thing\(\) is not legal as a SQL literal value", + r"SQL expression element or literal value expected, got " + r"some_sqla_thing\(\).", lambda: column("a", String) == SomeSQLAThing(), ) diff --git a/test/sql/test_update.py b/test/sql/test_update.py index 514076daab..9309ca45a9 100644 --- a/test/sql/test_update.py +++ b/test/sql/test_update.py @@ -179,7 +179,7 @@ class UpdateTest(_UpdateFromTestBase, fixtures.TablesTest, AssertsCompiledSQL): # test a non-correlated WHERE clause s = select([table2.c.othername], table2.c.otherid == 7) - u = update(table1, table1.c.name == s) + u = update(table1, table1.c.name == s.scalar_subquery()) self.assert_compile( u, "UPDATE mytable SET myid=:myid, name=:name, " @@ -194,7 +194,7 @@ class UpdateTest(_UpdateFromTestBase, fixtures.TablesTest, AssertsCompiledSQL): # test one that is actually correlated... s = select([table2.c.othername], table2.c.otherid == table1.c.myid) - u = table1.update(table1.c.name == s) + u = table1.update(table1.c.name == s.scalar_subquery()) self.assert_compile( u, "UPDATE mytable SET myid=:myid, name=:name, "