]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Implement new ClauseElement role and coercion system
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 30 Apr 2019 03:26:36 +0000 (23:26 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 18 May 2019 21:46:10 +0000 (17:46 -0400)
A major refactoring of all the functions handle all detection of
Core argument types as well as perform coercions into a new class hierarchy
based on "roles", each of which identify a syntactical location within a
SQL statement.  In contrast to the ClauseElement hierarchy that identifies
"what" each object is syntactically, the SQLRole hierarchy identifies
the "where does it go" of each object syntactically.   From this we define
a consistent type checking and coercion system that establishes well
defined behviors.

This is a breakout of the patch that is reorganizing select()
constructs to no longer be in the FromClause hierarchy.

Also includes a rename of as_scalar() into scalar_subquery(); deprecates
automatic coercion to scalar_subquery().

Partially-fixes: #4617
Change-Id: I26f1e78898693c6b99ef7ea2f4e7dfd0e8e1a1bd

85 files changed:
doc/build/changelog/unreleased_14/4617_coercion.rst [new file with mode: 0644]
doc/build/changelog/unreleased_14/4617_scalar.rst [new file with mode: 0644]
doc/build/core/tutorial.rst
lib/sqlalchemy/dialects/mssql/base.py
lib/sqlalchemy/dialects/mysql/base.py
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/dialects/postgresql/ext.py
lib/sqlalchemy/engine/result.py
lib/sqlalchemy/ext/baked.py
lib/sqlalchemy/ext/declarative/clsregistry.py
lib/sqlalchemy/orm/attributes.py
lib/sqlalchemy/orm/base.py
lib/sqlalchemy/orm/collections.py
lib/sqlalchemy/orm/interfaces.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/persistence.py
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/orm/relationships.py
lib/sqlalchemy/orm/session.py
lib/sqlalchemy/orm/strategy_options.py
lib/sqlalchemy/orm/util.py
lib/sqlalchemy/sql/__init__.py
lib/sqlalchemy/sql/base.py
lib/sqlalchemy/sql/coercions.py [new file with mode: 0644]
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/crud.py
lib/sqlalchemy/sql/ddl.py
lib/sqlalchemy/sql/default_comparator.py
lib/sqlalchemy/sql/dml.py
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/expression.py
lib/sqlalchemy/sql/functions.py
lib/sqlalchemy/sql/operators.py
lib/sqlalchemy/sql/roles.py [new file with mode: 0644]
lib/sqlalchemy/sql/schema.py
lib/sqlalchemy/sql/selectable.py
lib/sqlalchemy/sql/sqltypes.py
lib/sqlalchemy/sql/type_api.py
lib/sqlalchemy/testing/__init__.py
lib/sqlalchemy/testing/assertions.py
lib/sqlalchemy/testing/suite/test_cte.py
lib/sqlalchemy/testing/suite/test_results.py
lib/sqlalchemy/util/__init__.py
lib/sqlalchemy/util/langhelpers.py
test/dialect/mssql/test_compiler.py
test/dialect/mssql/test_query.py
test/dialect/mysql/test_query.py
test/ext/declarative/test_basic.py
test/ext/declarative/test_mixin.py
test/ext/test_baked.py
test/orm/inheritance/test_assorted_poly.py
test/orm/inheritance/test_basic.py
test/orm/inheritance/test_polymorphic_rel.py
test/orm/inheritance/test_single.py
test/orm/test_collection.py
test/orm/test_deprecations.py
test/orm/test_eager_relations.py
test/orm/test_froms.py
test/orm/test_lazy_relations.py
test/orm/test_mapper.py
test/orm/test_query.py
test/orm/test_rel_fn.py
test/orm/test_relationships.py
test/orm/test_update_delete.py
test/profiles.txt
test/sql/test_case_statement.py
test/sql/test_compare.py
test/sql/test_compiler.py
test/sql/test_cte.py
test/sql/test_defaults.py
test/sql/test_delete.py
test/sql/test_deprecations.py
test/sql/test_generative.py
test/sql/test_inspect.py
test/sql/test_join_rewriting.py
test/sql/test_metadata.py
test/sql/test_operators.py
test/sql/test_resultset.py
test/sql/test_returning.py
test/sql/test_roles.py [new file with mode: 0644]
test/sql/test_selectable.py
test/sql/test_text.py
test/sql/test_types.py
test/sql/test_update.py

diff --git a/doc/build/changelog/unreleased_14/4617_coercion.rst b/doc/build/changelog/unreleased_14/4617_coercion.rst
new file mode 100644 (file)
index 0000000..93be6d5
--- /dev/null
@@ -0,0 +1,13 @@
+.. change::
+    :tags: sql, change
+    :tickets: 4617
+
+    The "clause coercion" system, which is SQLAlchemy Core's system of receiving
+    arguments and resolving them into :class:`.ClauseElement` structures in order
+    to build up SQL expression objects, has been rewritten from a series of
+    ad-hoc functions to a fully consistent class-based system.   This change
+    is internal and should have no impact on end users other than more specific
+    error messages when the wrong kind of argument is passed to an expression
+    object, however the change is part of a larger set of changes involving
+    the role and behavior of :func:`.select` objects.
+
diff --git a/doc/build/changelog/unreleased_14/4617_scalar.rst b/doc/build/changelog/unreleased_14/4617_scalar.rst
new file mode 100644 (file)
index 0000000..3f22414
--- /dev/null
@@ -0,0 +1,19 @@
+.. change::
+    :tags: change, sql
+    :tickets: 4617
+
+    The :meth:`.SelectBase.as_scalar` and :meth:`.Query.as_scalar` methods have
+    been renamed to :meth:`.SelectBase.scalar_subquery` and :meth:`.Query.scalar_subquery`,
+    respectively.  The old names continue to exist within 1.4 series with a deprecation
+    warning.  In addition, the implicit coercion of :class:`.SelectBase`, :class:`.Alias`,
+    and other SELECT oriented objects into scalar subqueries when evaluated in a column
+    context is also deprecated, and emits a warning that the :meth:`.SelectBase.scalar_subquery`
+    method should be called explicitly.  This warning will in a later major release
+    become an error, however the message will always be clear when :meth:`.SelectBase.scalar_subquery`
+    needs to be invoked.   The latter part of the change is for clarity and to reduce the
+    implicit decisionmaking by the query coercion system.
+
+    This change is part of the larger change to convert :func:`.select` objects to no
+    longer be directly part of the "from clause" class hierarchy, which also includes
+    an overhaul of the clause coercion system.
+
index cdacedadfe772503a821a327d60a856e7af31ccb..4409996bdfaa49faadded78eb1e40f0f219267c7 100644 (file)
@@ -1590,24 +1590,24 @@ is often a :term:`correlated subquery`, which relies upon the enclosing
 SELECT statement in order to acquire at least one of its FROM clauses.
 
 The :func:`.select` construct can be modified to act as a
-column expression by calling either the :meth:`~.SelectBase.as_scalar`
+column expression by calling either the :meth:`~.SelectBase.scalar_subquery`
 or :meth:`~.SelectBase.label` method:
 
 .. sourcecode:: pycon+sql
 
-    >>> stmt = select([func.count(addresses.c.id)]).\
+    >>> subq = select([func.count(addresses.c.id)]).\
     ...             where(users.c.id == addresses.c.user_id).\
-    ...             as_scalar()
+    ...             scalar_subquery()
 
 The above construct is now a :class:`~.expression.ScalarSelect` object,
-and is no longer part of the :class:`~.expression.FromClause` hierarchy;
-it instead is within the :class:`~.expression.ColumnElement` family of
-expression constructs.  We can place this construct the same as any
+which is an adapter around the original :class:`.~expression.Select`
+object; it participates within the :class:`~.expression.ColumnElement`
+family of expression constructs.  We can place this construct the same as any
 other column within another :func:`.select`:
 
 .. sourcecode:: pycon+sql
 
-    >>> conn.execute(select([users.c.name, stmt])).fetchall()
+    >>> conn.execute(select([users.c.name, subq])).fetchall()
     {opensql}SELECT users.name, (SELECT count(addresses.id) AS count_1
     FROM addresses
     WHERE users.id = addresses.user_id) AS anon_1
@@ -1620,10 +1620,10 @@ it using :meth:`.SelectBase.label` instead:
 
 .. sourcecode:: pycon+sql
 
-    >>> stmt = select([func.count(addresses.c.id)]).\
+    >>> subq = select([func.count(addresses.c.id)]).\
     ...             where(users.c.id == addresses.c.user_id).\
     ...             label("address_count")
-    >>> conn.execute(select([users.c.name, stmt])).fetchall()
+    >>> conn.execute(select([users.c.name, subq])).fetchall()
     {opensql}SELECT users.name, (SELECT count(addresses.id) AS count_1
     FROM addresses
     WHERE users.id = addresses.user_id) AS address_count
@@ -1633,7 +1633,7 @@ it using :meth:`.SelectBase.label` instead:
 
 .. seealso::
 
-    :meth:`.Select.as_scalar`
+    :meth:`.Select.scalar_subquery`
 
     :meth:`.Select.label`
 
@@ -1642,7 +1642,7 @@ it using :meth:`.SelectBase.label` instead:
 Correlated Subqueries
 ---------------------
 
-Notice in the examples on :ref:`scalar_selects`, the FROM clause of each embedded
+In the examples on :ref:`scalar_selects`, the FROM clause of each embedded
 select did not contain the ``users`` table in its FROM clause. This is because
 SQLAlchemy automatically :term:`correlates` embedded FROM objects to that
 of an enclosing query, if present, and if the inner SELECT statement would
@@ -1653,7 +1653,8 @@ still have at least one FROM clause of its own.  For example:
     >>> stmt = select([addresses.c.user_id]).\
     ...             where(addresses.c.user_id == users.c.id).\
     ...             where(addresses.c.email_address == 'jack@yahoo.com')
-    >>> enclosing_stmt = select([users.c.name]).where(users.c.id == stmt)
+    >>> enclosing_stmt = select([users.c.name]).\
+    ...             where(users.c.id == stmt.scalar_subquery())
     >>> conn.execute(enclosing_stmt).fetchall()
     {opensql}SELECT users.name
     FROM users
@@ -1679,7 +1680,7 @@ may be correlated:
     >>> enclosing_stmt = select(
     ...         [users.c.name, addresses.c.email_address]).\
     ...     select_from(users.join(addresses)).\
-    ...     where(users.c.id == stmt)
+    ...     where(users.c.id == stmt.scalar_subquery())
     >>> conn.execute(enclosing_stmt).fetchall()
     {opensql}SELECT users.name, addresses.email_address
      FROM users JOIN addresses ON users.id = addresses.user_id
@@ -1698,7 +1699,7 @@ as the argument:
     ...             where(users.c.name == 'wendy').\
     ...             correlate(None)
     >>> enclosing_stmt = select([users.c.name]).\
-    ...     where(users.c.id == stmt)
+    ...     where(users.c.id == stmt.scalar_subquery())
     >>> conn.execute(enclosing_stmt).fetchall()
     {opensql}SELECT users.name
      FROM users
@@ -1721,7 +1722,7 @@ by telling it to correlate all FROM clauses except for ``users``:
     >>> enclosing_stmt = select(
     ...         [users.c.name, addresses.c.email_address]).\
     ...     select_from(users.join(addresses)).\
-    ...     where(users.c.id == stmt)
+    ...     where(users.c.id == stmt.scalar_subquery())
     >>> conn.execute(enclosing_stmt).fetchall()
     {opensql}SELECT users.name, addresses.email_address
      FROM users JOIN addresses ON users.id = addresses.user_id
index d2c84f446aa39fd7a30dedcbfc585eca124f2280..00a110aa22b2c9823c8f9c422cb2e00202ed666c 100644 (file)
@@ -672,6 +672,7 @@ from ... import util
 from ...engine import default
 from ...engine import reflection
 from ...sql import compiler
+from ...sql import elements
 from ...sql import expression
 from ...sql import quoted_name
 from ...sql import util as sql_util
@@ -1671,9 +1672,7 @@ class MSSQLCompiler(compiler.SQLCompiler):
             # translate for schema-qualified table aliases
             t = self._schema_aliased_table(column.table)
             if t is not None:
-                converted = expression._corresponding_column_or_error(
-                    t, column
-                )
+                converted = elements._corresponding_column_or_error(t, column)
                 if add_to_result_map is not None:
                     add_to_result_map(
                         column.name,
index 44f90c47cda60c87d7fa5fa10e178960ecf248c8..9cae3c689fe32c5a3b36c7c0cfb401f5245631d4 100644 (file)
@@ -788,8 +788,10 @@ from ... import types as sqltypes
 from ... import util
 from ...engine import default
 from ...engine import reflection
+from ...sql import coercions
 from ...sql import compiler
 from ...sql import elements
+from ...sql import roles
 from ...types import BINARY
 from ...types import BLOB
 from ...types import BOOLEAN
@@ -1218,7 +1220,7 @@ class MySQLCompiler(compiler.SQLCompiler):
     def visit_on_duplicate_key_update(self, on_duplicate, **kw):
         if on_duplicate._parameter_ordering:
             parameter_ordering = [
-                elements._column_as_key(key)
+                coercions.expect(roles.DMLColumnRole, key)
                 for key in on_duplicate._parameter_ordering
             ]
             ordered_keys = set(parameter_ordering)
@@ -1238,7 +1240,7 @@ class MySQLCompiler(compiler.SQLCompiler):
             val = on_duplicate.update.get(column.key)
             if val is None:
                 continue
-            elif elements._is_literal(val):
+            elif coercions._is_literal(val):
                 val = elements.BindParameter(None, val, type_=column.type)
                 value_text = self.process(val.self_group(), use_schema=False)
             elif isinstance(val, elements.BindParameter) and val.type._isnull:
index ceb62464429bbc5625cd472e356cdd425facadb2..f18bec932b4f772a343ae66f47eee94ee4a182bb 100644 (file)
@@ -932,9 +932,11 @@ from ... import sql
 from ... import util
 from ...engine import default
 from ...engine import reflection
+from ...sql import coercions
 from ...sql import compiler
 from ...sql import elements
 from ...sql import expression
+from ...sql import roles
 from ...sql import sqltypes
 from ...sql import util as sql_util
 from ...types import BIGINT
@@ -1774,7 +1776,7 @@ class PGCompiler(compiler.SQLCompiler):
             col_key = c.key
             if col_key in set_parameters:
                 value = set_parameters.pop(col_key)
-                if elements._is_literal(value):
+                if coercions._is_literal(value):
                     value = elements.BindParameter(None, value, type_=c.type)
 
                 else:
@@ -1806,7 +1808,8 @@ class PGCompiler(compiler.SQLCompiler):
                     else self.process(k, use_schema=False)
                 )
                 value_text = self.process(
-                    elements._literal_as_binds(v), use_schema=False
+                    coercions.expect(roles.ExpressionElementRole, v),
+                    use_schema=False,
                 )
                 action_set_ops.append("%s = %s" % (key_text, value_text))
 
index 42602823971f81018d32c8ffeff7451a2cf22698..f9cbc945acdb58b8f151072e305734a10a3d1dee 100644 (file)
@@ -6,9 +6,12 @@
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
 
 from .array import ARRAY
+from ... import util
+from ...sql import coercions
 from ...sql import elements
 from ...sql import expression
 from ...sql import functions
+from ...sql import roles
 from ...sql.schema import ColumnCollectionConstraint
 
 
@@ -50,16 +53,18 @@ class aggregate_order_by(expression.ColumnElement):
     __visit_name__ = "aggregate_order_by"
 
     def __init__(self, target, *order_by):
-        self.target = elements._literal_as_binds(target)
+        self.target = coercions.expect(roles.ExpressionElementRole, target)
 
         _lob = len(order_by)
         if _lob == 0:
             raise TypeError("at least one ORDER BY element is required")
         elif _lob == 1:
-            self.order_by = elements._literal_as_binds(order_by[0])
+            self.order_by = coercions.expect(
+                roles.ExpressionElementRole, order_by[0]
+            )
         else:
             self.order_by = elements.ClauseList(
-                *order_by, _literal_as_text=elements._literal_as_binds
+                *order_by, _literal_as_text_role=roles.ExpressionElementRole
             )
 
     def self_group(self, against=None):
@@ -166,7 +171,10 @@ class ExcludeConstraint(ColumnCollectionConstraint):
         expressions, operators = zip(*elements)
 
         for (expr, column, strname, add_element), operator in zip(
-            self._extract_col_expression_collection(expressions), operators
+            coercions.expect_col_expression_collection(
+                roles.DDLConstraintColumnRole, expressions
+            ),
+            operators,
         ):
             if add_element is not None:
                 columns.append(add_element)
@@ -177,8 +185,6 @@ class ExcludeConstraint(ColumnCollectionConstraint):
                 # backwards compat
                 self.operators[name] = operator
 
-            expr = expression._literal_as_column(expr)
-
             render_exprs.append((expr, name, operator))
 
         self._render_exprs = render_exprs
@@ -193,9 +199,21 @@ class ExcludeConstraint(ColumnCollectionConstraint):
         self.using = kw.get("using", "gist")
         where = kw.get("where")
         if where is not None:
-            self.where = expression._literal_as_text(
-                where, allow_coercion_to_text=True
+            self.where = coercions.expect(roles.StatementOptionRole, where)
+
+    def _set_parent(self, table):
+        super(ExcludeConstraint, self)._set_parent(table)
+
+        self._render_exprs = [
+            (
+                expr if isinstance(expr, elements.ClauseElement) else colexpr,
+                name,
+                operator,
+            )
+            for (expr, name, operator), colexpr in util.zip_longest(
+                self._render_exprs, self.columns
             )
+        ]
 
     def copy(self, **kw):
         elements = [(col, self.operators[col]) for col in self.columns.keys()]
index fcb1d4155416fe6e0caa3d07f9029949a167af4b..d0b16a74570eb07e9b46047d10505f695be0fb65 100644 (file)
@@ -626,7 +626,7 @@ class ResultMetaData(object):
             if raiseerr:
                 raise exc.NoSuchColumnError(
                     "Could not locate column in row for column '%s'"
-                    % expression._string_or_unprintable(key)
+                    % util.string_or_unprintable(key)
                 )
             else:
                 return None
index 09a9d73b7308f9eb9ac557c2f3ef229c2ff82418..4bb55ca493c3487b2a73a904be63fd3c127326ad 100644 (file)
@@ -269,7 +269,8 @@ class BakedQuery(object):
                 User.id == Address.user_id).correlate(Address)
 
             main_bq = self.bakery(
-                lambda s: s.query(Address.id, sub_bq.to_query(q).as_scalar())
+                lambda s: s.query(
+                Address.id, sub_bq.to_query(q).scalar_subquery())
             )
 
         :param query_or_session: a :class:`.Query` object or a class
index 25752f2e82d25d3d9370607b67b40b7dec763e85..1d05ddc996222ed5ae16ff1e221f09364d983335 100644 (file)
@@ -217,7 +217,7 @@ class _GetColumns(object):
         mp = class_mapper(self.cls, configure=False)
         if mp:
             if key not in mp.all_orm_descriptors:
-                raise exc.InvalidRequestError(
+                raise AttributeError(
                     "Class %r does not have a mapped column named %r"
                     % (self.cls, key)
                 )
index 1b8c8c7f33d4cf0e4d616e2c8453001c83b8bf49..f37928cc1ebd60011ed6683b2e3f98bb713582d7 100644 (file)
@@ -109,10 +109,6 @@ class QueryableAttribute(
             instance_state(instance), instance_dict(instance), passive
         )
 
-    def __selectable__(self):
-        # TODO: conditionally attach this method based on clause_element ?
-        return self
-
     @util.memoized_property
     def info(self):
         """Return the 'info' dictionary for the underlying SQL element.
index b64b35e72eee3c338568954f4e2985a2412e83c9..f809d5891b763e8e176a8b6826e7c3c25f766b19 100644 (file)
@@ -15,7 +15,6 @@ from . import exc
 from .. import exc as sa_exc
 from .. import inspection
 from .. import util
-from ..sql import expression
 
 
 PASSIVE_NO_RESULT = util.symbol(
@@ -356,13 +355,6 @@ def _is_mapped_class(entity):
     )
 
 
-def _attr_as_key(attr):
-    if hasattr(attr, "key"):
-        return attr.key
-    else:
-        return expression._column_as_key(attr)
-
-
 def _orm_columns(entity):
     insp = inspection.inspect(entity, False)
     if hasattr(insp, "selectable") and hasattr(insp.selectable, "c"):
index b9297e15c9b041c8125a6295696f1cf7a4f8e1ef..6bd009fb8933311bd25e98faeb28e017d544a2b5 100644 (file)
@@ -110,8 +110,9 @@ from sqlalchemy.util.compat import inspect_getfullargspec
 from . import base
 from .. import exc as sa_exc
 from .. import util
+from ..sql import coercions
 from ..sql import expression
-
+from ..sql import roles
 
 __all__ = [
     "collection",
@@ -243,7 +244,7 @@ def column_mapped_collection(mapping_spec):
 
     """
     cols = [
-        expression._only_column_elements(q, "mapping_spec")
+        coercions.expect(roles.ColumnArgumentRole, q, argname="mapping_spec")
         for q in util.to_list(mapping_spec)
     ]
     keyfunc = _PlainColumnGetter(cols)
index d2b08a9087b58a2619221f949a8e8dceb6eabb54..5098a55ce3793065a28bfa1ebd2a31948dcac622 100644 (file)
@@ -35,6 +35,7 @@ from .base import MANYTOONE
 from .base import NOT_EXTENSION
 from .base import ONETOMANY
 from .. import inspect
+from .. import inspection
 from .. import util
 from ..sql import operators
 
@@ -270,6 +271,7 @@ class MapperProperty(_MappedAttribute, InspectionAttr, util.MemoizedSlots):
         )
 
 
+@inspection._self_inspects
 class PropComparator(operators.ColumnOperators):
     r"""Defines SQL operators for :class:`.MapperProperty` objects.
 
index 6eadffb16030d36efd2f09e95a1e23b8a6cf6785..ccf05a7833ba81ddaed2b76d2bde486419c23fde 100644 (file)
@@ -45,8 +45,10 @@ from .. import log
 from .. import schema
 from .. import sql
 from .. import util
+from ..sql import coercions
 from ..sql import expression
 from ..sql import operators
+from ..sql import roles
 from ..sql import util as sql_util
 from ..sql import visitors
 
@@ -644,8 +646,14 @@ class Mapper(InspectionAttr):
         self.batch = batch
         self.eager_defaults = eager_defaults
         self.column_prefix = column_prefix
-        self.polymorphic_on = expression._clause_element_as_expr(
-            polymorphic_on
+        self.polymorphic_on = (
+            coercions.expect(
+                roles.ColumnArgumentOrKeyRole,
+                polymorphic_on,
+                argname="polymorphic_on",
+            )
+            if polymorphic_on is not None
+            else None
         )
         self._dependency_processors = []
         self.validators = util.immutabledict()
@@ -1548,14 +1556,6 @@ class Mapper(InspectionAttr):
                         "can be passed for polymorphic_on"
                     )
                 prop = self.polymorphic_on
-            elif not expression._is_column(self.polymorphic_on):
-                # polymorphic_on is not a Column and not a ColumnProperty;
-                # not supported right now.
-                raise sa_exc.ArgumentError(
-                    "Only direct column-mapped "
-                    "property or SQL expression "
-                    "can be passed for polymorphic_on"
-                )
             else:
                 # polymorphic_on is a Column or SQL expression and
                 # doesn't appear to be mapped. this means it can be 1.
@@ -1851,11 +1851,7 @@ class Mapper(InspectionAttr):
         # generate a properties.ColumnProperty
         columns = util.to_list(prop)
         column = columns[0]
-        if not expression._is_column(column):
-            raise sa_exc.ArgumentError(
-                "%s=%r is not an instance of MapperProperty or Column"
-                % (key, prop)
-            )
+        assert isinstance(column, expression.ColumnElement)
 
         prop = self._props.get(key, None)
 
@@ -2260,6 +2256,9 @@ class Mapper(InspectionAttr):
             for table, columns in self._cols_by_table.items()
         )
 
+    def __clause_element__(self):
+        return self.selectable
+
     @property
     def selectable(self):
         """The :func:`.select` construct this :class:`.Mapper` selects from
index e686b61c36ded064d71aa23767bdc1048b70532b..5106bff94e5300a2f26d28fb6d635ea70fe2f1a2 100644 (file)
@@ -28,7 +28,9 @@ from .base import state_str
 from .. import exc as sa_exc
 from .. import sql
 from .. import util
+from ..sql import coercions
 from ..sql import expression
+from ..sql import roles
 from ..sql.base import _from_objects
 
 
@@ -1914,7 +1916,7 @@ class BulkUpdateEvaluate(BulkEvaluate, BulkUpdate):
         values = self._resolved_values_keys_as_propnames
         for key, value in values:
             self.value_evaluators[key] = evaluator_compiler.process(
-                expression._literal_as_binds(value)
+                coercions.expect(roles.ExpressionElementRole, value)
             )
 
     def _do_post_synchronize(self):
index 2b323429fccc43e0b4da5698261348b9fbe79639..e2c10e50aa964e985a41a1cfcea3054522688e26 100644 (file)
@@ -19,7 +19,8 @@ from .interfaces import StrategizedProperty
 from .util import _orm_full_deannotate
 from .. import log
 from .. import util
-from ..sql import expression
+from ..sql import coercions
+from ..sql import roles
 
 
 __all__ = ["ColumnProperty"]
@@ -131,9 +132,14 @@ class ColumnProperty(StrategizedProperty):
 
         """
         super(ColumnProperty, self).__init__()
-        self._orig_columns = [expression._labeled(c) for c in columns]
+        self._orig_columns = [
+            coercions.expect(roles.LabeledColumnExprRole, c) for c in columns
+        ]
         self.columns = [
-            expression._labeled(_orm_full_deannotate(c)) for c in columns
+            coercions.expect(
+                roles.LabeledColumnExprRole, _orm_full_deannotate(c)
+            )
+            for c in columns
         ]
         self.group = kwargs.pop("group", None)
         self.deferred = kwargs.pop("deferred", False)
index 6d777ceae6a47efa384af7f8f20ebefa5a9617f0..8ef1663a12f1418311dd9680a03c0898bc53777b 100644 (file)
@@ -47,11 +47,12 @@ from .. import inspection
 from .. import log
 from .. import sql
 from .. import util
+from ..sql import coercions
 from ..sql import expression
+from ..sql import roles
 from ..sql import util as sql_util
 from ..sql import visitors
 from ..sql.base import ColumnCollection
-from ..sql.expression import _interpret_as_from
 from ..sql.selectable import ForUpdateArg
 
 
@@ -252,7 +253,7 @@ class Query(object):
                         "expected when the base alias is being set."
                     )
                 fa.append(info.selectable)
-            elif not info.is_selectable:
+            elif not info.is_clause_element or not info._is_from_clause:
                 raise sa_exc.ArgumentError(
                     "argument is not a mapped class, mapper, "
                     "aliased(), or FromClause instance."
@@ -309,9 +310,7 @@ class Query(object):
 
     def _adapt_col_list(self, cols):
         return [
-            self._adapt_clause(
-                expression._literal_as_label_reference(o), True, True
-            )
+            self._adapt_clause(coercions.expect(roles.ByOfRole, o), True, True)
             for o in cols
         ]
 
@@ -637,6 +636,12 @@ class Query(object):
 
         return self.enable_eagerloads(False).statement.label(name)
 
+    @util.deprecated(
+        "1.4",
+        "The :meth:`.Query.as_scalar` method is deprecated and will be "
+        "removed in a future release.  Please refer to "
+        ":meth:`.Query.scalar_subquery`.",
+    )
     def as_scalar(self):
         """Return the full SELECT statement represented by this
         :class:`.Query`, converted to a scalar subquery.
@@ -644,19 +649,20 @@ class Query(object):
         Analogous to :meth:`sqlalchemy.sql.expression.SelectBase.as_scalar`.
 
         """
+        return self.scalar_subquery()
 
-        return self.enable_eagerloads(False).statement.as_scalar()
-
-    @property
-    def selectable(self):
-        """Return the :class:`.Select` object emitted by this :class:`.Query`.
-
-        Used for :func:`.inspect` compatibility, this is equivalent to::
+    def scalar_subquery(self):
+        """Return the full SELECT statement represented by this
+        :class:`.Query`, converted to a scalar subquery.
 
-            query.enable_eagerloads(False).with_labels().statement
+        Analogous to
+        :meth:`sqlalchemy.sql.expression.SelectBase.scalar_subquery`.
 
+        .. versionchanged:: 1.4 the :meth:`.Query.scalar_subquery` method
+           replaces the :meth:`.Query.as_scalar` method.
         """
-        return self.__clause_element__()
+
+        return self.enable_eagerloads(False).statement.scalar_subquery()
 
     def __clause_element__(self):
         return self.enable_eagerloads(False).with_labels().statement
@@ -1094,7 +1100,9 @@ class Query(object):
                 self._correlate = self._correlate.union([None])
             else:
                 self._correlate = self._correlate.union(
-                    sql_util.surface_selectables(_interpret_as_from(s))
+                    sql_util.surface_selectables(
+                        coercions.expect(roles.FromClauseRole, s)
+                    )
                 )
 
     @_generative()
@@ -1757,7 +1765,7 @@ class Query(object):
 
         """
         for criterion in list(criterion):
-            criterion = expression._expression_literal_as_text(criterion)
+            criterion = coercions.expect(roles.WhereHavingRole, criterion)
 
             criterion = self._adapt_clause(criterion, True, True)
 
@@ -1870,7 +1878,7 @@ class Query(object):
 
         """
 
-        criterion = expression._expression_literal_as_text(criterion)
+        criterion = coercions.expect(roles.WhereHavingRole, criterion)
 
         if criterion is not None and not isinstance(
             criterion, sql.ClauseElement
@@ -3184,7 +3192,7 @@ class Query(object):
             ORM tutorial
 
         """
-        statement = expression._expression_literal_as_text(statement)
+        statement = coercions.expect(roles.SelectStatementRole, statement)
 
         if not isinstance(
             statement, (expression.TextClause, expression.SelectBase)
@@ -4651,7 +4659,7 @@ class QueryContext(object):
         if query._statement is not None:
             if (
                 isinstance(query._statement, expression.SelectBase)
-                and not query._statement._textual
+                and not query._statement._is_textual
                 and not query._statement.use_labels
             ):
                 self.statement = query._statement.apply_labels()
index b40fad332bfa418944f3c87a97f93c7e776df88b..8b03955e22a5afd8485fec8dbe62598e819db65a 100644 (file)
@@ -36,8 +36,10 @@ from .. import schema
 from .. import sql
 from .. import util
 from ..inspection import inspect
+from ..sql import coercions
 from ..sql import expression
 from ..sql import operators
+from ..sql import roles
 from ..sql import visitors
 from ..sql.util import _deep_deannotate
 from ..sql.util import _shallow_annotate
@@ -63,7 +65,7 @@ def remote(expr):
 
     """
     return _annotate_columns(
-        expression._clause_element_as_expr(expr), {"remote": True}
+        coercions.expect(roles.ColumnArgumentRole, expr), {"remote": True}
     )
 
 
@@ -83,7 +85,7 @@ def foreign(expr):
     """
 
     return _annotate_columns(
-        expression._clause_element_as_expr(expr), {"foreign": True}
+        coercions.expect(roles.ColumnArgumentRole, expr), {"foreign": True}
     )
 
 
@@ -1897,7 +1899,9 @@ class RelationshipProperty(StrategizedProperty):
                     self,
                     attr,
                     _orm_deannotate(
-                        expression._only_column_elements(val, attr)
+                        coercions.expect(
+                            roles.ColumnArgumentRole, val, argname=attr
+                        )
                     ),
                 )
 
@@ -1905,17 +1909,23 @@ class RelationshipProperty(StrategizedProperty):
         # remote_side are all columns, not strings.
         if self.order_by is not False and self.order_by is not None:
             self.order_by = [
-                expression._only_column_elements(x, "order_by")
+                coercions.expect(
+                    roles.ColumnArgumentRole, x, argname="order_by"
+                )
                 for x in util.to_list(self.order_by)
             ]
 
         self._user_defined_foreign_keys = util.column_set(
-            expression._only_column_elements(x, "foreign_keys")
+            coercions.expect(
+                roles.ColumnArgumentRole, x, argname="foreign_keys"
+            )
             for x in util.to_column_set(self._user_defined_foreign_keys)
         )
 
         self.remote_side = util.column_set(
-            expression._only_column_elements(x, "remote_side")
+            coercions.expect(
+                roles.ColumnArgumentRole, x, argname="remote_side"
+            )
             for x in util.to_column_set(self.remote_side)
         )
 
@@ -2653,7 +2663,9 @@ class JoinCondition(object):
         else:
 
             def repl(element):
-                if element in remote_side:
+                # use set() to avoid generating ``__eq__()`` expressions
+                # against each element
+                if element in set(remote_side):
                     return element._annotate({"remote": True})
 
             self.primaryjoin = visitors.replacement_traverse(
index 517fc2b36b801dafdcab774cc763809ae7fc3fa7..cc3361e904de4f0d4b3908ec104895fcfe7918cd 100644 (file)
@@ -32,7 +32,8 @@ from .. import exc as sa_exc
 from .. import sql
 from .. import util
 from ..inspection import inspect
-from ..sql import expression
+from ..sql import coercions
+from ..sql import roles
 from ..sql import util as sql_util
 
 
@@ -1257,9 +1258,7 @@ class Session(_SessionClassMethods):
             in order to execute the statement.
 
         """
-        clause = expression._literal_as_text(
-            clause, allow_coercion_to_text=True
-        )
+        clause = coercions.expect(roles.CoerceTextStatementRole, clause)
 
         if bind is None:
             bind = self.get_bind(mapper, clause=clause, **kw)
index 912b6b5503d27afc250eef8e329bfab5f68c5570..ebcb101adbb2f034c566ad1d03ccdf24e818421c 100644 (file)
@@ -24,7 +24,8 @@ from .util import _orm_full_deannotate
 from .. import exc as sa_exc
 from .. import inspect
 from .. import util
-from ..sql import expression as sql_expr
+from ..sql import coercions
+from ..sql import roles
 from ..sql.base import _generative
 from ..sql.base import Generative
 
@@ -1543,7 +1544,9 @@ def with_expression(loadopt, key, expression):
 
     """
 
-    expression = sql_expr._labeled(_orm_full_deannotate(expression))
+    expression = coercions.expect(
+        roles.LabeledColumnExprRole, _orm_full_deannotate(expression)
+    )
 
     return loadopt.set_column_strategy(
         (key,), {"query_expression": True}, opts={"expression": expression}
index 36da781a41a1400b4d6dbf2246a8ac6399b12626..e574181068a3f57781b17cd8777343148c9072b8 100644 (file)
@@ -628,6 +628,9 @@ class AliasedInsp(InspectionAttr):
     is_aliased_class = True
     "always returns True"
 
+    def __clause_element__(self):
+        return self.selectable
+
     @property
     def class_(self):
         """Return the mapped class ultimately represented by this
index fb5639ef372ce4d7dbbf3e13a356d395b0f570fa..00cafd8ff80ee275cb87b07930a76b7149cfdf37 100644 (file)
@@ -94,6 +94,22 @@ def __go(lcls):
     from .elements import ClauseList  # noqa
     from .selectable import AnnotatedFromClause  # noqa
 
+    from . import base
+    from . import coercions
+    from . import elements
+    from . import selectable
+    from . import schema
+    from . import sqltypes
+    from . import type_api
+
+    base.coercions = elements.coercions = coercions
+    base.elements = elements
+    base.type_api = type_api
+    coercions.elements = elements
+    coercions.schema = schema
+    coercions.selectable = selectable
+    coercions.sqltypes = sqltypes
+
     _prepare_annotations(ColumnElement, AnnotatedColumnElement)
     _prepare_annotations(FromClause, AnnotatedFromClause)
     _prepare_annotations(ClauseList, Annotated)
index c5e5fd8a1b20159c67ef8cdf94ad9cce2445018e..9df0c932f96abda31df7d393b695a6009449aba7 100644 (file)
@@ -17,6 +17,9 @@ from .visitors import ClauseVisitor
 from .. import exc
 from .. import util
 
+coercions = None  # type: types.ModuleType
+elements = None  # type: types.ModuleType
+type_api = None  # type: types.ModuleType
 
 PARSE_AUTOCOMMIT = util.symbol("PARSE_AUTOCOMMIT")
 NO_ARG = util.symbol("NO_ARG")
@@ -589,8 +592,7 @@ class ColumnCollection(util.OrderedProperties):
 
     __hash__ = None
 
-    @util.dependencies("sqlalchemy.sql.elements")
-    def __eq__(self, elements, other):
+    def __eq__(self, other):
         l = []
         for c in getattr(other, "_all_columns", other):
             for local in self._all_columns:
@@ -636,8 +638,7 @@ class ColumnSet(util.ordered_column_set):
     def __add__(self, other):
         return list(self) + list(other)
 
-    @util.dependencies("sqlalchemy.sql.elements")
-    def __eq__(self, elements, other):
+    def __eq__(self, other):
         l = []
         for c in other:
             for local in self:
diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py
new file mode 100644 (file)
index 0000000..7c7222f
--- /dev/null
@@ -0,0 +1,580 @@
+# sql/coercions.py
+# Copyright (C) 2005-2019 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
+
+import numbers
+import re
+
+from . import operators
+from . import roles
+from . import visitors
+from .visitors import Visitable
+from .. import exc
+from .. import inspection
+from .. import util
+from ..util import collections_abc
+
+elements = None  # type: types.ModuleType
+schema = None  # type: types.ModuleType
+selectable = None  # type: types.ModuleType
+sqltypes = None  # type: types.ModuleType
+
+
+def _is_literal(element):
+    """Return whether or not the element is a "literal" in the context
+    of a SQL expression construct.
+
+    """
+    return not isinstance(
+        element, (Visitable, schema.SchemaEventTarget)
+    ) and not hasattr(element, "__clause_element__")
+
+
+def _document_text_coercion(paramname, meth_rst, param_rst):
+    return util.add_parameter_text(
+        paramname,
+        (
+            ".. warning:: "
+            "The %s argument to %s can be passed as a Python string argument, "
+            "which will be treated "
+            "as **trusted SQL text** and rendered as given.  **DO NOT PASS "
+            "UNTRUSTED INPUT TO THIS PARAMETER**."
+        )
+        % (param_rst, meth_rst),
+    )
+
+
+def expect(role, element, **kw):
+    # major case is that we are given a ClauseElement already, skip more
+    # elaborate logic up front if possible
+    impl = _impl_lookup[role]
+
+    if not isinstance(element, (elements.ClauseElement, schema.SchemaItem)):
+        resolved = impl._resolve_for_clause_element(element, **kw)
+    else:
+        resolved = element
+
+    if issubclass(resolved.__class__, impl._role_class):
+        if impl._post_coercion:
+            resolved = impl._post_coercion(resolved, **kw)
+        return resolved
+    else:
+        return impl._implicit_coercions(element, resolved, **kw)
+
+
+def expect_as_key(role, element, **kw):
+    kw["as_key"] = True
+    return expect(role, element, **kw)
+
+
+def expect_col_expression_collection(role, expressions):
+    for expr in expressions:
+        strname = None
+        column = None
+
+        resolved = expect(role, expr)
+        if isinstance(resolved, util.string_types):
+            strname = resolved = expr
+        else:
+            cols = []
+            visitors.traverse(resolved, {}, {"column": cols.append})
+            if cols:
+                column = cols[0]
+        add_element = column if column is not None else strname
+        yield resolved, column, strname, add_element
+
+
+class RoleImpl(object):
+    __slots__ = ("_role_class", "name", "_use_inspection")
+
+    def _literal_coercion(self, element, **kw):
+        raise NotImplementedError()
+
+    _post_coercion = None
+
+    def __init__(self, role_class):
+        self._role_class = role_class
+        self.name = role_class._role_name
+        self._use_inspection = issubclass(role_class, roles.UsesInspection)
+
+    def _resolve_for_clause_element(self, element, argname=None, **kw):
+        literal_coercion = self._literal_coercion
+        original_element = element
+        is_clause_element = False
+
+        while hasattr(element, "__clause_element__") and not isinstance(
+            element, (elements.ClauseElement, schema.SchemaItem)
+        ):
+            element = element.__clause_element__()
+            is_clause_element = True
+
+        if not is_clause_element:
+            if self._use_inspection:
+                insp = inspection.inspect(element, raiseerr=False)
+                if insp is not None:
+                    try:
+                        return insp.__clause_element__()
+                    except AttributeError:
+                        self._raise_for_expected(original_element, argname)
+
+            return self._literal_coercion(element, argname=argname, **kw)
+        else:
+            return element
+
+    def _implicit_coercions(self, element, resolved, argname=None, **kw):
+        self._raise_for_expected(element, argname)
+
+    def _raise_for_expected(self, element, argname=None):
+        if argname:
+            raise exc.ArgumentError(
+                "%s expected for argument %r; got %r."
+                % (self.name, argname, element)
+            )
+        else:
+            raise exc.ArgumentError(
+                "%s expected, got %r." % (self.name, element)
+            )
+
+
+class _StringOnly(object):
+    def _resolve_for_clause_element(self, element, argname=None, **kw):
+        return self._literal_coercion(element, **kw)
+
+
+class _ReturnsStringKey(object):
+    def _implicit_coercions(
+        self, original_element, resolved, argname=None, **kw
+    ):
+        if isinstance(original_element, util.string_types):
+            return original_element
+        else:
+            self._raise_for_expected(original_element, argname)
+
+    def _literal_coercion(self, element, **kw):
+        return element
+
+
+class _ColumnCoercions(object):
+    def _warn_for_scalar_subquery_coercion(self):
+        util.warn_deprecated(
+            "coercing SELECT object to scalar subquery in a "
+            "column-expression context is deprecated in version 1.4; "
+            "please use the .scalar_subquery() method to produce a scalar "
+            "subquery.  This automatic coercion will be removed in a "
+            "future release."
+        )
+
+    def _implicit_coercions(
+        self, original_element, resolved, argname=None, **kw
+    ):
+        if resolved._is_select_statement:
+            self._warn_for_scalar_subquery_coercion()
+            return resolved.scalar_subquery()
+        elif (
+            resolved._is_from_clause
+            and isinstance(resolved, selectable.Alias)
+            and resolved.original._is_select_statement
+        ):
+            self._warn_for_scalar_subquery_coercion()
+            return resolved.original.scalar_subquery()
+        else:
+            self._raise_for_expected(original_element, argname)
+
+
+def _no_text_coercion(
+    element, argname=None, exc_cls=exc.ArgumentError, extra=None
+):
+    raise exc_cls(
+        "%(extra)sTextual SQL expression %(expr)r %(argname)sshould be "
+        "explicitly declared as text(%(expr)r)"
+        % {
+            "expr": util.ellipses_string(element),
+            "argname": "for argument %s" % (argname,) if argname else "",
+            "extra": "%s " % extra if extra else "",
+        }
+    )
+
+
+class _NoTextCoercion(object):
+    def _literal_coercion(self, element, argname=None):
+        if isinstance(element, util.string_types) and issubclass(
+            elements.TextClause, self._role_class
+        ):
+            _no_text_coercion(element, argname)
+        else:
+            self._raise_for_expected(element, argname)
+
+
+class _CoerceLiterals(object):
+    _coerce_consts = False
+    _coerce_star = False
+    _coerce_numerics = False
+
+    def _text_coercion(self, element, argname=None):
+        return _no_text_coercion(element, argname)
+
+    def _literal_coercion(self, element, argname=None):
+        if isinstance(element, util.string_types):
+            if self._coerce_star and element == "*":
+                return elements.ColumnClause("*", is_literal=True)
+            else:
+                return self._text_coercion(element, argname)
+
+        if self._coerce_consts:
+            if element is None:
+                return elements.Null()
+            elif element is False:
+                return elements.False_()
+            elif element is True:
+                return elements.True_()
+
+        if self._coerce_numerics and isinstance(element, (numbers.Number)):
+            return elements.ColumnClause(str(element), is_literal=True)
+
+        self._raise_for_expected(element, argname)
+
+
+class ExpressionElementImpl(
+    _ColumnCoercions, RoleImpl, roles.ExpressionElementRole
+):
+    def _literal_coercion(self, element, name=None, type_=None, argname=None):
+        if element is None:
+            return elements.Null()
+        else:
+            try:
+                return elements.BindParameter(
+                    name, element, type_, unique=True
+                )
+            except exc.ArgumentError:
+                self._raise_for_expected(element)
+
+
+class BinaryElementImpl(
+    ExpressionElementImpl, RoleImpl, roles.BinaryElementRole
+):
+    def _literal_coercion(
+        self, element, expr, operator, bindparam_type=None, argname=None
+    ):
+        try:
+            return expr._bind_param(operator, element, type_=bindparam_type)
+        except exc.ArgumentError:
+            self._raise_for_expected(element)
+
+    def _post_coercion(self, resolved, expr, **kw):
+        if (
+            isinstance(resolved, elements.BindParameter)
+            and resolved.type._isnull
+        ):
+            resolved = resolved._clone()
+            resolved.type = expr.type
+        return resolved
+
+
+class InElementImpl(RoleImpl, roles.InElementRole):
+    def _implicit_coercions(
+        self, original_element, resolved, argname=None, **kw
+    ):
+        if resolved._is_from_clause:
+            if (
+                isinstance(resolved, selectable.Alias)
+                and resolved.original._is_select_statement
+            ):
+                return resolved.original
+            else:
+                return resolved.select()
+        else:
+            self._raise_for_expected(original_element, argname)
+
+    def _literal_coercion(self, element, expr, operator, **kw):
+        if isinstance(element, collections_abc.Iterable) and not isinstance(
+            element, util.string_types
+        ):
+            args = []
+            for o in element:
+                if not _is_literal(o):
+                    if not isinstance(o, operators.ColumnOperators):
+                        self._raise_for_expected(element, **kw)
+                elif o is None:
+                    o = elements.Null()
+                else:
+                    o = expr._bind_param(operator, o)
+                args.append(o)
+
+            return elements.ClauseList(*args)
+
+        else:
+            self._raise_for_expected(element, **kw)
+
+    def _post_coercion(self, element, expr, operator, **kw):
+        if element._is_select_statement:
+            return element.scalar_subquery()
+        elif isinstance(element, elements.ClauseList):
+            if len(element.clauses) == 0:
+                op, negate_op = (
+                    (operators.empty_in_op, operators.empty_notin_op)
+                    if operator is operators.in_op
+                    else (operators.empty_notin_op, operators.empty_in_op)
+                )
+                return element.self_group(against=op)._annotate(
+                    dict(in_ops=(op, negate_op))
+                )
+            else:
+                return element.self_group(against=operator)
+
+        elif isinstance(element, elements.BindParameter) and element.expanding:
+
+            if isinstance(expr, elements.Tuple):
+                element = element._with_expanding_in_types(
+                    [elem.type for elem in expr]
+                )
+            return element
+        else:
+            return element
+
+
+class WhereHavingImpl(
+    _CoerceLiterals, _ColumnCoercions, RoleImpl, roles.WhereHavingRole
+):
+
+    _coerce_consts = True
+
+    def _text_coercion(self, element, argname=None):
+        return _no_text_coercion(element, argname)
+
+
+class StatementOptionImpl(
+    _CoerceLiterals, RoleImpl, roles.StatementOptionRole
+):
+
+    _coerce_consts = True
+
+    def _text_coercion(self, element, argname=None):
+        return elements.TextClause(element)
+
+
+class ColumnArgumentImpl(_NoTextCoercion, RoleImpl, roles.ColumnArgumentRole):
+    pass
+
+
+class ColumnArgumentOrKeyImpl(
+    _ReturnsStringKey, RoleImpl, roles.ColumnArgumentOrKeyRole
+):
+    pass
+
+
+class ByOfImpl(_CoerceLiterals, _ColumnCoercions, RoleImpl, roles.ByOfRole):
+
+    _coerce_consts = True
+
+    def _text_coercion(self, element, argname=None):
+        return elements._textual_label_reference(element)
+
+
+class OrderByImpl(ByOfImpl, RoleImpl, roles.OrderByRole):
+    def _post_coercion(self, resolved):
+        if (
+            isinstance(resolved, self._role_class)
+            and resolved._order_by_label_element is not None
+        ):
+            return elements._label_reference(resolved)
+        else:
+            return resolved
+
+
+class DMLColumnImpl(_ReturnsStringKey, RoleImpl, roles.DMLColumnRole):
+    def _post_coercion(self, element, as_key=False):
+        if as_key:
+            return element.key
+        else:
+            return element
+
+
+class ConstExprImpl(RoleImpl, roles.ConstExprRole):
+    def _literal_coercion(self, element, argname=None):
+        if element is None:
+            return elements.Null()
+        elif element is False:
+            return elements.False_()
+        elif element is True:
+            return elements.True_()
+        else:
+            self._raise_for_expected(element, argname)
+
+
+class TruncatedLabelImpl(_StringOnly, RoleImpl, roles.TruncatedLabelRole):
+    def _implicit_coercions(
+        self, original_element, resolved, argname=None, **kw
+    ):
+        if isinstance(original_element, util.string_types):
+            return resolved
+        else:
+            self._raise_for_expected(original_element, argname)
+
+    def _literal_coercion(self, element, argname=None):
+        """coerce the given value to :class:`._truncated_label`.
+
+        Existing :class:`._truncated_label` and
+        :class:`._anonymous_label` objects are passed
+        unchanged.
+        """
+
+        if isinstance(element, elements._truncated_label):
+            return element
+        else:
+            return elements._truncated_label(element)
+
+
+class DDLExpressionImpl(_CoerceLiterals, RoleImpl, roles.DDLExpressionRole):
+
+    _coerce_consts = True
+
+    def _text_coercion(self, element, argname=None):
+        return elements.TextClause(element)
+
+
+class DDLConstraintColumnImpl(
+    _ReturnsStringKey, RoleImpl, roles.DDLConstraintColumnRole
+):
+    pass
+
+
+class LimitOffsetImpl(RoleImpl, roles.LimitOffsetRole):
+    def _implicit_coercions(self, element, resolved, argname=None, **kw):
+        if resolved is None:
+            return None
+        else:
+            self._raise_for_expected(element, argname)
+
+    def _literal_coercion(self, element, name, type_, **kw):
+        if element is None:
+            return None
+        else:
+            value = util.asint(element)
+            return selectable._OffsetLimitParam(
+                name, value, type_=type_, unique=True
+            )
+
+
+class LabeledColumnExprImpl(
+    ExpressionElementImpl, roles.LabeledColumnExprRole
+):
+    def _implicit_coercions(
+        self, original_element, resolved, argname=None, **kw
+    ):
+        if isinstance(resolved, roles.ExpressionElementRole):
+            return resolved.label(None)
+        else:
+            new = super(LabeledColumnExprImpl, self)._implicit_coercions(
+                original_element, resolved, argname=argname, **kw
+            )
+            if isinstance(new, roles.ExpressionElementRole):
+                return new.label(None)
+            else:
+                self._raise_for_expected(original_element, argname)
+
+
+class ColumnsClauseImpl(_CoerceLiterals, RoleImpl, roles.ColumnsClauseRole):
+
+    _coerce_consts = True
+    _coerce_numerics = True
+    _coerce_star = True
+
+    _guess_straight_column = re.compile(r"^\w\S*$", re.I)
+
+    def _text_coercion(self, element, argname=None):
+        element = str(element)
+
+        guess_is_literal = not self._guess_straight_column.match(element)
+        raise exc.ArgumentError(
+            "Textual column expression %(column)r %(argname)sshould be "
+            "explicitly declared with text(%(column)r), "
+            "or use %(literal_column)s(%(column)r) "
+            "for more specificity"
+            % {
+                "column": util.ellipses_string(element),
+                "argname": "for argument %s" % (argname,) if argname else "",
+                "literal_column": "literal_column"
+                if guess_is_literal
+                else "column",
+            }
+        )
+
+
+class ReturnsRowsImpl(RoleImpl, roles.ReturnsRowsRole):
+    pass
+
+
+class StatementImpl(_NoTextCoercion, RoleImpl, roles.StatementRole):
+    pass
+
+
+class CoerceTextStatementImpl(_CoerceLiterals, RoleImpl, roles.StatementRole):
+    def _text_coercion(self, element, argname=None):
+        return elements.TextClause(element)
+
+
+class SelectStatementImpl(
+    _NoTextCoercion, RoleImpl, roles.SelectStatementRole
+):
+    def _implicit_coercions(
+        self, original_element, resolved, argname=None, **kw
+    ):
+        if resolved._is_text_clause:
+            return resolved.columns()
+        else:
+            self._raise_for_expected(original_element, argname)
+
+
+class HasCTEImpl(ReturnsRowsImpl, roles.HasCTERole):
+    pass
+
+
+class FromClauseImpl(_NoTextCoercion, RoleImpl, roles.FromClauseRole):
+    def _implicit_coercions(
+        self, original_element, resolved, argname=None, **kw
+    ):
+        if resolved._is_text_clause:
+            return resolved
+        else:
+            self._raise_for_expected(original_element, argname)
+
+
+class DMLSelectImpl(_NoTextCoercion, RoleImpl, roles.DMLSelectRole):
+    def _implicit_coercions(
+        self, original_element, resolved, argname=None, **kw
+    ):
+        if resolved._is_from_clause:
+            if (
+                isinstance(resolved, selectable.Alias)
+                and resolved.original._is_select_statement
+            ):
+                return resolved.original
+            else:
+                return resolved.select()
+        else:
+            self._raise_for_expected(original_element, argname)
+
+
+class CompoundElementImpl(
+    _NoTextCoercion, RoleImpl, roles.CompoundElementRole
+):
+    def _implicit_coercions(self, original_element, resolved, argname=None):
+        if resolved._is_from_clause:
+            return resolved
+        else:
+            self._raise_for_expected(original_element, argname)
+
+
+_impl_lookup = {}
+
+
+for name in dir(roles):
+    cls = getattr(roles, name)
+    if name.endswith("Role"):
+        name = name.replace("Role", "Impl")
+        if name in globals():
+            impl = globals()[name](cls)
+            _impl_lookup[cls] = impl
index c7fe3dc50e164bb664ac2a42f68b539d68c4a860..8080d2cc662fb45af4b483adfdb3dba7824814ba 100644 (file)
@@ -27,14 +27,15 @@ import contextlib
 import itertools
 import re
 
+from . import coercions
 from . import crud
 from . import elements
 from . import functions
 from . import operators
+from . import roles
 from . import schema
 from . import selectable
 from . import sqltypes
-from . import visitors
 from .. import exc
 from .. import util
 
@@ -400,7 +401,9 @@ class TypeCompiler(util.with_metaclass(util.EnsureKWArgType, object)):
         return type_._compiler_dispatch(self, **kw)
 
 
-class _CompileLabel(visitors.Visitable):
+# this was a Visitable, but to allow accurate detection of
+# column elements this is actually a column element
+class _CompileLabel(elements.ColumnElement):
 
     """lightweight label object which acts as an expression.Label."""
 
@@ -766,10 +769,10 @@ class SQLCompiler(Compiled):
             else:
                 col = with_cols[element.element]
         except KeyError:
-            elements._no_text_coercion(
+            coercions._no_text_coercion(
                 element.element,
-                exc.CompileError,
-                "Can't resolve label reference for ORDER BY / GROUP BY.",
+                extra="Can't resolve label reference for ORDER BY / GROUP BY.",
+                exc_cls=exc.CompileError,
             )
         else:
             kwargs["render_label_as_label"] = col
@@ -1635,7 +1638,6 @@ class SQLCompiler(Compiled):
         if is_new_cte:
             self.ctes_by_name[cte_name] = cte
 
-            # look for embedded DML ctes and propagate autocommit
             if (
                 "autocommit" in cte.element._execution_options
                 and "autocommit" not in self.execution_options
@@ -1656,10 +1658,10 @@ class SQLCompiler(Compiled):
                     self.ctes_recursive = True
                 text = self.preparer.format_alias(cte, cte_name)
                 if cte.recursive:
-                    if isinstance(cte.original, selectable.Select):
-                        col_source = cte.original
-                    elif isinstance(cte.original, selectable.CompoundSelect):
-                        col_source = cte.original.selects[0]
+                    if isinstance(cte.element, selectable.Select):
+                        col_source = cte.element
+                    elif isinstance(cte.element, selectable.CompoundSelect):
+                        col_source = cte.element.selects[0]
                     else:
                         assert False
                     recur_cols = [
@@ -1810,7 +1812,7 @@ class SQLCompiler(Compiled):
         ):
             result_expr = _CompileLabel(
                 col_expr,
-                elements._as_truncated(column.name),
+                coercions.expect(roles.TruncatedLabelRole, column.name),
                 alt_names=(column.key,),
             )
         elif (
@@ -1830,7 +1832,7 @@ class SQLCompiler(Compiled):
             # assert isinstance(column, elements.ColumnClause)
             result_expr = _CompileLabel(
                 col_expr,
-                elements._as_truncated(column.name),
+                coercions.expect(roles.TruncatedLabelRole, column.name),
                 alt_names=(column.key,),
             )
         else:
@@ -1880,7 +1882,7 @@ class SQLCompiler(Compiled):
             newelem = cloned[element] = element._clone()
 
             if (
-                newelem.is_selectable
+                newelem._is_from_clause
                 and newelem._is_join
                 and isinstance(newelem.right, selectable.FromGrouping)
             ):
@@ -1933,7 +1935,7 @@ class SQLCompiler(Compiled):
                 # marker in the stack.
                 kw["transform_clue"] = "select_container"
                 newelem._copy_internals(clone=visit, **kw)
-            elif newelem.is_selectable and newelem._is_select:
+            elif newelem._is_returns_rows and newelem._is_select_statement:
                 barrier_select = (
                     kw.get("transform_clue", None) == "select_container"
                 )
@@ -2349,6 +2351,7 @@ class SQLCompiler(Compiled):
             + join_type
             + join.right._compiler_dispatch(self, asfrom=True, **kwargs)
             + " ON "
+            # TODO: likely need asfrom=True here?
             + join.onclause._compiler_dispatch(self, **kwargs)
         )
 
index 552f61b4a068d29f7c8cbb3a0745e731e6ba426d..881ea9fcda54dd62d867fd11bcf0756d21bb55fb 100644 (file)
@@ -9,14 +9,16 @@
 within INSERT and UPDATE statements.
 
 """
+import functools
 import operator
 
+from . import coercions
 from . import dml
 from . import elements
+from . import roles
 from .. import exc
 from .. import util
 
-
 REQUIRED = util.symbol(
     "REQUIRED",
     """
@@ -174,7 +176,7 @@ def _get_crud_params(compiler, stmt, **kw):
         if check:
             raise exc.CompileError(
                 "Unconsumed column names: %s"
-                % (", ".join("%s" % c for c in check))
+                % (", ".join("%s" % (c,) for c in check))
             )
 
     if stmt._has_multi_parameters:
@@ -207,8 +209,12 @@ def _key_getters_for_crud_column(compiler, stmt):
         # statement.
         _et = set(stmt._extra_froms)
 
+        c_key_role = functools.partial(
+            coercions.expect_as_key, roles.DMLColumnRole
+        )
+
         def _column_as_key(key):
-            str_key = elements._column_as_key(key)
+            str_key = c_key_role(key)
             if hasattr(key, "table") and key.table in _et:
                 return (key.table.name, str_key)
             else:
@@ -227,7 +233,9 @@ def _key_getters_for_crud_column(compiler, stmt):
                 return col.key
 
     else:
-        _column_as_key = elements._column_as_key
+        _column_as_key = functools.partial(
+            coercions.expect_as_key, roles.DMLColumnRole
+        )
         _getattr_col_key = _col_bind_name = operator.attrgetter("key")
 
     return _column_as_key, _getattr_col_key, _col_bind_name
@@ -386,7 +394,7 @@ def _append_param_parameter(
     kw,
 ):
     value = parameters.pop(col_key)
-    if elements._is_literal(value):
+    if coercions._is_literal(value):
         value = _create_bind_param(
             compiler,
             c,
@@ -633,9 +641,8 @@ def _get_multitable_params(
     values,
     kw,
 ):
-
     normalized_params = dict(
-        (elements._clause_element_as_expr(c), param)
+        (coercions.expect(roles.DMLColumnRole, c), param)
         for c, param in stmt_parameters.items()
     )
     affected_tables = set()
@@ -645,7 +652,7 @@ def _get_multitable_params(
                 affected_tables.add(t)
                 check_columns[_getattr_col_key(c)] = c
                 value = normalized_params[c]
-                if elements._is_literal(value):
+                if coercions._is_literal(value):
                     value = _create_bind_param(
                         compiler,
                         c,
@@ -697,7 +704,7 @@ def _extend_values_for_multiparams(compiler, stmt, values, kw):
             if col in row or col.key in row:
                 key = col if col in row else col.key
 
-                if elements._is_literal(row[key]):
+                if coercions._is_literal(row[key]):
                     new_param = _create_bind_param(
                         compiler,
                         col,
@@ -730,7 +737,7 @@ def _get_stmt_parameters_params(
             # a non-Column expression on the left side;
             # add it to values() in an "as-is" state,
             # coercing right side to bound param
-            if elements._is_literal(v):
+            if coercions._is_literal(v):
                 v = compiler.process(
                     elements.BindParameter(None, v, type_=k.type), **kw
                 )
index d87a6a1b04014ee19267225e9418dbb8325f20aa..ff36a68e4a3bd36a46fb116abc24c48e4674f0ae 100644 (file)
@@ -10,6 +10,7 @@ to invoke them for a create/drop call.
 
 """
 
+from . import roles
 from .base import _bind_or_error
 from .base import _generative
 from .base import Executable
@@ -29,7 +30,7 @@ class _DDLCompiles(ClauseElement):
         return dialect.ddl_compiler(dialect, self, **kw)
 
 
-class DDLElement(Executable, _DDLCompiles):
+class DDLElement(roles.DDLRole, Executable, _DDLCompiles):
     """Base class for DDL expression constructs.
 
     This class is the base for the general purpose :class:`.DDL` class,
index 9a12b84cd901eab6631d354ca2b7419dcc6ecb1e..918f7524e53cb21d2e365e665ce8645ee4238353 100644 (file)
@@ -8,32 +8,21 @@
 """Default implementation of SQL comparison operations.
 """
 
+
+from . import coercions
 from . import operators
+from . import roles
 from . import type_api
-from .elements import _clause_element_as_expr
-from .elements import _const_expr
-from .elements import _is_literal
-from .elements import _literal_as_text
 from .elements import and_
 from .elements import BinaryExpression
-from .elements import BindParameter
-from .elements import ClauseElement
 from .elements import ClauseList
 from .elements import collate
 from .elements import CollectionAggregate
-from .elements import ColumnElement
 from .elements import False_
 from .elements import Null
 from .elements import or_
-from .elements import TextClause
 from .elements import True_
-from .elements import Tuple
 from .elements import UnaryExpression
-from .elements import Visitable
-from .selectable import Alias
-from .selectable import ScalarSelect
-from .selectable import Selectable
-from .selectable import SelectBase
 from .. import exc
 from .. import util
 
@@ -62,7 +51,7 @@ def _boolean_compare(
         ):
             return BinaryExpression(
                 expr,
-                _literal_as_text(obj),
+                coercions.expect(roles.ConstExprRole, obj),
                 op,
                 type_=result_type,
                 negate=negate,
@@ -71,7 +60,7 @@ def _boolean_compare(
         elif op in (operators.is_distinct_from, operators.isnot_distinct_from):
             return BinaryExpression(
                 expr,
-                _literal_as_text(obj),
+                coercions.expect(roles.ConstExprRole, obj),
                 op,
                 type_=result_type,
                 negate=negate,
@@ -82,7 +71,7 @@ def _boolean_compare(
             if op in (operators.eq, operators.is_):
                 return BinaryExpression(
                     expr,
-                    _const_expr(obj),
+                    coercions.expect(roles.ConstExprRole, obj),
                     operators.is_,
                     negate=operators.isnot,
                     type_=result_type,
@@ -90,7 +79,7 @@ def _boolean_compare(
             elif op in (operators.ne, operators.isnot):
                 return BinaryExpression(
                     expr,
-                    _const_expr(obj),
+                    coercions.expect(roles.ConstExprRole, obj),
                     operators.isnot,
                     negate=operators.is_,
                     type_=result_type,
@@ -102,7 +91,9 @@ def _boolean_compare(
                     "operators can be used with None/True/False"
                 )
     else:
-        obj = _check_literal(expr, op, obj)
+        obj = coercions.expect(
+            roles.BinaryElementRole, element=obj, operator=op, expr=expr
+        )
 
     if reverse:
         return BinaryExpression(
@@ -127,7 +118,9 @@ def _custom_op_operate(expr, op, obj, reverse=False, result_type=None, **kw):
 
 
 def _binary_operate(expr, op, obj, reverse=False, result_type=None, **kw):
-    obj = _check_literal(expr, op, obj)
+    obj = coercions.expect(
+        roles.BinaryElementRole, obj, expr=expr, operator=op
+    )
 
     if reverse:
         left, right = obj, expr
@@ -156,77 +149,22 @@ def _scalar(expr, op, fn, **kw):
 
 
 def _in_impl(expr, op, seq_or_selectable, negate_op, **kw):
-    seq_or_selectable = _clause_element_as_expr(seq_or_selectable)
-
-    if isinstance(seq_or_selectable, ScalarSelect):
-        return _boolean_compare(expr, op, seq_or_selectable, negate=negate_op)
-    elif isinstance(seq_or_selectable, SelectBase):
-
-        # TODO: if we ever want to support (x, y, z) IN (select x,
-        # y, z from table), we would need a multi-column version of
-        # as_scalar() to produce a multi- column selectable that
-        # does not export itself as a FROM clause
-
-        return _boolean_compare(
-            expr, op, seq_or_selectable.as_scalar(), negate=negate_op, **kw
-        )
-    elif isinstance(seq_or_selectable, (Selectable, TextClause)):
-        return _boolean_compare(
-            expr, op, seq_or_selectable, negate=negate_op, **kw
-        )
-    elif isinstance(seq_or_selectable, ClauseElement):
-        if (
-            isinstance(seq_or_selectable, BindParameter)
-            and seq_or_selectable.expanding
-        ):
-
-            if isinstance(expr, Tuple):
-                seq_or_selectable = seq_or_selectable._with_expanding_in_types(
-                    [elem.type for elem in expr]
-                )
-
-            return _boolean_compare(
-                expr, op, seq_or_selectable, negate=negate_op
-            )
-        else:
-            raise exc.InvalidRequestError(
-                "in_() accepts"
-                " either a list of expressions, "
-                'a selectable, or an "expanding" bound parameter: %r'
-                % seq_or_selectable
-            )
-
-    # Handle non selectable arguments as sequences
-    args = []
-    for o in seq_or_selectable:
-        if not _is_literal(o):
-            if not isinstance(o, operators.ColumnOperators):
-                raise exc.InvalidRequestError(
-                    "in_() accepts"
-                    " either a list of expressions, "
-                    'a selectable, or an "expanding" bound parameter: %r' % o
-                )
-        elif o is None:
-            o = Null()
-        else:
-            o = expr._bind_param(op, o)
-        args.append(o)
-
-    if len(args) == 0:
-        op, negate_op = (
-            (operators.empty_in_op, operators.empty_notin_op)
-            if op is operators.in_op
-            else (operators.empty_notin_op, operators.empty_in_op)
-        )
+    seq_or_selectable = coercions.expect(
+        roles.InElementRole, seq_or_selectable, expr=expr, operator=op
+    )
+    if "in_ops" in seq_or_selectable._annotations:
+        op, negate_op = seq_or_selectable._annotations["in_ops"]
 
     return _boolean_compare(
-        expr, op, ClauseList(*args).self_group(against=op), negate=negate_op
+        expr, op, seq_or_selectable, negate=negate_op, **kw
     )
 
 
 def _getitem_impl(expr, op, other, **kw):
     if isinstance(expr.type, type_api.INDEXABLE):
-        other = _check_literal(expr, op, other)
+        other = coercions.expect(
+            roles.BinaryElementRole, other, expr=expr, operator=op
+        )
         return _binary_operate(expr, op, other, **kw)
     else:
         _unsupported_impl(expr, op, other, **kw)
@@ -257,7 +195,12 @@ def _match_impl(expr, op, other, **kw):
     return _boolean_compare(
         expr,
         operators.match_op,
-        _check_literal(expr, operators.match_op, other),
+        coercions.expect(
+            roles.BinaryElementRole,
+            other,
+            expr=expr,
+            operator=operators.match_op,
+        ),
         result_type=type_api.MATCHTYPE,
         negate=operators.notmatch_op
         if op is operators.match_op
@@ -278,8 +221,18 @@ def _between_impl(expr, op, cleft, cright, **kw):
     return BinaryExpression(
         expr,
         ClauseList(
-            _check_literal(expr, operators.and_, cleft),
-            _check_literal(expr, operators.and_, cright),
+            coercions.expect(
+                roles.BinaryElementRole,
+                cleft,
+                expr=expr,
+                operator=operators.and_,
+            ),
+            coercions.expect(
+                roles.BinaryElementRole,
+                cright,
+                expr=expr,
+                operator=operators.and_,
+            ),
             operator=operators.and_,
             group=False,
             group_contents=False,
@@ -349,22 +302,3 @@ operator_lookup = {
     "rshift": (_unsupported_impl,),
     "contains": (_unsupported_impl,),
 }
-
-
-def _check_literal(expr, operator, other, bindparam_type=None):
-    if isinstance(other, (ColumnElement, TextClause)):
-        if isinstance(other, BindParameter) and other.type._isnull:
-            other = other._clone()
-            other.type = expr.type
-        return other
-    elif hasattr(other, "__clause_element__"):
-        other = other.__clause_element__()
-    elif isinstance(other, type_api.TypeEngine.Comparator):
-        other = other.expr
-
-    if isinstance(other, (SelectBase, Alias)):
-        return other.as_scalar()
-    elif not isinstance(other, Visitable):
-        return expr._bind_param(operator, other, type_=bindparam_type)
-    else:
-        return other
index 3c40e7914387b98eef4c16484ec0a0f3bfb890f9..c7d83fc12b008ee306c65b1662e3f278d3e2d445 100644 (file)
@@ -9,18 +9,16 @@ Provide :class:`.Insert`, :class:`.Update` and :class:`.Delete`.
 
 """
 
+from . import coercions
+from . import roles
 from .base import _from_objects
 from .base import _generative
 from .base import DialectKWArgs
 from .base import Executable
 from .elements import _clone
-from .elements import _column_as_key
-from .elements import _literal_as_text
 from .elements import and_
 from .elements import ClauseElement
 from .elements import Null
-from .selectable import _interpret_as_from
-from .selectable import _interpret_as_select
 from .selectable import HasCTE
 from .selectable import HasPrefixes
 from .. import exc
@@ -28,7 +26,12 @@ from .. import util
 
 
 class UpdateBase(
-    HasCTE, DialectKWArgs, HasPrefixes, Executable, ClauseElement
+    roles.DMLRole,
+    HasCTE,
+    DialectKWArgs,
+    HasPrefixes,
+    Executable,
+    ClauseElement,
 ):
     """Form the base for ``INSERT``, ``UPDATE``, and ``DELETE`` statements.
 
@@ -210,7 +213,7 @@ class ValuesBase(UpdateBase):
     _post_values_clause = None
 
     def __init__(self, table, values, prefixes):
-        self.table = _interpret_as_from(table)
+        self.table = coercions.expect(roles.FromClauseRole, table)
         self.parameters, self._has_multi_parameters = self._process_colparams(
             values
         )
@@ -604,13 +607,16 @@ class Insert(ValuesBase):
             )
 
         self.parameters, self._has_multi_parameters = self._process_colparams(
-            {_column_as_key(n): Null() for n in names}
+            {
+                coercions.expect(roles.DMLColumnRole, n, as_key=True): Null()
+                for n in names
+            }
         )
 
         self.select_names = names
         self.inline = True
         self.include_insert_from_select_defaults = include_defaults
-        self.select = _interpret_as_select(select)
+        self.select = coercions.expect(roles.DMLSelectRole, select)
 
     def _copy_internals(self, clone=_clone, **kw):
         # TODO: coverage
@@ -678,7 +684,7 @@ class Update(ValuesBase):
             users.update().values(name='ed').where(
                     users.c.name==select([addresses.c.email_address]).\
                                 where(addresses.c.user_id==users.c.id).\
-                                as_scalar()
+                                scalar_subquery()
                     )
 
         :param values:
@@ -744,7 +750,7 @@ class Update(ValuesBase):
             users.update().values(
                     name=select([addresses.c.email_address]).\
                             where(addresses.c.user_id==users.c.id).\
-                            as_scalar()
+                            scalar_subquery()
                 )
 
         .. seealso::
@@ -759,7 +765,9 @@ class Update(ValuesBase):
         self._bind = bind
         self._returning = returning
         if whereclause is not None:
-            self._whereclause = _literal_as_text(whereclause)
+            self._whereclause = coercions.expect(
+                roles.WhereHavingRole, whereclause
+            )
         else:
             self._whereclause = None
         self.inline = inline
@@ -785,10 +793,13 @@ class Update(ValuesBase):
         """
         if self._whereclause is not None:
             self._whereclause = and_(
-                self._whereclause, _literal_as_text(whereclause)
+                self._whereclause,
+                coercions.expect(roles.WhereHavingRole, whereclause),
             )
         else:
-            self._whereclause = _literal_as_text(whereclause)
+            self._whereclause = coercions.expect(
+                roles.WhereHavingRole, whereclause
+            )
 
     @property
     def _extra_froms(self):
@@ -846,7 +857,7 @@ class Delete(UpdateBase):
             users.delete().where(
                     users.c.name==select([addresses.c.email_address]).\
                                 where(addresses.c.user_id==users.c.id).\
-                                as_scalar()
+                                scalar_subquery()
                     )
 
          .. versionchanged:: 1.2.0
@@ -858,14 +869,16 @@ class Delete(UpdateBase):
 
         """
         self._bind = bind
-        self.table = _interpret_as_from(table)
+        self.table = coercions.expect(roles.FromClauseRole, table)
         self._returning = returning
 
         if prefixes:
             self._setup_prefixes(prefixes)
 
         if whereclause is not None:
-            self._whereclause = _literal_as_text(whereclause)
+            self._whereclause = coercions.expect(
+                roles.WhereHavingRole, whereclause
+            )
         else:
             self._whereclause = None
 
@@ -883,10 +896,13 @@ class Delete(UpdateBase):
 
         if self._whereclause is not None:
             self._whereclause = and_(
-                self._whereclause, _literal_as_text(whereclause)
+                self._whereclause,
+                coercions.expect(roles.WhereHavingRole, whereclause),
             )
         else:
-            self._whereclause = _literal_as_text(whereclause)
+            self._whereclause = coercions.expect(
+                roles.WhereHavingRole, whereclause
+            )
 
     @property
     def _extra_froms(self):
index e634e5a367d53d12cec7c2c116ed24384e1c8bf6..a333303ec25815f323a203a6cd3c833f7af19a33 100644 (file)
 from __future__ import unicode_literals
 
 import itertools
-import numbers
 import operator
 import re
 
 from . import clause_compare
+from . import coercions
 from . import operators
+from . import roles
 from . import type_api
 from .annotation import Annotated
 from .base import _generative
@@ -26,6 +27,7 @@ from .base import Executable
 from .base import Immutable
 from .base import NO_ARG
 from .base import PARSE_AUTOCOMMIT
+from .coercions import _document_text_coercion
 from .visitors import cloned_traverse
 from .visitors import traverse
 from .visitors import Visitable
@@ -38,20 +40,6 @@ def _clone(element, **kw):
     return element._clone()
 
 
-def _document_text_coercion(paramname, meth_rst, param_rst):
-    return util.add_parameter_text(
-        paramname,
-        (
-            ".. warning:: "
-            "The %s argument to %s can be passed as a Python string argument, "
-            "which will be treated "
-            "as **trusted SQL text** and rendered as given.  **DO NOT PASS "
-            "UNTRUSTED INPUT TO THIS PARAMETER**."
-        )
-        % (param_rst, meth_rst),
-    )
-
-
 def collate(expression, collation):
     """Return the clause ``expression COLLATE collation``.
 
@@ -71,7 +59,7 @@ def collate(expression, collation):
 
     """
 
-    expr = _literal_as_binds(expression)
+    expr = coercions.expect(roles.ExpressionElementRole, expression)
     return BinaryExpression(
         expr, CollationClause(collation), operators.collate, type_=expr.type
     )
@@ -127,7 +115,7 @@ def between(expr, lower_bound, upper_bound, symmetric=False):
         :meth:`.ColumnElement.between`
 
     """
-    expr = _literal_as_binds(expr)
+    expr = coercions.expect(roles.ExpressionElementRole, expr)
     return expr.between(lower_bound, upper_bound, symmetric=symmetric)
 
 
@@ -172,11 +160,11 @@ def not_(clause):
     same result.
 
     """
-    return operators.inv(_literal_as_binds(clause))
+    return operators.inv(coercions.expect(roles.ExpressionElementRole, clause))
 
 
 @inspection._self_inspects
-class ClauseElement(Visitable):
+class ClauseElement(roles.SQLRole, Visitable):
     """Base class for elements of a programmatically constructed SQL
     expression.
 
@@ -188,13 +176,20 @@ class ClauseElement(Visitable):
     supports_execution = False
     _from_objects = []
     bind = None
+    description = None
     _is_clone_of = None
-    is_selectable = False
+
     is_clause_element = True
+    is_selectable = False
 
-    description = None
-    _order_by_label_element = None
+    _is_textual = False
+    _is_from_clause = False
+    _is_returns_rows = False
+    _is_text_clause = False
     _is_from_container = False
+    _is_select_statement = False
+
+    _order_by_label_element = None
 
     def _clone(self):
         """Create a shallow copy of this ClauseElement.
@@ -238,7 +233,7 @@ class ClauseElement(Visitable):
 
 
         """
-        raise NotImplementedError(self.__class__)
+        raise NotImplementedError()
 
     @property
     def _constructor(self):
@@ -394,6 +389,7 @@ class ClauseElement(Visitable):
         return []
 
     def self_group(self, against=None):
+        # type: (Optional[Any]) -> ClauseElement
         """Apply a 'grouping' to this :class:`.ClauseElement`.
 
         This method is overridden by subclasses to return a
@@ -553,7 +549,20 @@ class ClauseElement(Visitable):
             )
 
 
-class ColumnElement(operators.ColumnOperators, ClauseElement):
+class ColumnElement(
+    roles.ColumnArgumentOrKeyRole,
+    roles.StatementOptionRole,
+    roles.WhereHavingRole,
+    roles.BinaryElementRole,
+    roles.OrderByRole,
+    roles.ColumnsClauseRole,
+    roles.LimitOffsetRole,
+    roles.DMLColumnRole,
+    roles.DDLConstraintColumnRole,
+    roles.DDLExpressionRole,
+    operators.ColumnOperators,
+    ClauseElement,
+):
     """Represent a column-oriented SQL expression suitable for usage in the
     "columns" clause, WHERE clause etc. of a statement.
 
@@ -586,17 +595,13 @@ class ColumnElement(operators.ColumnOperators, ClauseElement):
           :class:`.TypeEngine` objects) are applied to the value.
 
         * any special object value, typically ORM-level constructs, which
-          feature a method called ``__clause_element__()``.  The Core
+          feature an accessor called ``__clause_element__()``.  The Core
           expression system looks for this method when an object of otherwise
           unknown type is passed to a function that is looking to coerce the
-          argument into a :class:`.ColumnElement` expression.  The
-          ``__clause_element__()`` method, if present, should return a
-          :class:`.ColumnElement` instance.  The primary use of
-          ``__clause_element__()`` within SQLAlchemy is that of class-bound
-          attributes on ORM-mapped classes; a ``User`` class which contains a
-          mapped attribute named ``.name`` will have a method
-          ``User.name.__clause_element__()`` which when invoked returns the
-          :class:`.Column` called ``name`` associated with the mapped table.
+          argument into a :class:`.ColumnElement` and sometimes a
+          :class:`.SelectBase` expression.   It is used within the ORM to
+          convert from ORM-specific objects like mapped classes and
+          mapped attributes into Core expression objects.
 
         * The Python ``None`` value is typically interpreted as ``NULL``,
           which in SQLAlchemy Core produces an instance of :func:`.null`.
@@ -702,6 +707,7 @@ class ColumnElement(operators.ColumnOperators, ClauseElement):
     _alt_names = ()
 
     def self_group(self, against=None):
+        # type: (Module, Module, Optional[Any]) -> ClauseEleent
         if (
             against in (operators.and_, operators.or_, operators._asbool)
             and self.type._type_affinity is type_api.BOOLEANTYPE._type_affinity
@@ -826,7 +832,9 @@ class ColumnElement(operators.ColumnOperators, ClauseElement):
         else:
             key = name
         co = ColumnClause(
-            _as_truncated(name) if name_is_truncatable else name,
+            coercions.expect(roles.TruncatedLabelRole, name)
+            if name_is_truncatable
+            else name,
             type_=getattr(self, "type", None),
             _selectable=selectable,
         )
@@ -878,7 +886,7 @@ class ColumnElement(operators.ColumnOperators, ClauseElement):
         )
 
 
-class BindParameter(ColumnElement):
+class BindParameter(roles.InElementRole, ColumnElement):
     r"""Represent a "bound expression".
 
     :class:`.BindParameter` is invoked explicitly using the
@@ -1235,7 +1243,8 @@ class BindParameter(ColumnElement):
                     "bindparams collection argument required for _cache_key "
                     "implementation.  Bound parameter cache keys are not safe "
                     "to use without accommodating for the value or callable "
-                    "within the parameter itself.")
+                    "within the parameter itself."
+                )
         else:
             bindparams.append(self)
         return (BindParameter, self.type._cache_key, self._orig_key)
@@ -1282,7 +1291,20 @@ class TypeClause(ClauseElement):
         return (TypeClause, self.type._cache_key)
 
 
-class TextClause(Executable, ClauseElement):
+class TextClause(
+    roles.DDLConstraintColumnRole,
+    roles.DDLExpressionRole,
+    roles.StatementOptionRole,
+    roles.WhereHavingRole,
+    roles.OrderByRole,
+    roles.FromClauseRole,
+    roles.SelectStatementRole,
+    roles.CoerceTextStatementRole,
+    roles.BinaryElementRole,
+    roles.InElementRole,
+    Executable,
+    ClauseElement,
+):
     """Represent a literal SQL text fragment.
 
     E.g.::
@@ -1304,6 +1326,10 @@ class TextClause(Executable, ClauseElement):
 
     __visit_name__ = "textclause"
 
+    _is_text_clause = True
+
+    _is_textual = True
+
     _bind_params_regex = re.compile(r"(?<![:\w\x5c]):(\w+)(?!:)", re.UNICODE)
     _execution_options = Executable._execution_options.union(
         {"autocommit": PARSE_AUTOCOMMIT}
@@ -1318,20 +1344,16 @@ class TextClause(Executable, ClauseElement):
     def _select_iterable(self):
         return (self,)
 
-    @property
-    def selectable(self):
-        # allows text() to be considered by
-        # _interpret_as_from
-        return self
-
-    _hide_froms = []
-
     # help in those cases where text() is
     # interpreted in a column expression situation
     key = _label = _resolve_label = None
 
     _allow_label_resolve = False
 
+    @property
+    def _hide_froms(self):
+        return []
+
     def __init__(self, text, bind=None):
         self._bind = bind
         self._bindparams = {}
@@ -1670,7 +1692,6 @@ class TextClause(Executable, ClauseElement):
 
 
         """
-
         positional_input_cols = [
             ColumnClause(col.key, types.pop(col.key))
             if col.key in types
@@ -1696,6 +1717,7 @@ class TextClause(Executable, ClauseElement):
         return self.type.comparator_factory(self)
 
     def self_group(self, against=None):
+        # type: (Optional[Any]) -> Union[Grouping, TextClause]
         if against is operators.in_op:
             return Grouping(self)
         else:
@@ -1715,7 +1737,7 @@ class TextClause(Executable, ClauseElement):
         )
 
 
-class Null(ColumnElement):
+class Null(roles.ConstExprRole, ColumnElement):
     """Represent the NULL keyword in a SQL statement.
 
     :class:`.Null` is accessed as a constant via the
@@ -1739,7 +1761,7 @@ class Null(ColumnElement):
         return (Null,)
 
 
-class False_(ColumnElement):
+class False_(roles.ConstExprRole, ColumnElement):
     """Represent the ``false`` keyword, or equivalent, in a SQL statement.
 
     :class:`.False_` is accessed as a constant via the
@@ -1798,7 +1820,7 @@ class False_(ColumnElement):
         return (False_,)
 
 
-class True_(ColumnElement):
+class True_(roles.ConstExprRole, ColumnElement):
     """Represent the ``true`` keyword, or equivalent, in a SQL statement.
 
     :class:`.True_` is accessed as a constant via the
@@ -1864,7 +1886,12 @@ class True_(ColumnElement):
         return (True_,)
 
 
-class ClauseList(ClauseElement):
+class ClauseList(
+    roles.InElementRole,
+    roles.OrderByRole,
+    roles.ColumnsClauseRole,
+    ClauseElement,
+):
     """Describe a list of clauses, separated by an operator.
 
     By default, is comma-separated, such as a column listing.
@@ -1877,16 +1904,22 @@ class ClauseList(ClauseElement):
         self.operator = kwargs.pop("operator", operators.comma_op)
         self.group = kwargs.pop("group", True)
         self.group_contents = kwargs.pop("group_contents", True)
-        text_converter = kwargs.pop(
-            "_literal_as_text", _expression_literal_as_text
+
+        self._text_converter_role = text_converter_role = kwargs.pop(
+            "_literal_as_text_role", roles.WhereHavingRole
         )
         if self.group_contents:
             self.clauses = [
-                text_converter(clause).self_group(against=self.operator)
+                coercions.expect(text_converter_role, clause).self_group(
+                    against=self.operator
+                )
                 for clause in clauses
             ]
         else:
-            self.clauses = [text_converter(clause) for clause in clauses]
+            self.clauses = [
+                coercions.expect(text_converter_role, clause)
+                for clause in clauses
+            ]
         self._is_implicitly_boolean = operators.is_boolean(self.operator)
 
     def __iter__(self):
@@ -1902,10 +1935,14 @@ class ClauseList(ClauseElement):
     def append(self, clause):
         if self.group_contents:
             self.clauses.append(
-                _literal_as_text(clause).self_group(against=self.operator)
+                coercions.expect(self._text_converter_role, clause).self_group(
+                    against=self.operator
+                )
             )
         else:
-            self.clauses.append(_literal_as_text(clause))
+            self.clauses.append(
+                coercions.expect(self._text_converter_role, clause)
+            )
 
     def _copy_internals(self, clone=_clone, **kw):
         self.clauses = [clone(clause, **kw) for clause in self.clauses]
@@ -1923,6 +1960,7 @@ class ClauseList(ClauseElement):
         return list(itertools.chain(*[c._from_objects for c in self.clauses]))
 
     def self_group(self, against=None):
+        # type: (Optional[Any]) -> ClauseElement
         if self.group and operators.is_precedent(self.operator, against):
             return Grouping(self)
         else:
@@ -1947,7 +1985,7 @@ class BooleanClauseList(ClauseList, ColumnElement):
         convert_clauses = []
 
         clauses = [
-            _expression_literal_as_text(clause)
+            coercions.expect(roles.WhereHavingRole, clause)
             for clause in util.coerce_generator_arg(clauses)
         ]
         for clause in clauses:
@@ -2055,6 +2093,7 @@ class BooleanClauseList(ClauseList, ColumnElement):
         return (self,)
 
     def self_group(self, against=None):
+        # type: (Optional[Any]) -> ClauseElement
         if not self.clauses:
             return self
         else:
@@ -2092,7 +2131,9 @@ class Tuple(ClauseList, ColumnElement):
 
         """
 
-        clauses = [_literal_as_binds(c) for c in clauses]
+        clauses = [
+            coercions.expect(roles.ExpressionElementRole, c) for c in clauses
+        ]
         self._type_tuple = [arg.type for arg in clauses]
         self.type = kw.pop(
             "type_",
@@ -2283,12 +2324,20 @@ class Case(ColumnElement):
 
         if value is not None:
             whenlist = [
-                (_literal_as_binds(c).self_group(), _literal_as_binds(r))
+                (
+                    coercions.expect(
+                        roles.ExpressionElementRole, c
+                    ).self_group(),
+                    coercions.expect(roles.ExpressionElementRole, r),
+                )
                 for (c, r) in whens
             ]
         else:
             whenlist = [
-                (_no_literals(c).self_group(), _literal_as_binds(r))
+                (
+                    coercions.expect(roles.ColumnArgumentRole, c).self_group(),
+                    coercions.expect(roles.ExpressionElementRole, r),
+                )
                 for (c, r) in whens
             ]
 
@@ -2300,12 +2349,12 @@ class Case(ColumnElement):
         if value is None:
             self.value = None
         else:
-            self.value = _literal_as_binds(value)
+            self.value = coercions.expect(roles.ExpressionElementRole, value)
 
         self.type = type_
         self.whens = whenlist
         if else_ is not None:
-            self.else_ = _literal_as_binds(else_)
+            self.else_ = coercions.expect(roles.ExpressionElementRole, else_)
         else:
             self.else_ = None
 
@@ -2455,7 +2504,9 @@ class Cast(ColumnElement):
 
         """
         self.type = type_api.to_instance(type_)
-        self.clause = _literal_as_binds(expression, type_=self.type)
+        self.clause = coercions.expect(
+            roles.ExpressionElementRole, expression, type_=self.type
+        )
         self.typeclause = TypeClause(self.type)
 
     def _copy_internals(self, clone=_clone, **kw):
@@ -2557,7 +2608,9 @@ class TypeCoerce(ColumnElement):
 
         """
         self.type = type_api.to_instance(type_)
-        self.clause = _literal_as_binds(expression, type_=self.type)
+        self.clause = coercions.expect(
+            roles.ExpressionElementRole, expression, type_=self.type
+        )
 
     def _copy_internals(self, clone=_clone, **kw):
         self.clause = clone(self.clause, **kw)
@@ -2598,7 +2651,7 @@ class Extract(ColumnElement):
         """
         self.type = type_api.INTEGERTYPE
         self.field = field
-        self.expr = _literal_as_binds(expr, None)
+        self.expr = coercions.expect(roles.ExpressionElementRole, expr)
 
     def _copy_internals(self, clone=_clone, **kw):
         self.expr = clone(self.expr, **kw)
@@ -2733,7 +2786,7 @@ class UnaryExpression(ColumnElement):
 
         """
         return UnaryExpression(
-            _literal_as_label_reference(column),
+            coercions.expect(roles.ByOfRole, column),
             modifier=operators.nullsfirst_op,
             wraps_column_expression=False,
         )
@@ -2776,7 +2829,7 @@ class UnaryExpression(ColumnElement):
 
         """
         return UnaryExpression(
-            _literal_as_label_reference(column),
+            coercions.expect(roles.ByOfRole, column),
             modifier=operators.nullslast_op,
             wraps_column_expression=False,
         )
@@ -2817,7 +2870,7 @@ class UnaryExpression(ColumnElement):
 
         """
         return UnaryExpression(
-            _literal_as_label_reference(column),
+            coercions.expect(roles.ByOfRole, column),
             modifier=operators.desc_op,
             wraps_column_expression=False,
         )
@@ -2857,7 +2910,7 @@ class UnaryExpression(ColumnElement):
 
         """
         return UnaryExpression(
-            _literal_as_label_reference(column),
+            coercions.expect(roles.ByOfRole, column),
             modifier=operators.asc_op,
             wraps_column_expression=False,
         )
@@ -2898,7 +2951,7 @@ class UnaryExpression(ColumnElement):
             :data:`.func`
 
         """
-        expr = _literal_as_binds(expr)
+        expr = coercions.expect(roles.ExpressionElementRole, expr)
         return UnaryExpression(
             expr,
             operator=operators.distinct_op,
@@ -2953,6 +3006,7 @@ class UnaryExpression(ColumnElement):
             return ClauseElement._negate(self)
 
     def self_group(self, against=None):
+        # type: (Optional[Any]) -> ClauseElement
         if self.operator and operators.is_precedent(self.operator, against):
             return Grouping(self)
         else:
@@ -2990,10 +3044,8 @@ class CollectionAggregate(UnaryExpression):
 
         """
 
-        expr = _literal_as_binds(expr)
+        expr = coercions.expect(roles.ExpressionElementRole, expr)
 
-        if expr.is_selectable and hasattr(expr, "as_scalar"):
-            expr = expr.as_scalar()
         expr = expr.self_group()
         return CollectionAggregate(
             expr,
@@ -3023,9 +3075,7 @@ class CollectionAggregate(UnaryExpression):
 
         """
 
-        expr = _literal_as_binds(expr)
-        if expr.is_selectable and hasattr(expr, "as_scalar"):
-            expr = expr.as_scalar()
+        expr = coercions.expect(roles.ExpressionElementRole, expr)
         expr = expr.self_group()
         return CollectionAggregate(
             expr,
@@ -3064,6 +3114,7 @@ class AsBoolean(UnaryExpression):
         self._is_implicitly_boolean = element._is_implicitly_boolean
 
     def self_group(self, against=None):
+        # type: (Optional[Any]) -> ClauseElement
         return self
 
     def _cache_key(self, **kw):
@@ -3155,6 +3206,8 @@ class BinaryExpression(ColumnElement):
         )
 
     def self_group(self, against=None):
+        # type: (Optional[Any]) -> ClauseElement
+
         if operators.is_precedent(self.operator, against):
             return Grouping(self)
         else:
@@ -3191,6 +3244,7 @@ class Slice(ColumnElement):
         self.type = type_api.NULLTYPE
 
     def self_group(self, against=None):
+        # type: (Optional[Any]) -> ClauseElement
         assert against is operator.getitem
         return self
 
@@ -3215,6 +3269,7 @@ class Grouping(ColumnElement):
         self.type = getattr(element, "type", type_api.NULLTYPE)
 
     def self_group(self, against=None):
+        # type: (Optional[Any]) -> ClauseElement
         return self
 
     @util.memoized_property
@@ -3363,13 +3418,12 @@ class Over(ColumnElement):
         self.element = element
         if order_by is not None:
             self.order_by = ClauseList(
-                *util.to_list(order_by),
-                _literal_as_text=_literal_as_label_reference
+                *util.to_list(order_by), _literal_as_text_role=roles.ByOfRole
             )
         if partition_by is not None:
             self.partition_by = ClauseList(
                 *util.to_list(partition_by),
-                _literal_as_text=_literal_as_label_reference
+                _literal_as_text_role=roles.ByOfRole
             )
 
         if range_:
@@ -3534,8 +3588,7 @@ class WithinGroup(ColumnElement):
         self.element = element
         if order_by is not None:
             self.order_by = ClauseList(
-                *util.to_list(order_by),
-                _literal_as_text=_literal_as_label_reference
+                *util.to_list(order_by), _literal_as_text_role=roles.ByOfRole
             )
 
     def over(self, partition_by=None, order_by=None, range_=None, rows=None):
@@ -3658,7 +3711,7 @@ class FunctionFilter(ColumnElement):
         """
 
         for criterion in list(criterion):
-            criterion = _expression_literal_as_text(criterion)
+            criterion = coercions.expect(roles.WhereHavingRole, criterion)
 
             if self.criterion is not None:
                 self.criterion = self.criterion & criterion
@@ -3727,7 +3780,7 @@ class FunctionFilter(ColumnElement):
         )
 
 
-class Label(ColumnElement):
+class Label(roles.LabeledColumnExprRole, ColumnElement):
     """Represents a column label (AS).
 
     Represent a label, as typically applied to any column-level
@@ -3801,6 +3854,7 @@ class Label(ColumnElement):
         return self._element.self_group(against=operators.as_)
 
     def self_group(self, against=None):
+        # type: (Optional[Any]) -> ClauseElement
         return self._apply_to_inner(self._element.self_group, against=against)
 
     def _negate(self):
@@ -3849,7 +3903,7 @@ class Label(ColumnElement):
         return e
 
 
-class ColumnClause(Immutable, ColumnElement):
+class ColumnClause(roles.LabeledColumnExprRole, Immutable, ColumnElement):
     """Represents a column expression from any textual string.
 
     The :class:`.ColumnClause`, a lightweight analogue to the
@@ -3985,14 +4039,14 @@ class ColumnClause(Immutable, ColumnElement):
         if (
             self.is_literal
             or self.table is None
-            or self.table._textual
+            or self.table._is_textual
             or not hasattr(other, "proxy_set")
             or (
                 isinstance(other, ColumnClause)
                 and (
                     other.is_literal
                     or other.table is None
-                    or other.table._textual
+                    or other.table._is_textual
                 )
             )
         ):
@@ -4083,7 +4137,7 @@ class ColumnClause(Immutable, ColumnElement):
                     counter += 1
                 label = _label
 
-            return _as_truncated(label)
+            return coercions.expect(roles.TruncatedLabelRole, label)
 
         else:
             return name
@@ -4110,7 +4164,7 @@ class ColumnClause(Immutable, ColumnElement):
         # otherwise its considered to be a label
         is_literal = self.is_literal and (name is None or name == self.name)
         c = self._constructor(
-            _as_truncated(name or self.name)
+            coercions.expect(roles.TruncatedLabelRole, name or self.name)
             if name_is_truncatable
             else (name or self.name),
             type_=self.type,
@@ -4250,6 +4304,108 @@ class quoted_name(util.MemoizedSlots, util.text_type):
         return "'%s'" % backslashed
 
 
+def _expand_cloned(elements):
+    """expand the given set of ClauseElements to be the set of all 'cloned'
+    predecessors.
+
+    """
+    return itertools.chain(*[x._cloned_set for x in elements])
+
+
+def _select_iterables(elements):
+    """expand tables into individual columns in the
+    given list of column expressions.
+
+    """
+    return itertools.chain(*[c._select_iterable for c in elements])
+
+
+def _cloned_intersection(a, b):
+    """return the intersection of sets a and b, counting
+    any overlap between 'cloned' predecessors.
+
+    The returned set is in terms of the entities present within 'a'.
+
+    """
+    all_overlap = set(_expand_cloned(a)).intersection(_expand_cloned(b))
+    return set(
+        elem for elem in a if all_overlap.intersection(elem._cloned_set)
+    )
+
+
+def _cloned_difference(a, b):
+    all_overlap = set(_expand_cloned(a)).intersection(_expand_cloned(b))
+    return set(
+        elem for elem in a if not all_overlap.intersection(elem._cloned_set)
+    )
+
+
+def _find_columns(clause):
+    """locate Column objects within the given expression."""
+
+    cols = util.column_set()
+    traverse(clause, {}, {"column": cols.add})
+    return cols
+
+
+def _type_from_args(args):
+    for a in args:
+        if not a.type._isnull:
+            return a.type
+    else:
+        return type_api.NULLTYPE
+
+
+def _corresponding_column_or_error(fromclause, column, require_embedded=False):
+    c = fromclause.corresponding_column(
+        column, require_embedded=require_embedded
+    )
+    if c is None:
+        raise exc.InvalidRequestError(
+            "Given column '%s', attached to table '%s', "
+            "failed to locate a corresponding column from table '%s'"
+            % (column, getattr(column, "table", None), fromclause.description)
+        )
+    return c
+
+
+class AnnotatedColumnElement(Annotated):
+    def __init__(self, element, values):
+        Annotated.__init__(self, element, values)
+        ColumnElement.comparator._reset(self)
+        for attr in ("name", "key", "table"):
+            if self.__dict__.get(attr, False) is None:
+                self.__dict__.pop(attr)
+
+    def _with_annotations(self, values):
+        clone = super(AnnotatedColumnElement, self)._with_annotations(values)
+        ColumnElement.comparator._reset(clone)
+        return clone
+
+    @util.memoized_property
+    def name(self):
+        """pull 'name' from parent, if not present"""
+        return self._Annotated__element.name
+
+    @util.memoized_property
+    def table(self):
+        """pull 'table' from parent, if not present"""
+        return self._Annotated__element.table
+
+    @util.memoized_property
+    def key(self):
+        """pull 'key' from parent, if not present"""
+        return self._Annotated__element.key
+
+    @util.memoized_property
+    def info(self):
+        return self._Annotated__element.info
+
+    @util.memoized_property
+    def anon_label(self):
+        return self._Annotated__element.anon_label
+
+
 class _truncated_label(quoted_name):
     """A unicode subclass used to identify symbolic "
     "names that may require truncation."""
@@ -4378,349 +4534,3 @@ class _anonymous_label(_truncated_label):
         else:
             # else skip the constructor call
             return self % map_
-
-
-def _as_truncated(value):
-    """coerce the given value to :class:`._truncated_label`.
-
-    Existing :class:`._truncated_label` and
-    :class:`._anonymous_label` objects are passed
-    unchanged.
-    """
-
-    if isinstance(value, _truncated_label):
-        return value
-    else:
-        return _truncated_label(value)
-
-
-def _string_or_unprintable(element):
-    if isinstance(element, util.string_types):
-        return element
-    else:
-        try:
-            return str(element)
-        except Exception:
-            return "unprintable element %r" % element
-
-
-def _expand_cloned(elements):
-    """expand the given set of ClauseElements to be the set of all 'cloned'
-    predecessors.
-
-    """
-    return itertools.chain(*[x._cloned_set for x in elements])
-
-
-def _select_iterables(elements):
-    """expand tables into individual columns in the
-    given list of column expressions.
-
-    """
-    return itertools.chain(*[c._select_iterable for c in elements])
-
-
-def _cloned_intersection(a, b):
-    """return the intersection of sets a and b, counting
-    any overlap between 'cloned' predecessors.
-
-    The returned set is in terms of the entities present within 'a'.
-
-    """
-    all_overlap = set(_expand_cloned(a)).intersection(_expand_cloned(b))
-    return set(
-        elem for elem in a if all_overlap.intersection(elem._cloned_set)
-    )
-
-
-def _cloned_difference(a, b):
-    all_overlap = set(_expand_cloned(a)).intersection(_expand_cloned(b))
-    return set(
-        elem for elem in a if not all_overlap.intersection(elem._cloned_set)
-    )
-
-
-@util.dependencies("sqlalchemy.sql.functions")
-def _labeled(functions, element):
-    if not hasattr(element, "name") or isinstance(
-        element, functions.FunctionElement
-    ):
-        return element.label(None)
-    else:
-        return element
-
-
-def _is_column(col):
-    """True if ``col`` is an instance of :class:`.ColumnElement`."""
-
-    return isinstance(col, ColumnElement)
-
-
-def _find_columns(clause):
-    """locate Column objects within the given expression."""
-
-    cols = util.column_set()
-    traverse(clause, {}, {"column": cols.add})
-    return cols
-
-
-# there is some inconsistency here between the usage of
-# inspect() vs. checking for Visitable and __clause_element__.
-# Ideally all functions here would derive from inspect(),
-# however the inspect() versions add significant callcount
-# overhead for critical functions like _interpret_as_column_or_from().
-# Generally, the column-based functions are more performance critical
-# and are fine just checking for __clause_element__().  It is only
-# _interpret_as_from() where we'd like to be able to receive ORM entities
-# that have no defined namespace, hence inspect() is needed there.
-
-
-def _column_as_key(element):
-    if isinstance(element, util.string_types):
-        return element
-    if hasattr(element, "__clause_element__"):
-        element = element.__clause_element__()
-    try:
-        return element.key
-    except AttributeError:
-        return None
-
-
-def _clause_element_as_expr(element):
-    if hasattr(element, "__clause_element__"):
-        return element.__clause_element__()
-    else:
-        return element
-
-
-def _literal_as_label_reference(element):
-    if isinstance(element, util.string_types):
-        return _textual_label_reference(element)
-
-    elif hasattr(element, "__clause_element__"):
-        element = element.__clause_element__()
-
-    return _literal_as_text(element)
-
-
-def _literal_and_labels_as_label_reference(element):
-    if isinstance(element, util.string_types):
-        return _textual_label_reference(element)
-
-    elif hasattr(element, "__clause_element__"):
-        element = element.__clause_element__()
-
-    if (
-        isinstance(element, ColumnElement)
-        and element._order_by_label_element is not None
-    ):
-        return _label_reference(element)
-    else:
-        return _literal_as_text(element)
-
-
-def _expression_literal_as_text(element):
-    return _literal_as_text(element)
-
-
-def _literal_as(element, text_fallback):
-    if isinstance(element, Visitable):
-        return element
-    elif hasattr(element, "__clause_element__"):
-        return element.__clause_element__()
-    elif isinstance(element, util.string_types):
-        return text_fallback(element)
-    elif isinstance(element, (util.NoneType, bool)):
-        return _const_expr(element)
-    else:
-        raise exc.ArgumentError(
-            "SQL expression object expected, got object of type %r "
-            "instead" % type(element)
-        )
-
-
-def _literal_as_text(element, allow_coercion_to_text=False):
-    if allow_coercion_to_text:
-        return _literal_as(element, TextClause)
-    else:
-        return _literal_as(element, _no_text_coercion)
-
-
-def _literal_as_column(element):
-    return _literal_as(element, ColumnClause)
-
-
-def _no_column_coercion(element):
-    element = str(element)
-    guess_is_literal = not _guess_straight_column.match(element)
-    raise exc.ArgumentError(
-        "Textual column expression %(column)r should be "
-        "explicitly declared with text(%(column)r), "
-        "or use %(literal_column)s(%(column)r) "
-        "for more specificity"
-        % {
-            "column": util.ellipses_string(element),
-            "literal_column": "literal_column"
-            if guess_is_literal
-            else "column",
-        }
-    )
-
-
-def _no_text_coercion(element, exc_cls=exc.ArgumentError, extra=None):
-    raise exc_cls(
-        "%(extra)sTextual SQL expression %(expr)r should be "
-        "explicitly declared as text(%(expr)r)"
-        % {
-            "expr": util.ellipses_string(element),
-            "extra": "%s " % extra if extra else "",
-        }
-    )
-
-
-def _no_literals(element):
-    if hasattr(element, "__clause_element__"):
-        return element.__clause_element__()
-    elif not isinstance(element, Visitable):
-        raise exc.ArgumentError(
-            "Ambiguous literal: %r.  Use the 'text()' "
-            "function to indicate a SQL expression "
-            "literal, or 'literal()' to indicate a "
-            "bound value." % (element,)
-        )
-    else:
-        return element
-
-
-def _is_literal(element):
-    return not isinstance(element, Visitable) and not hasattr(
-        element, "__clause_element__"
-    )
-
-
-def _only_column_elements_or_none(element, name):
-    if element is None:
-        return None
-    else:
-        return _only_column_elements(element, name)
-
-
-def _only_column_elements(element, name):
-    if hasattr(element, "__clause_element__"):
-        element = element.__clause_element__()
-    if not isinstance(element, ColumnElement):
-        raise exc.ArgumentError(
-            "Column-based expression object expected for argument "
-            "'%s'; got: '%s', type %s" % (name, element, type(element))
-        )
-    return element
-
-
-def _literal_as_binds(element, name=None, type_=None):
-    if hasattr(element, "__clause_element__"):
-        return element.__clause_element__()
-    elif not isinstance(element, Visitable):
-        if element is None:
-            return Null()
-        else:
-            return BindParameter(name, element, type_=type_, unique=True)
-    else:
-        return element
-
-
-_guess_straight_column = re.compile(r"^\w\S*$", re.I)
-
-
-def _interpret_as_column_or_from(element):
-    if isinstance(element, Visitable):
-        return element
-    elif hasattr(element, "__clause_element__"):
-        return element.__clause_element__()
-
-    insp = inspection.inspect(element, raiseerr=False)
-    if insp is None:
-        if isinstance(element, (util.NoneType, bool)):
-            return _const_expr(element)
-    elif hasattr(insp, "selectable"):
-        return insp.selectable
-
-    # be forgiving as this is an extremely common
-    # and known expression
-    if element == "*":
-        guess_is_literal = True
-    elif isinstance(element, (numbers.Number)):
-        return ColumnClause(str(element), is_literal=True)
-    else:
-        _no_column_coercion(element)
-    return ColumnClause(element, is_literal=guess_is_literal)
-
-
-def _const_expr(element):
-    if isinstance(element, (Null, False_, True_)):
-        return element
-    elif element is None:
-        return Null()
-    elif element is False:
-        return False_()
-    elif element is True:
-        return True_()
-    else:
-        raise exc.ArgumentError("Expected None, False, or True")
-
-
-def _type_from_args(args):
-    for a in args:
-        if not a.type._isnull:
-            return a.type
-    else:
-        return type_api.NULLTYPE
-
-
-def _corresponding_column_or_error(fromclause, column, require_embedded=False):
-    c = fromclause.corresponding_column(
-        column, require_embedded=require_embedded
-    )
-    if c is None:
-        raise exc.InvalidRequestError(
-            "Given column '%s', attached to table '%s', "
-            "failed to locate a corresponding column from table '%s'"
-            % (column, getattr(column, "table", None), fromclause.description)
-        )
-    return c
-
-
-class AnnotatedColumnElement(Annotated):
-    def __init__(self, element, values):
-        Annotated.__init__(self, element, values)
-        ColumnElement.comparator._reset(self)
-        for attr in ("name", "key", "table"):
-            if self.__dict__.get(attr, False) is None:
-                self.__dict__.pop(attr)
-
-    def _with_annotations(self, values):
-        clone = super(AnnotatedColumnElement, self)._with_annotations(values)
-        ColumnElement.comparator._reset(clone)
-        return clone
-
-    @util.memoized_property
-    def name(self):
-        """pull 'name' from parent, if not present"""
-        return self._Annotated__element.name
-
-    @util.memoized_property
-    def table(self):
-        """pull 'table' from parent, if not present"""
-        return self._Annotated__element.table
-
-    @util.memoized_property
-    def key(self):
-        """pull 'key' from parent, if not present"""
-        return self._Annotated__element.key
-
-    @util.memoized_property
-    def info(self):
-        return self._Annotated__element.info
-
-    @util.memoized_property
-    def anon_label(self):
-        return self._Annotated__element.anon_label
index f381879ce17d32714ba672b02e1b4c134bfbfa74..b04355cf5ddf9f3692ecacc660a19dcc2a029999 100644 (file)
@@ -67,7 +67,6 @@ __all__ = [
     "outerjoin",
     "over",
     "select",
-    "subquery",
     "table",
     "text",
     "tuple_",
@@ -92,22 +91,7 @@ from .dml import Insert  # noqa
 from .dml import Update  # noqa
 from .dml import UpdateBase  # noqa
 from .dml import ValuesBase  # noqa
-from .elements import _clause_element_as_expr  # noqa
-from .elements import _clone  # noqa
-from .elements import _cloned_difference  # noqa
-from .elements import _cloned_intersection  # noqa
-from .elements import _column_as_key  # noqa
-from .elements import _corresponding_column_or_error  # noqa
-from .elements import _expression_literal_as_text  # noqa
-from .elements import _is_column  # noqa
-from .elements import _labeled  # noqa
-from .elements import _literal_as_binds  # noqa
-from .elements import _literal_as_column  # noqa
-from .elements import _literal_as_label_reference  # noqa
-from .elements import _literal_as_text  # noqa
-from .elements import _only_column_elements  # noqa
 from .elements import _select_iterables  # noqa
-from .elements import _string_or_unprintable  # noqa
 from .elements import _truncated_label  # noqa
 from .elements import between  # noqa
 from .elements import BinaryExpression  # noqa
@@ -147,7 +131,6 @@ from .functions import func  # noqa
 from .functions import Function  # noqa
 from .functions import FunctionElement  # noqa
 from .functions import modifier  # noqa
-from .selectable import _interpret_as_from  # noqa
 from .selectable import Alias  # noqa
 from .selectable import CompoundSelect  # noqa
 from .selectable import CTE  # noqa
@@ -160,6 +143,7 @@ from .selectable import HasPrefixes  # noqa
 from .selectable import HasSuffixes  # noqa
 from .selectable import Join  # noqa
 from .selectable import Lateral  # noqa
+from .selectable import ReturnsRows  # noqa
 from .selectable import ScalarSelect  # noqa
 from .selectable import Select  # noqa
 from .selectable import Selectable  # noqa
@@ -171,7 +155,6 @@ from .selectable import TextAsFrom  # noqa
 from .visitors import Visitable  # noqa
 from ..util.langhelpers import public_factory  # noqa
 
-
 # factory functions - these pull class-bound constructors and classmethods
 # from SQL elements and selectables into public functions.  This allows
 # the functions to be available in the sqlalchemy.sql.* namespace and
index d0aa239881222d90d9dcb5ce08e74649b6e5de24..17378999860f6bc02007150a8ba1480fcbd235f8 100644 (file)
@@ -9,14 +9,15 @@
 
 """
 from . import annotation
+from . import coercions
 from . import operators
+from . import roles
 from . import schema
 from . import sqltypes
 from . import util as sqlutil
 from .base import ColumnCollection
 from .base import Executable
 from .elements import _clone
-from .elements import _literal_as_binds
 from .elements import _type_from_args
 from .elements import BinaryExpression
 from .elements import BindParameter
@@ -83,7 +84,12 @@ class FunctionElement(Executable, ColumnElement, FromClause):
         """Construct a :class:`.FunctionElement`.
         """
         args = [
-            _literal_as_binds(c, getattr(self, "name", None)) for c in clauses
+            coercions.expect(
+                roles.ExpressionElementRole,
+                c,
+                name=getattr(self, "name", None),
+            )
+            for c in clauses
         ]
         self._has_args = self._has_args or bool(args)
         self.clause_expr = ClauseList(
@@ -686,7 +692,12 @@ class GenericFunction(util.with_metaclass(_GenericMeta, Function)):
     def __init__(self, *args, **kwargs):
         parsed_args = kwargs.pop("_parsed_args", None)
         if parsed_args is None:
-            parsed_args = [_literal_as_binds(c, self.name) for c in args]
+            parsed_args = [
+                coercions.expect(
+                    roles.ExpressionElementRole, c, name=self.name
+                )
+                for c in args
+            ]
         self._has_args = self._has_args or bool(parsed_args)
         self.packagenames = []
         self._bind = kwargs.get("bind", None)
@@ -751,7 +762,10 @@ class ReturnTypeFromArgs(GenericFunction):
     """Define a function whose return type is the same as its arguments."""
 
     def __init__(self, *args, **kwargs):
-        args = [_literal_as_binds(c, self.name) for c in args]
+        args = [
+            coercions.expect(roles.ExpressionElementRole, c, name=self.name)
+            for c in args
+        ]
         kwargs.setdefault("type_", _type_from_args(args))
         kwargs["_parsed_args"] = args
         super(ReturnTypeFromArgs, self).__init__(*args, **kwargs)
@@ -880,7 +894,7 @@ class array_agg(GenericFunction):
     type = sqltypes.ARRAY
 
     def __init__(self, *args, **kwargs):
-        args = [_literal_as_binds(c) for c in args]
+        args = [coercions.expect(roles.ExpressionElementRole, c) for c in args]
 
         default_array_type = kwargs.pop("_default_array_type", sqltypes.ARRAY)
         if "type_" not in kwargs:
index 8479c1d5943f27f77d9caec3f61f186f13823c9d..b8bbb45252184e1380309ba2c81933aeb7096b61 100644 (file)
@@ -1053,7 +1053,7 @@ class ColumnOperators(Operators):
             expr = 5 == mytable.c.somearray.any_()
 
             # mysql '5 = ANY (SELECT value FROM table)'
-            expr = 5 == select([table.c.value]).as_scalar().any_()
+            expr = 5 == select([table.c.value]).scalar_subquery().any_()
 
         .. seealso::
 
@@ -1078,7 +1078,7 @@ class ColumnOperators(Operators):
             expr = 5 == mytable.c.somearray.all_()
 
             # mysql '5 = ALL (SELECT value FROM table)'
-            expr = 5 == select([table.c.value]).as_scalar().all_()
+            expr = 5 == select([table.c.value]).scalar_subquery().all_()
 
         .. seealso::
 
diff --git a/lib/sqlalchemy/sql/roles.py b/lib/sqlalchemy/sql/roles.py
new file mode 100644 (file)
index 0000000..2d3aaf9
--- /dev/null
@@ -0,0 +1,157 @@
+# sql/roles.py
+# Copyright (C) 2005-2019 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
+
+
+class SQLRole(object):
+    """Define a "role" within a SQL statement structure.
+
+    Classes within SQL Core participate within SQLRole hierarchies in order
+    to more accurately indicate where they may be used within SQL statements
+    of all types.
+
+    .. versionadded:: 1.4
+
+    """
+
+
+class UsesInspection(object):
+    pass
+
+
+class ColumnArgumentRole(SQLRole):
+    _role_name = "Column expression"
+
+
+class ColumnArgumentOrKeyRole(ColumnArgumentRole):
+    _role_name = "Column expression or string key"
+
+
+class ColumnListRole(SQLRole):
+    """Elements suitable for forming comma separated lists of expressions."""
+
+
+class TruncatedLabelRole(SQLRole):
+    _role_name = "String SQL identifier"
+
+
+class ColumnsClauseRole(UsesInspection, ColumnListRole):
+    _role_name = "Column expression or FROM clause"
+
+    @property
+    def _select_iterable(self):
+        raise NotImplementedError()
+
+
+class LimitOffsetRole(SQLRole):
+    _role_name = "LIMIT / OFFSET expression"
+
+
+class ByOfRole(ColumnListRole):
+    _role_name = "GROUP BY / OF / etc. expression"
+
+
+class OrderByRole(ByOfRole):
+    _role_name = "ORDER BY expression"
+
+
+class StructuralRole(SQLRole):
+    pass
+
+
+class StatementOptionRole(StructuralRole):
+    _role_name = "statement sub-expression element"
+
+
+class WhereHavingRole(StructuralRole):
+    _role_name = "SQL expression for WHERE/HAVING role"
+
+
+class ExpressionElementRole(SQLRole):
+    _role_name = "SQL expression element"
+
+
+class ConstExprRole(ExpressionElementRole):
+    _role_name = "Constant True/False/None expression"
+
+
+class LabeledColumnExprRole(ExpressionElementRole):
+    pass
+
+
+class BinaryElementRole(ExpressionElementRole):
+    _role_name = "SQL expression element or literal value"
+
+
+class InElementRole(SQLRole):
+    _role_name = (
+        "IN expression list, SELECT construct, or bound parameter object"
+    )
+
+
+class FromClauseRole(ColumnsClauseRole):
+    _role_name = "FROM expression, such as a Table or alias() object"
+
+    @property
+    def _hide_froms(self):
+        raise NotImplementedError()
+
+
+class CoerceTextStatementRole(SQLRole):
+    _role_name = "Executable SQL, text() construct, or string statement"
+
+
+class StatementRole(CoerceTextStatementRole):
+    _role_name = "Executable SQL or text() construct"
+
+
+class ReturnsRowsRole(StatementRole):
+    _role_name = (
+        "Row returning expression such as a SELECT, or an "
+        "INSERT/UPDATE/DELETE with RETURNING"
+    )
+
+
+class SelectStatementRole(ReturnsRowsRole):
+    _role_name = "SELECT construct or equivalent text() construct"
+
+
+class HasCTERole(ReturnsRowsRole):
+    pass
+
+
+class CompoundElementRole(SQLRole):
+    """SELECT statements inside a CompoundSelect, e.g. UNION, EXTRACT, etc."""
+
+    _role_name = (
+        "SELECT construct for inclusion in a UNION or other set construct"
+    )
+
+
+class DMLRole(StatementRole):
+    pass
+
+
+class DMLColumnRole(SQLRole):
+    _role_name = "SET/VALUES column expression or string key"
+
+
+class DMLSelectRole(SQLRole):
+    """A SELECT statement embedded in DML, typically INSERT from SELECT """
+
+    _role_name = "SELECT statement or equivalent textual object"
+
+
+class DDLRole(StatementRole):
+    pass
+
+
+class DDLExpressionRole(StructuralRole):
+    _role_name = "SQL expression element for DDL constraint"
+
+
+class DDLConstraintColumnRole(SQLRole):
+    _role_name = "String column name or column object for DDL constraint"
index b045e006e7d9e491bb2287247706906b7be75a71..62ff25a6461be192d483babc32d969f8c75cf1aa 100644 (file)
@@ -34,16 +34,16 @@ import collections
 import operator
 
 import sqlalchemy
+from . import coercions
 from . import ddl
+from . import roles
 from . import type_api
 from . import visitors
 from .base import _bind_or_error
 from .base import ColumnCollection
 from .base import DialectKWArgs
 from .base import SchemaEventTarget
-from .elements import _as_truncated
-from .elements import _document_text_coercion
-from .elements import _literal_as_text
+from .coercions import _document_text_coercion
 from .elements import ClauseElement
 from .elements import ColumnClause
 from .elements import ColumnElement
@@ -1583,7 +1583,9 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause):
             )
         try:
             c = self._constructor(
-                _as_truncated(name or self.name)
+                coercions.expect(
+                    roles.TruncatedLabelRole, name if name else self.name
+                )
                 if name_is_truncatable
                 else (name or self.name),
                 self.type,
@@ -2109,13 +2111,19 @@ class ForeignKey(DialectKWArgs, SchemaItem):
 
 
 class _NotAColumnExpr(object):
+    # the coercions system is not used in crud.py for the values passed in
+    # the insert().values() and update().values() methods, so the usual
+    # pathways to rejecting a coercion in the unlikely case of adding defaut
+    # generator objects to insert() or update() constructs aren't available;
+    # create a quick coercion rejection here that is specific to what crud.py
+    # calls on value objects.
     def _not_a_column_expr(self):
         raise exc.InvalidRequestError(
             "This %s cannot be used directly "
             "as a column expression." % self.__class__.__name__
         )
 
-    __clause_element__ = self_group = lambda self: self._not_a_column_expr()
+    self_group = lambda self: self._not_a_column_expr()  # noqa
     _from_objects = property(lambda self: self._not_a_column_expr())
 
 
@@ -2274,7 +2282,7 @@ class ColumnDefault(DefaultGenerator):
         return "ColumnDefault(%r)" % (self.arg,)
 
 
-class Sequence(DefaultGenerator):
+class Sequence(roles.StatementRole, DefaultGenerator):
     """Represents a named database sequence.
 
     The :class:`.Sequence` object represents the name and configurational
@@ -2759,25 +2767,6 @@ class ColumnCollectionMixin(object):
         if _autoattach and self._pending_colargs:
             self._check_attach()
 
-    @classmethod
-    def _extract_col_expression_collection(cls, expressions):
-        for expr in expressions:
-            strname = None
-            column = None
-            if hasattr(expr, "__clause_element__"):
-                expr = expr.__clause_element__()
-
-            if not isinstance(expr, (ColumnElement, TextClause)):
-                # this assumes a string
-                strname = expr
-            else:
-                cols = []
-                visitors.traverse(expr, {}, {"column": cols.append})
-                if cols:
-                    column = cols[0]
-            add_element = column if column is not None else strname
-            yield expr, column, strname, add_element
-
     def _check_attach(self, evt=False):
         col_objs = [c for c in self._pending_colargs if isinstance(c, Column)]
 
@@ -2960,7 +2949,7 @@ class CheckConstraint(ColumnCollectionConstraint):
 
         """
 
-        self.sqltext = _literal_as_text(sqltext, allow_coercion_to_text=True)
+        self.sqltext = coercions.expect(roles.DDLExpressionRole, sqltext)
 
         columns = []
         visitors.traverse(self.sqltext, {}, {"column": columns.append})
@@ -3630,7 +3619,9 @@ class Index(DialectKWArgs, ColumnCollectionMixin, SchemaItem):
             column,
             strname,
             add_element,
-        ) in self._extract_col_expression_collection(expressions):
+        ) in coercions.expect_col_expression_collection(
+            roles.DDLConstraintColumnRole, expressions
+        ):
             if add_element is not None:
                 columns.append(add_element)
             processed_expressions.append(expr)
index 5167182fe2dbcd83ada670ae62312bdae624a134..41be9fc5a3a3a7fc76ffb14cff98ea0be72e0aae 100644 (file)
@@ -15,8 +15,9 @@ import itertools
 import operator
 from operator import attrgetter
 
-from sqlalchemy.sql.visitors import Visitable
+from . import coercions
 from . import operators
+from . import roles
 from . import type_api
 from .annotation import Annotated
 from .base import _from_objects
@@ -26,18 +27,12 @@ from .base import ColumnSet
 from .base import Executable
 from .base import Generative
 from .base import Immutable
+from .coercions import _document_text_coercion
 from .elements import _anonymous_label
-from .elements import _clause_element_as_expr
 from .elements import _clone
 from .elements import _cloned_difference
 from .elements import _cloned_intersection
-from .elements import _document_text_coercion
 from .elements import _expand_cloned
-from .elements import _interpret_as_column_or_from
-from .elements import _literal_and_labels_as_label_reference
-from .elements import _literal_as_label_reference
-from .elements import _literal_as_text
-from .elements import _no_text_coercion
 from .elements import _select_iterables
 from .elements import and_
 from .elements import BindParameter
@@ -48,75 +43,15 @@ from .elements import literal_column
 from .elements import True_
 from .elements import UnaryExpression
 from .. import exc
-from .. import inspection
 from .. import util
 
 
-def _interpret_as_from(element):
-    insp = inspection.inspect(element, raiseerr=False)
-    if insp is None:
-        if isinstance(element, util.string_types):
-            _no_text_coercion(element)
-    try:
-        return insp.selectable
-    except AttributeError:
-        raise exc.ArgumentError("FROM expression expected")
-
-
-def _interpret_as_select(element):
-    element = _interpret_as_from(element)
-    if isinstance(element, Alias):
-        element = element.original
-    if not isinstance(element, SelectBase):
-        element = element.select()
-    return element
-
-
 class _OffsetLimitParam(BindParameter):
     @property
     def _limit_offset_value(self):
         return self.effective_value
 
 
-def _offset_or_limit_clause(element, name=None, type_=None):
-    """Convert the given value to an "offset or limit" clause.
-
-    This handles incoming integers and converts to an expression; if
-    an expression is already given, it is passed through.
-
-    """
-    if element is None:
-        return None
-    elif hasattr(element, "__clause_element__"):
-        return element.__clause_element__()
-    elif isinstance(element, Visitable):
-        return element
-    else:
-        value = util.asint(element)
-        return _OffsetLimitParam(name, value, type_=type_, unique=True)
-
-
-def _offset_or_limit_clause_asint(clause, attrname):
-    """Convert the "offset or limit" clause of a select construct to an
-    integer.
-
-    This is only possible if the value is stored as a simple bound parameter.
-    Otherwise, a compilation error is raised.
-
-    """
-    if clause is None:
-        return None
-    try:
-        value = clause._limit_offset_value
-    except AttributeError:
-        raise exc.CompileError(
-            "This SELECT structure does not use a simple "
-            "integer value for %s" % attrname
-        )
-    else:
-        return util.asint(value)
-
-
 def subquery(alias, *args, **kwargs):
     r"""Return an :class:`.Alias` object derived
     from a :class:`.Select`.
@@ -133,8 +68,42 @@ def subquery(alias, *args, **kwargs):
     return Select(*args, **kwargs).alias(alias)
 
 
-class Selectable(ClauseElement):
-    """mark a class as being selectable"""
+class ReturnsRows(roles.ReturnsRowsRole, ClauseElement):
+    """The basemost class for Core contructs that have some concept of
+    columns that can represent rows.
+
+    While the SELECT statement and TABLE are the primary things we think
+    of in this category,  DML like INSERT, UPDATE and DELETE can also specify
+    RETURNING which means they can be used in CTEs and other forms, and
+    PostgreSQL has functions that return rows also.
+
+    .. versionadded:: 1.4
+
+    """
+
+    _is_returns_rows = True
+
+    # sub-elements of returns_rows
+    _is_from_clause = False
+    _is_select_statement = False
+    _is_lateral = False
+
+    @property
+    def selectable(self):
+        raise NotImplementedError(
+            "This object is a base ReturnsRows object, but is not a "
+            "FromClause so has no .c. collection."
+        )
+
+
+class Selectable(ReturnsRows):
+    """mark a class as being selectable.
+
+    This class is legacy as of 1.4 as the concept of a SQL construct which
+    "returns rows" is more generalized than one which can be the subject
+    of a SELECT.
+
+    """
 
     __visit_name__ = "selectable"
 
@@ -190,7 +159,7 @@ class HasPrefixes(object):
     def _setup_prefixes(self, prefixes, dialect=None):
         self._prefixes = self._prefixes + tuple(
             [
-                (_literal_as_text(p, allow_coercion_to_text=True), dialect)
+                (coercions.expect(roles.StatementOptionRole, p), dialect)
                 for p in prefixes
             ]
         )
@@ -236,13 +205,13 @@ class HasSuffixes(object):
     def _setup_suffixes(self, suffixes, dialect=None):
         self._suffixes = self._suffixes + tuple(
             [
-                (_literal_as_text(p, allow_coercion_to_text=True), dialect)
+                (coercions.expect(roles.StatementOptionRole, p), dialect)
                 for p in suffixes
             ]
         )
 
 
-class FromClause(Selectable):
+class FromClause(roles.FromClauseRole, Selectable):
     """Represent an element that can be used within the ``FROM``
     clause of a ``SELECT`` statement.
 
@@ -265,16 +234,6 @@ class FromClause(Selectable):
     named_with_column = False
     _hide_froms = []
 
-    _is_join = False
-    _is_select = False
-    _is_from_container = False
-
-    _is_lateral = False
-
-    _textual = False
-    """a marker that allows us to easily distinguish a :class:`.TextAsFrom`
-    or similar object from other kinds of :class:`.FromClause` objects."""
-
     schema = None
     """Define the 'schema' attribute for this :class:`.FromClause`.
 
@@ -284,6 +243,11 @@ class FromClause(Selectable):
 
     """
 
+    is_selectable = has_selectable = True
+    _is_from_clause = True
+    _is_text_as_from = False
+    _is_join = False
+
     def _translate_schema(self, effective_schema, map_):
         return effective_schema
 
@@ -726,8 +690,8 @@ class Join(FromClause):
         :class:`.FromClause` object.
 
         """
-        self.left = _interpret_as_from(left)
-        self.right = _interpret_as_from(right).self_group()
+        self.left = coercions.expect(roles.FromClauseRole, left)
+        self.right = coercions.expect(roles.FromClauseRole, right).self_group()
 
         if onclause is None:
             self.onclause = self._match_primaries(self.left, self.right)
@@ -1292,7 +1256,9 @@ class Alias(FromClause):
          .. versionadded:: 0.9.0
 
         """
-        return _interpret_as_from(selectable).alias(name=name, flat=flat)
+        return coercions.expect(roles.FromClauseRole, selectable).alias(
+            name=name, flat=flat
+        )
 
     def _init(self, selectable, name=None):
         baseselectable = selectable
@@ -1327,14 +1293,6 @@ class Alias(FromClause):
         else:
             return self.name.encode("ascii", "backslashreplace")
 
-    def as_scalar(self):
-        try:
-            return self.element.as_scalar()
-        except AttributeError:
-            raise AttributeError(
-                "Element %s does not support " "'as_scalar()'" % self.element
-            )
-
     def is_derived_from(self, fromclause):
         if fromclause in self._cloned_set:
             return True
@@ -1426,7 +1384,9 @@ class Lateral(Alias):
             :ref:`lateral_selects` -  overview of usage.
 
         """
-        return _interpret_as_from(selectable).lateral(name=name)
+        return coercions.expect(roles.FromClauseRole, selectable).lateral(
+            name=name
+        )
 
 
 class TableSample(Alias):
@@ -1488,7 +1448,7 @@ class TableSample(Alias):
          REPEATABLE sub-clause is also rendered.
 
         """
-        return _interpret_as_from(selectable).tablesample(
+        return coercions.expect(roles.FromClauseRole, selectable).tablesample(
             sampling, name=name, seed=seed
         )
 
@@ -1523,7 +1483,7 @@ class CTE(Generative, HasSuffixes, Alias):
         Please see :meth:`.HasCte.cte` for detail on CTE usage.
 
         """
-        return _interpret_as_from(selectable).cte(
+        return coercions.expect(roles.HasCTERole, selectable).cte(
             name=name, recursive=recursive
         )
 
@@ -1588,7 +1548,7 @@ class CTE(Generative, HasSuffixes, Alias):
         )
 
 
-class HasCTE(object):
+class HasCTE(roles.HasCTERole):
     """Mixin that declares a class to include CTE support.
 
     .. versionadded:: 1.1
@@ -2059,13 +2019,22 @@ class ForUpdateArg(ClauseElement):
         self.key_share = key_share
         if of is not None:
             self.of = [
-                _interpret_as_column_or_from(elem) for elem in util.to_list(of)
+                coercions.expect(roles.ColumnsClauseRole, elem)
+                for elem in util.to_list(of)
             ]
         else:
             self.of = None
 
 
-class SelectBase(HasCTE, Executable, FromClause):
+class SelectBase(
+    roles.SelectStatementRole,
+    roles.DMLSelectRole,
+    roles.CompoundElementRole,
+    roles.InElementRole,
+    HasCTE,
+    Executable,
+    FromClause,
+):
     """Base class for SELECT statements.
 
 
@@ -2075,15 +2044,32 @@ class SelectBase(HasCTE, Executable, FromClause):
 
     """
 
+    _is_select_statement = True
+
+    @util.deprecated(
+        "1.4",
+        "The :meth:`.SelectBase.as_scalar` method is deprecated and will be "
+        "removed in a future release.  Please refer to "
+        ":meth:`.SelectBase.scalar_subquery`.",
+    )
     def as_scalar(self):
+        return self.scalar_subquery()
+
+    def scalar_subquery(self):
         """return a 'scalar' representation of this selectable, which can be
         used as a column expression.
 
         Typically, a select statement which has only one column in its columns
-        clause is eligible to be used as a scalar expression.
+        clause is eligible to be used as a scalar expression.  The scalar
+        subquery can then be used in the WHERE clause or columns clause of
+        an enclosing SELECT.
 
-        The returned object is an instance of
-        :class:`ScalarSelect`.
+        Note that the scalar subquery differentiates from the FROM-level
+        subquery that can be produced using the :meth:`.SelectBase.subquery`
+        method.
+
+        .. versionchanged: 1.4 - the ``.as_scalar()`` method was renamed to
+           :meth:`.SelectBase.scalar_subquery`.
 
         """
         return ScalarSelect(self)
@@ -2097,7 +2083,7 @@ class SelectBase(HasCTE, Executable, FromClause):
             :meth:`~.SelectBase.as_scalar`.
 
         """
-        return self.as_scalar().label(name)
+        return self.scalar_subquery().label(name)
 
     @_generative
     @util.deprecated(
@@ -2181,20 +2167,19 @@ class GenerativeSelect(SelectBase):
                 {"autocommit": autocommit}
             )
         if limit is not None:
-            self._limit_clause = _offset_or_limit_clause(limit)
+            self._limit_clause = self._offset_or_limit_clause(limit)
         if offset is not None:
-            self._offset_clause = _offset_or_limit_clause(offset)
+            self._offset_clause = self._offset_or_limit_clause(offset)
         self._bind = bind
 
         if order_by is not None:
             self._order_by_clause = ClauseList(
                 *util.to_list(order_by),
-                _literal_as_text=_literal_and_labels_as_label_reference
+                _literal_as_text_role=roles.OrderByRole
             )
         if group_by is not None:
             self._group_by_clause = ClauseList(
-                *util.to_list(group_by),
-                _literal_as_text=_literal_as_label_reference
+                *util.to_list(group_by), _literal_as_text_role=roles.ByOfRole
             )
 
     @property
@@ -2287,6 +2272,37 @@ class GenerativeSelect(SelectBase):
         """
         self.use_labels = True
 
+    def _offset_or_limit_clause(self, element, name=None, type_=None):
+        """Convert the given value to an "offset or limit" clause.
+
+        This handles incoming integers and converts to an expression; if
+        an expression is already given, it is passed through.
+
+        """
+        return coercions.expect(
+            roles.LimitOffsetRole, element, name=name, type_=type_
+        )
+
+    def _offset_or_limit_clause_asint(self, clause, attrname):
+        """Convert the "offset or limit" clause of a select construct to an
+        integer.
+
+        This is only possible if the value is stored as a simple bound
+        parameter. Otherwise, a compilation error is raised.
+
+        """
+        if clause is None:
+            return None
+        try:
+            value = clause._limit_offset_value
+        except AttributeError:
+            raise exc.CompileError(
+                "This SELECT structure does not use a simple "
+                "integer value for %s" % attrname
+            )
+        else:
+            return util.asint(value)
+
     @property
     def _limit(self):
         """Get an integer value for the limit.  This should only be used
@@ -2295,7 +2311,7 @@ class GenerativeSelect(SelectBase):
         isn't currently set to an integer.
 
         """
-        return _offset_or_limit_clause_asint(self._limit_clause, "limit")
+        return self._offset_or_limit_clause_asint(self._limit_clause, "limit")
 
     @property
     def _simple_int_limit(self):
@@ -2319,7 +2335,9 @@ class GenerativeSelect(SelectBase):
         offset isn't currently set to an integer.
 
         """
-        return _offset_or_limit_clause_asint(self._offset_clause, "offset")
+        return self._offset_or_limit_clause_asint(
+            self._offset_clause, "offset"
+        )
 
     @_generative
     def limit(self, limit):
@@ -2339,7 +2357,7 @@ class GenerativeSelect(SelectBase):
 
         """
 
-        self._limit_clause = _offset_or_limit_clause(limit)
+        self._limit_clause = self._offset_or_limit_clause(limit)
 
     @_generative
     def offset(self, offset):
@@ -2361,7 +2379,7 @@ class GenerativeSelect(SelectBase):
 
         """
 
-        self._offset_clause = _offset_or_limit_clause(offset)
+        self._offset_clause = self._offset_or_limit_clause(offset)
 
     @_generative
     def order_by(self, *clauses):
@@ -2403,8 +2421,7 @@ class GenerativeSelect(SelectBase):
             if getattr(self, "_order_by_clause", None) is not None:
                 clauses = list(self._order_by_clause) + list(clauses)
             self._order_by_clause = ClauseList(
-                *clauses,
-                _literal_as_text=_literal_and_labels_as_label_reference
+                *clauses, _literal_as_text_role=roles.OrderByRole
             )
 
     def append_group_by(self, *clauses):
@@ -2423,7 +2440,7 @@ class GenerativeSelect(SelectBase):
             if getattr(self, "_group_by_clause", None) is not None:
                 clauses = list(self._group_by_clause) + list(clauses)
             self._group_by_clause = ClauseList(
-                *clauses, _literal_as_text=_literal_as_label_reference
+                *clauses, _literal_as_text_role=roles.ByOfRole
             )
 
     @property
@@ -2478,7 +2495,7 @@ class CompoundSelect(GenerativeSelect):
 
         # some DBs do not like ORDER BY in the inner queries of a UNION, etc.
         for n, s in enumerate(selects):
-            s = _clause_element_as_expr(s)
+            s = coercions.expect(roles.CompoundElementRole, s)
 
             if not numcols:
                 numcols = len(s.c._all_columns)
@@ -2741,7 +2758,6 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
     _correlate = ()
     _correlate_except = None
     _memoized_property = SelectBase._memoized_property
-    _is_select = True
 
     @util.deprecated_params(
         autocommit=(
@@ -2965,12 +2981,14 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
                 self._distinct = True
             else:
                 self._distinct = [
-                    _literal_as_text(e) for e in util.to_list(distinct)
+                    coercions.expect(roles.WhereHavingRole, e)
+                    for e in util.to_list(distinct)
                 ]
 
         if from_obj is not None:
             self._from_obj = util.OrderedSet(
-                _interpret_as_from(f) for f in util.to_list(from_obj)
+                coercions.expect(roles.FromClauseRole, f)
+                for f in util.to_list(from_obj)
             )
         else:
             self._from_obj = util.OrderedSet()
@@ -2986,7 +3004,7 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
         if cols_present:
             self._raw_columns = []
             for c in columns:
-                c = _interpret_as_column_or_from(c)
+                c = coercions.expect(roles.ColumnsClauseRole, c)
                 if isinstance(c, ScalarSelect):
                     c = c.self_group(against=operators.comma_op)
                 self._raw_columns.append(c)
@@ -2994,16 +3012,16 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
             self._raw_columns = []
 
         if whereclause is not None:
-            self._whereclause = _literal_as_text(whereclause).self_group(
-                against=operators._asbool
-            )
+            self._whereclause = coercions.expect(
+                roles.WhereHavingRole, whereclause
+            ).self_group(against=operators._asbool)
         else:
             self._whereclause = None
 
         if having is not None:
-            self._having = _literal_as_text(having).self_group(
-                against=operators._asbool
-            )
+            self._having = coercions.expect(
+                roles.WhereHavingRole, having
+            ).self_group(against=operators._asbool)
         else:
             self._having = None
 
@@ -3202,15 +3220,6 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
         else:
             self._hints = self._hints.union({(selectable, dialect_name): text})
 
-    @property
-    def type(self):
-        raise exc.InvalidRequestError(
-            "Select objects don't have a type.  "
-            "Call as_scalar() on this Select "
-            "object to return a 'scalar' version "
-            "of this Select."
-        )
-
     @_memoized_property.method
     def locate_all_froms(self):
         """return a Set of all FromClause elements referenced by this Select.
@@ -3496,7 +3505,7 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
         self._reset_exported()
         rc = []
         for c in columns:
-            c = _interpret_as_column_or_from(c)
+            c = coercions.expect(roles.ColumnsClauseRole, c)
             if isinstance(c, ScalarSelect):
                 c = c.self_group(against=operators.comma_op)
             rc.append(c)
@@ -3530,7 +3539,7 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
 
         """
         if expr:
-            expr = [_literal_as_label_reference(e) for e in expr]
+            expr = [coercions.expect(roles.ByOfRole, e) for e in expr]
             if isinstance(self._distinct, list):
                 self._distinct = self._distinct + expr
             else:
@@ -3618,7 +3627,7 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
             self._correlate = ()
         else:
             self._correlate = set(self._correlate).union(
-                _interpret_as_from(f) for f in fromclauses
+                coercions.expect(roles.FromClauseRole, f) for f in fromclauses
             )
 
     @_generative
@@ -3653,7 +3662,7 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
             self._correlate_except = ()
         else:
             self._correlate_except = set(self._correlate_except or ()).union(
-                _interpret_as_from(f) for f in fromclauses
+                coercions.expect(roles.FromClauseRole, f) for f in fromclauses
             )
 
     def append_correlation(self, fromclause):
@@ -3668,7 +3677,7 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
 
         self._auto_correlate = False
         self._correlate = set(self._correlate).union(
-            _interpret_as_from(f) for f in fromclause
+            coercions.expect(roles.FromClauseRole, f) for f in fromclause
         )
 
     def append_column(self, column):
@@ -3689,7 +3698,7 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
 
         """
         self._reset_exported()
-        column = _interpret_as_column_or_from(column)
+        column = coercions.expect(roles.ColumnsClauseRole, column)
 
         if isinstance(column, ScalarSelect):
             column = column.self_group(against=operators.comma_op)
@@ -3705,7 +3714,7 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
         standard :term:`method chaining`.
 
         """
-        clause = _literal_as_text(clause)
+        clause = coercions.expect(roles.WhereHavingRole, clause)
         self._prefixes = self._prefixes + (clause,)
 
     def append_whereclause(self, whereclause):
@@ -3747,7 +3756,7 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
 
         """
         self._reset_exported()
-        fromclause = _interpret_as_from(fromclause)
+        fromclause = coercions.expect(roles.FromClauseRole, fromclause)
         self._from_obj = self._from_obj.union([fromclause])
 
     @_memoized_property
@@ -3894,7 +3903,7 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
     bind = property(bind, _set_bind)
 
 
-class ScalarSelect(Generative, Grouping):
+class ScalarSelect(roles.InElementRole, Generative, Grouping):
     _from_objects = []
     _is_from_container = True
     _is_implicitly_boolean = False
@@ -3956,7 +3965,7 @@ class Exists(UnaryExpression):
         else:
             if not args:
                 args = ([literal_column("*")],)
-            s = Select(*args, **kwargs).as_scalar().self_group()
+            s = Select(*args, **kwargs).scalar_subquery().self_group()
 
         UnaryExpression.__init__(
             self,
@@ -3999,6 +4008,7 @@ class Exists(UnaryExpression):
         return e
 
 
+# TODO: rename to TextualSelect, this is not a FROM clause
 class TextAsFrom(SelectBase):
     """Wrap a :class:`.TextClause` construct within a :class:`.SelectBase`
     interface.
@@ -4022,7 +4032,7 @@ class TextAsFrom(SelectBase):
 
     __visit_name__ = "text_as_from"
 
-    _textual = True
+    _is_textual = True
 
     def __init__(self, text, columns, positional=False):
         self.element = text
index 0d39445527427dfc7e12d5654023ca1c368b1514..6a520a2d59d22799d0e710bee94c57e96cfa2ea7 100644 (file)
@@ -14,13 +14,14 @@ import datetime as dt
 import decimal
 import json
 
+from . import coercions
 from . import elements
 from . import operators
+from . import roles
 from . import type_api
 from .base import _bind_or_error
 from .base import SchemaEventTarget
 from .elements import _defer_name
-from .elements import _literal_as_binds
 from .elements import quoted_name
 from .elements import Slice
 from .elements import TypeCoerce as type_coerce  # noqa
@@ -2187,19 +2188,21 @@ class JSON(Indexable, TypeEngine):
             if not isinstance(index, util.string_types) and isinstance(
                 index, compat.collections_abc.Sequence
             ):
-                index = default_comparator._check_literal(
-                    self.expr,
-                    operators.json_path_getitem_op,
+                index = coercions.expect(
+                    roles.BinaryElementRole,
                     index,
+                    expr=self.expr,
+                    operator=operators.json_path_getitem_op,
                     bindparam_type=JSON.JSONPathType,
                 )
 
                 operator = operators.json_path_getitem_op
             else:
-                index = default_comparator._check_literal(
-                    self.expr,
-                    operators.json_getitem_op,
+                index = coercions.expect(
+                    roles.BinaryElementRole,
                     index,
+                    expr=self.expr,
+                    operator=operators.json_getitem_op,
                     bindparam_type=JSON.JSONIndexType,
                 )
                 operator = operators.json_getitem_op
@@ -2372,17 +2375,20 @@ class ARRAY(SchemaEventTarget, Indexable, Concatenable, TypeEngine):
                 if self.type.zero_indexes:
                     index = slice(index.start + 1, index.stop + 1, index.step)
                 index = Slice(
-                    _literal_as_binds(
+                    coercions.expect(
+                        roles.ExpressionElementRole,
                         index.start,
                         name=self.expr.key,
                         type_=type_api.INTEGERTYPE,
                     ),
-                    _literal_as_binds(
+                    coercions.expect(
+                        roles.ExpressionElementRole,
                         index.stop,
                         name=self.expr.key,
                         type_=type_api.INTEGERTYPE,
                     ),
-                    _literal_as_binds(
+                    coercions.expect(
+                        roles.ExpressionElementRole,
                         index.step,
                         name=self.expr.key,
                         type_=type_api.INTEGERTYPE,
@@ -2438,7 +2444,7 @@ class ARRAY(SchemaEventTarget, Indexable, Concatenable, TypeEngine):
             """
             operator = operator if operator else operators.eq
             return operator(
-                elements._literal_as_binds(other),
+                coercions.expect(roles.ExpressionElementRole, other),
                 elements.CollectionAggregate._create_any(self.expr),
             )
 
@@ -2473,7 +2479,7 @@ class ARRAY(SchemaEventTarget, Indexable, Concatenable, TypeEngine):
             """
             operator = operator if operator else operators.eq
             return operator(
-                elements._literal_as_binds(other),
+                coercions.expect(roles.ExpressionElementRole, other),
                 elements.CollectionAggregate._create_all(self.expr),
             )
 
index bdeae961376033c7d126e7bbf193c01ffc1fdabc..5eea27e08a14889fd2828d7471f6e73ed54b9351 100644 (file)
@@ -57,6 +57,9 @@ class TypeEngine(Visitable):
 
         default_comparator = None
 
+        def __clause_element__(self):
+            return self.expr
+
         def __init__(self, expr):
             self.expr = expr
             self.type = expr.type
index c52e9a76f1a1a246ed424f6fe9acb5db9ef21672..090c8488a5910554948ff5c9285e814861bbc2ef 100644 (file)
@@ -23,6 +23,7 @@ from .assertions import expect_warnings  # noqa
 from .assertions import in_  # noqa
 from .assertions import is_  # noqa
 from .assertions import is_false  # noqa
+from .assertions import is_instance_of  # noqa
 from .assertions import is_not_  # noqa
 from .assertions import is_true  # noqa
 from .assertions import le_  # noqa
index d8038e225cf5cd139400dd3ef3aca6fb500fada9..819fedcc772b016734085ecbcb5f3abe5f98f8fa 100644 (file)
@@ -247,6 +247,10 @@ def le_(a, b, msg=None):
     assert a <= b, msg or "%r != %r" % (a, b)
 
 
+def is_instance_of(a, b, msg=None):
+    assert isinstance(a, b), msg or "%r is not an instance of %r" % (a, b)
+
+
 def is_true(a, msg=None):
     is_(a, True, msg=msg)
 
index 012de7911d9e08266a46fb09b9ebc0733fe0abdc..c7e6a266ca4384b26b1b2db4a5624f79750d7c65 100644 (file)
@@ -198,9 +198,9 @@ class CTETest(fixtures.TablesTest):
             conn.execute(
                 some_other_table.delete().where(
                     some_other_table.c.data
-                    == select([cte.c.data]).where(
-                        cte.c.id == some_other_table.c.id
-                    )
+                    == select([cte.c.data])
+                    .where(cte.c.id == some_other_table.c.id)
+                    .scalar_subquery()
                 )
             )
             eq_(
index a07e45df2d37bcd0178197094436801e6ec374c9..8bf66d2eb9936ad495f57f57ef3e594b8f0959ef 100644 (file)
@@ -98,7 +98,7 @@ class RowFetchTest(fixtures.TablesTest):
 
         """
         datetable = self.tables.has_dates
-        s = select([datetable.alias("x").c.today]).as_scalar()
+        s = select([datetable.alias("x").c.today]).scalar_subquery()
         s2 = select([datetable.c.id, s.label("somelabel")])
         row = config.db.execute(s2).first()
 
index e6e4907abbf46b9b150952e005421fe13c1752e2..a8cdc5ef7fc0399e8aa6096bfec21c3ab58fd4ea 100644 (file)
@@ -140,6 +140,7 @@ from .langhelpers import portable_instancemethod  # noqa
 from .langhelpers import quoted_token_parser  # noqa
 from .langhelpers import safe_reraise  # noqa
 from .langhelpers import set_creation_order  # noqa
+from .langhelpers import string_or_unprintable  # noqa
 from .langhelpers import symbol  # noqa
 from .langhelpers import unbound_method_to_callable  # noqa
 from .langhelpers import warn  # noqa
index 7d1321e0b83c342d4cea2fda27e18e3096990778..7a7faff60c4faa65106c291958170e3887efc729 100644 (file)
@@ -79,6 +79,16 @@ class safe_reraise(object):
             compat.reraise(type_, value, traceback)
 
 
+def string_or_unprintable(element):
+    if isinstance(element, compat.string_types):
+        return element
+    else:
+        try:
+            return str(element)
+        except Exception:
+            return "unprintable element %r" % element
+
+
 def clsname_as_plain_name(cls):
     return " ".join(
         n.lower() for n in re.findall(r"([A-Z][a-z]+)", cls.__name__)
index 30a11d16b2f170962e29f4dc7a4d85f4f70b7245..498de763c6d9ed857963b15c090caf52eace1552 100644 (file)
@@ -267,7 +267,7 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
         stmt = table.update().values(
             val=select([other.c.newval])
             .where(table.c.sym == other.c.sym)
-            .as_scalar()
+            .scalar_subquery()
         )
 
         self.assert_compile(
@@ -334,14 +334,14 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
 
         t = table("sometable", column("somecolumn"))
         self.assert_compile(
-            t.select().where(t.c.somecolumn == t.select()),
+            t.select().where(t.c.somecolumn == t.select().scalar_subquery()),
             "SELECT sometable.somecolumn FROM "
             "sometable WHERE sometable.somecolumn = "
             "(SELECT sometable.somecolumn FROM "
             "sometable)",
         )
         self.assert_compile(
-            t.select().where(t.c.somecolumn != t.select()),
+            t.select().where(t.c.somecolumn != t.select().scalar_subquery()),
             "SELECT sometable.somecolumn FROM "
             "sometable WHERE sometable.somecolumn != "
             "(SELECT sometable.somecolumn FROM "
@@ -844,7 +844,7 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
         t1 = table("t1", column("x", Integer), column("y", Integer))
         t2 = table("t2", column("x", Integer), column("y", Integer))
 
-        order_by = select([t2.c.y]).where(t1.c.x == t2.c.x).as_scalar()
+        order_by = select([t2.c.y]).where(t1.c.x == t2.c.x).scalar_subquery()
         s = (
             select([t1])
             .where(t1.c.x == 5)
@@ -1135,7 +1135,7 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
         idx = Index("test_idx_data_1", tbl.c.data, mssql_where=tbl.c.data > 1)
         self.assert_compile(
             schema.CreateIndex(idx),
-            "CREATE INDEX test_idx_data_1 ON test (data) WHERE data > 1"
+            "CREATE INDEX test_idx_data_1 ON test (data) WHERE data > 1",
         )
 
     def test_index_ordering(self):
index bf836fc14741d05029759bd912f15843195b662b..4ecf0634c962a626dcfabee55c2fac0fd8d23a1c 100644 (file)
@@ -130,7 +130,7 @@ class LegacySchemaAliasingTest(fixtures.TestBase, AssertsCompiledSQL):
 
     def test_column_subquery_to_alias(self):
         a1 = self.t2.alias("a1")
-        s = select([self.t2, select([a1.c.a]).as_scalar()])
+        s = select([self.t2, select([a1.c.a]).scalar_subquery()])
         self._assert_sql(
             s,
             "SELECT t2_1.a, t2_1.b, t2_1.c, "
index e492f1e17f2e3da7809bdccfe78e9d964fe6430b..39485ae102d4367abdf585f1c1d9115aa4c26846 100644 (file)
@@ -247,7 +247,7 @@ class AnyAllTest(fixtures.TablesTest):
     def test_any_w_comparator(self):
         stuff = self.tables.stuff
         stmt = select([stuff.c.id]).where(
-            stuff.c.value > any_(select([stuff.c.value]))
+            stuff.c.value > any_(select([stuff.c.value]).scalar_subquery())
         )
 
         eq_(testing.db.execute(stmt).fetchall(), [(2,), (3,), (4,), (5,)])
@@ -255,13 +255,13 @@ class AnyAllTest(fixtures.TablesTest):
     def test_all_w_comparator(self):
         stuff = self.tables.stuff
         stmt = select([stuff.c.id]).where(
-            stuff.c.value >= all_(select([stuff.c.value]))
+            stuff.c.value >= all_(select([stuff.c.value]).scalar_subquery())
         )
 
         eq_(testing.db.execute(stmt).fetchall(), [(5,)])
 
     def test_any_literal(self):
         stuff = self.tables.stuff
-        stmt = select([4 == any_(select([stuff.c.value]))])
+        stmt = select([4 == any_(select([stuff.c.value]).scalar_subquery())])
 
         is_(testing.db.execute(stmt).scalar(), True)
index 3fe2f1bfe97e8f20ee25f0a2e398fc97cdc2ddb7..b6c911813b8b3f73f8831333d82daff6c4d991a6 100644 (file)
@@ -705,7 +705,7 @@ class DeclarativeTest(DeclarativeTestBase):
             rel = relationship("User", primaryjoin="User.id==Bar.__table__.id")
 
         assert_raises_message(
-            exc.InvalidRequestError,
+            AttributeError,
             "does not have a mapped column named " "'__table__'",
             configure_mappers,
         )
@@ -1469,7 +1469,7 @@ class DeclarativeTest(DeclarativeTestBase):
         User.address_count = sa.orm.column_property(
             sa.select([sa.func.count(Address.id)])
             .where(Address.user_id == User.id)
-            .as_scalar()
+            .scalar_subquery()
         )
         Base.metadata.create_all()
         u1 = User(
@@ -1514,9 +1514,9 @@ class DeclarativeTest(DeclarativeTestBase):
                 # this doesn't really gain us anything.  but if
                 # one is used, lets have it function as expected...
                 return sa.orm.column_property(
-                    sa.select([sa.func.count(Address.id)]).where(
-                        Address.user_id == cls.id
-                    )
+                    sa.select([sa.func.count(Address.id)])
+                    .where(Address.user_id == cls.id)
+                    .scalar_subquery()
                 )
 
         Base.metadata.create_all()
@@ -1616,7 +1616,7 @@ class DeclarativeTest(DeclarativeTestBase):
             adr_count = sa.orm.column_property(
                 sa.select(
                     [sa.func.count(Address.id)], Address.user_id == id
-                ).as_scalar()
+                ).scalar_subquery()
             )
             addresses = relationship(Address)
 
@@ -1920,7 +1920,7 @@ class DeclarativeTest(DeclarativeTestBase):
         User.address_count = sa.orm.column_property(
             sa.select([sa.func.count(Address.id)])
             .where(Address.user_id == User.id)
-            .as_scalar()
+            .scalar_subquery()
         )
         Base.metadata.create_all()
         u1 = User(
index ef9bbd354d3f69c1365b941ce8cfca1cdf63f253..df7dea77c482ed20b3e12fd7f98431bc200cf8be 100644 (file)
@@ -1851,7 +1851,7 @@ class DeclaredAttrTest(DeclarativeTestBase, testing.AssertsCompiledSQL):
                 return column_property(
                     select([func.count(Address.id)])
                     .where(Address.user_id == cls.id)
-                    .as_scalar()
+                    .scalar_subquery()
                 )
 
         class Address(Base):
index 55cd9376bf31756487736ee3bbfeccc9cc251880..00c6a78b4de90e031d903b92709df154ffe41360 100644 (file)
@@ -765,9 +765,11 @@ class ResultTest(BakedTest):
         )
 
         main_bq = self.bakery(
-            lambda s: s.query(Address.id, sub_bq.to_query(s).as_scalar())
+            lambda s: s.query(Address.id, sub_bq.to_query(s).scalar_subquery())
+        )
+        main_bq += lambda q: q.filter(
+            sub_bq.to_query(q).scalar_subquery() == "ed"
         )
-        main_bq += lambda q: q.filter(sub_bq.to_query(q).as_scalar() == "ed")
         main_bq += lambda q: q.order_by(Address.id)
 
         sess = Session()
index 525824669c13c05ada80e527f3624afef593225a..07ffd93857dafd45e331cad23dbcb8fb9a759e3d 100644 (file)
@@ -2184,6 +2184,7 @@ class CorrelateExceptWPolyAdaptTest(
                     select([func.count(Superclass.id)])
                     .where(Superclass.common_id == id)
                     .correlate_except(Superclass)
+                    .scalar_subquery()
                 )
 
         if not use_correlate_except:
@@ -2191,6 +2192,7 @@ class CorrelateExceptWPolyAdaptTest(
                 select([func.count(Superclass.id)])
                 .where(Superclass.common_id == Common.id)
                 .correlate(Common)
+                .scalar_subquery()
             )
 
         return Common, Superclass
index 59ecc7c986b22d2b547b4088f85324ab4ff5bf99..472dafcc4b71fec37640d739f9d2e0e3bfef9c7d 100644 (file)
@@ -344,8 +344,8 @@ class PolymorphicOnNotLocalTest(fixtures.MappedTest):
 
         assert_raises_message(
             sa_exc.ArgumentError,
-            "Only direct column-mapped property or "
-            "SQL expression can be passed for polymorphic_on",
+            r"Column expression or string key expected for argument "
+            r"'polymorphic_on'; got .*function",
             go,
         )
 
index c16573b237d6d301ce59746935bcd66b38f8e957..bbf0d472a16f11578a9b7cf20cdc61bd2612234b 100644 (file)
@@ -1295,7 +1295,7 @@ class _PolymorphicTestBase(object):
         subq = (
             sess.query(engineers.c.person_id)
             .filter(Engineer.primary_language == "java")
-            .statement.as_scalar()
+            .statement.scalar_subquery()
         )
         eq_(sess.query(Person).filter(Person.person_id.in_(subq)).one(), e1)
 
@@ -1634,15 +1634,16 @@ class _PolymorphicTestBase(object):
 
         # this for a long time did not work with PolymorphicAliased and
         # PolymorphicUnions, which was due to the no_replacement_traverse
-        # annotation added to query.statement which then went into as_scalar().
-        # this is removed as of :ticket:`4304` so now works.
+        # annotation added to query.statement which then went into
+        # scalar_subquery(). this is removed as of :ticket:`4304` so now
+        # works.
         eq_(
             sess.query(Person.name)
             .filter(
                 sess.query(Company.name)
                 .filter(Company.company_id == Person.company_id)
                 .correlate(Person)
-                .as_scalar()
+                .scalar_subquery()
                 == "Elbonia, Inc."
             )
             .all(),
@@ -1660,7 +1661,7 @@ class _PolymorphicTestBase(object):
                 sess.query(Company.name)
                 .filter(Company.company_id == paliased.company_id)
                 .correlate(paliased)
-                .as_scalar()
+                .scalar_subquery()
                 == "Elbonia, Inc."
             )
             .all(),
@@ -1678,7 +1679,7 @@ class _PolymorphicTestBase(object):
                 sess.query(Company.name)
                 .filter(Company.company_id == paliased.company_id)
                 .correlate(paliased)
-                .as_scalar()
+                .scalar_subquery()
                 == "Elbonia, Inc."
             )
             .all(),
@@ -1720,7 +1721,7 @@ class PolymorphicTest(_PolymorphicTestBase, _Polymorphic):
                 sess.query(Company.name)
                 .filter(Company.company_id == p_poly.company_id)
                 .correlate(p_poly)
-                .as_scalar()
+                .scalar_subquery()
                 == "Elbonia, Inc."
             )
             .all(),
@@ -1739,7 +1740,7 @@ class PolymorphicTest(_PolymorphicTestBase, _Polymorphic):
                 sess.query(Company.name)
                 .filter(Company.company_id == p_poly.company_id)
                 .correlate(p_poly)
-                .as_scalar()
+                .scalar_subquery()
                 == "Elbonia, Inc."
             )
             .all(),
index 1b28974b7b5b5f802c44b42b0a455178707db592..a54cfbe9330580e1129b1c732448ab8088fc1847 100644 (file)
@@ -962,7 +962,7 @@ class RelationshipToSingleTest(
             .select_from(Engineer)
             .filter(Engineer.company_id == Company.company_id)
             .correlate(Company)
-            .as_scalar()
+            .scalar_subquery()
         )
 
         self.assert_compile(
index 83f4f44511fdbde710855f845b9c630303cfb9fd..eb0df3107f9301a0c6fa39e55de10feb1e437634 100644 (file)
@@ -1827,15 +1827,15 @@ class DictHelpersTest(fixtures.MappedTest):
     def test_column_mapped_assertions(self):
         assert_raises_message(
             sa_exc.ArgumentError,
-            "Column-based expression object expected "
-            "for argument 'mapping_spec'; got: 'a'",
+            "Column expression expected "
+            "for argument 'mapping_spec'; got 'a'.",
             collections.column_mapped_collection,
             "a",
         )
         assert_raises_message(
             sa_exc.ArgumentError,
-            "Column-based expression object expected "
-            "for argument 'mapping_spec'; got: 'a'",
+            "Column expression expected "
+            "for argument 'mapping_spec'; got .*TextClause.",
             collections.column_mapped_collection,
             text("a"),
         )
index 679b3ec5bd459ce6eeb0dcc2d4e917d3c1b7e0f1..c5938416cda835bd7c7badb166f56f79ce67bdb9 100644 (file)
@@ -550,6 +550,17 @@ class StrongIdentityMapTest(_fixtures.FixtureTest):
 class DeprecatedMapperTest(_fixtures.FixtureTest, AssertsCompiledSQL):
     __dialect__ = "default"
 
+    def test_query_as_scalar(self):
+        users, User = self.tables.users, self.classes.User
+
+        mapper(User, users)
+        s = Session()
+        with assertions.expect_deprecated(
+            r"The Query.as_scalar\(\) method is deprecated and will "
+            "be removed in a future release."
+        ):
+            s.query(User).as_scalar()
+
     def test_cancel_order_by(self):
         users, User = self.tables.users, self.classes.User
 
index 4adf9a72f26852b92379ef4287aaf9a3d04ef579..499803543d03eedce718793f34795f52ade85ca5 100644 (file)
@@ -3258,7 +3258,7 @@ class SubqueryAliasingTest(fixtures.MappedTest, testing.AssertsCompiledSQL):
             b_table.c.a_id == a_table.c.id
         )
 
-        self._fixture({"summation": column_property(cp)})
+        self._fixture({"summation": column_property(cp.scalar_subquery())})
         self.assert_compile(
             create_session()
             .query(A)
@@ -3281,7 +3281,7 @@ class SubqueryAliasingTest(fixtures.MappedTest, testing.AssertsCompiledSQL):
             b_table.c.a_id == a_table.c.id
         )
 
-        self._fixture({"summation": column_property(cp)})
+        self._fixture({"summation": column_property(cp.scalar_subquery())})
         self.assert_compile(
             create_session()
             .query(A)
@@ -3306,7 +3306,7 @@ class SubqueryAliasingTest(fixtures.MappedTest, testing.AssertsCompiledSQL):
             .correlate(a_table)
         )
 
-        self._fixture({"summation": column_property(cp)})
+        self._fixture({"summation": column_property(cp.scalar_subquery())})
         self.assert_compile(
             create_session()
             .query(A)
@@ -3330,7 +3330,7 @@ class SubqueryAliasingTest(fixtures.MappedTest, testing.AssertsCompiledSQL):
             select([func.sum(b_table.c.value)])
             .where(b_table.c.a_id == a_table.c.id)
             .correlate(a_table)
-            .as_scalar()
+            .scalar_subquery()
         )
 
         # up until 0.8, this was ordering by a new subquery.
@@ -3360,7 +3360,7 @@ class SubqueryAliasingTest(fixtures.MappedTest, testing.AssertsCompiledSQL):
             select([func.sum(b_table.c.value)])
             .where(b_table.c.a_id == a_table.c.id)
             .correlate(a_table)
-            .as_scalar()
+            .scalar_subquery()
             .label("foo")
         )
         self.assert_compile(
@@ -3387,7 +3387,7 @@ class SubqueryAliasingTest(fixtures.MappedTest, testing.AssertsCompiledSQL):
             select([func.sum(b_table.c.value)])
             .where(b_table.c.a_id == a_table.c.id)
             .correlate(a_table)
-            .as_scalar()
+            .scalar_subquery()
         )
         # test a different unary operator
         self.assert_compile(
@@ -4689,7 +4689,7 @@ class SubqueryTest(fixtures.MappedTest):
                 tag_score = tag_score.label(labelname)
                 user_score = user_score.label(labelname)
             else:
-                user_score = user_score.as_scalar()
+                user_score = user_score.scalar_subquery()
 
             mapper(
                 Tag,
@@ -4798,40 +4798,28 @@ class CorrelatedSubqueryTest(fixtures.MappedTest):
         )
 
     def test_labeled_on_date_noalias(self):
-        self._do_test("label", True, False)
+        self._do_test(True, True, False)
 
     def test_scalar_on_date_noalias(self):
-        self._do_test("scalar", True, False)
-
-    def test_plain_on_date_noalias(self):
-        self._do_test("none", True, False)
+        self._do_test(False, True, False)
 
     def test_labeled_on_limitid_noalias(self):
-        self._do_test("label", False, False)
+        self._do_test(True, False, False)
 
     def test_scalar_on_limitid_noalias(self):
-        self._do_test("scalar", False, False)
-
-    def test_plain_on_limitid_noalias(self):
-        self._do_test("none", False, False)
+        self._do_test(False, False, False)
 
     def test_labeled_on_date_alias(self):
-        self._do_test("label", True, True)
+        self._do_test(True, True, True)
 
     def test_scalar_on_date_alias(self):
-        self._do_test("scalar", True, True)
-
-    def test_plain_on_date_alias(self):
-        self._do_test("none", True, True)
+        self._do_test(False, True, True)
 
     def test_labeled_on_limitid_alias(self):
-        self._do_test("label", False, True)
+        self._do_test(True, False, True)
 
     def test_scalar_on_limitid_alias(self):
-        self._do_test("scalar", False, True)
-
-    def test_plain_on_limitid_alias(self):
-        self._do_test("none", False, True)
+        self._do_test(False, False, True)
 
     def _do_test(self, labeled, ondate, aliasstuff):
         stuff, users = self.tables.stuff, self.tables.users
@@ -4843,7 +4831,6 @@ class CorrelatedSubqueryTest(fixtures.MappedTest):
             pass
 
         mapper(Stuff, stuff)
-
         if aliasstuff:
             salias = stuff.alias()
         else:
@@ -4879,11 +4866,11 @@ class CorrelatedSubqueryTest(fixtures.MappedTest):
         else:
             operator = operators.eq
 
-        if labeled == "label":
+        if labeled:
             stuff_view = stuff_view.label("foo")
             operator = operators.eq
-        elif labeled == "scalar":
-            stuff_view = stuff_view.as_scalar()
+        else:
+            stuff_view = stuff_view.scalar_subquery()
 
         if ondate:
             mapper(
index 1cccfff2684e3646c8a8b4ee05ec85d9d61725e9..3cec10e6815251d5e10805a0acd9f8bb3f3f904a 100644 (file)
@@ -161,79 +161,79 @@ class QueryCorrelatesLikeSelect(QueryTest, AssertsCompiledSQL):
         "WHERE addresses.user_id = users.id) AS anon_1 FROM users"
     )
 
-    def test_as_scalar_select_auto_correlate(self):
+    def test_scalar_subquery_select_auto_correlate(self):
         addresses, users = self.tables.addresses, self.tables.users
         query = select(
             [func.count(addresses.c.id)], addresses.c.user_id == users.c.id
-        ).as_scalar()
+        ).scalar_subquery()
         query = select([users.c.name.label("users_name"), query])
         self.assert_compile(
             query, self.query_correlated, dialect=default.DefaultDialect()
         )
 
-    def test_as_scalar_select_explicit_correlate(self):
+    def test_scalar_subquery_select_explicit_correlate(self):
         addresses, users = self.tables.addresses, self.tables.users
         query = (
             select(
                 [func.count(addresses.c.id)], addresses.c.user_id == users.c.id
             )
             .correlate(users)
-            .as_scalar()
+            .scalar_subquery()
         )
         query = select([users.c.name.label("users_name"), query])
         self.assert_compile(
             query, self.query_correlated, dialect=default.DefaultDialect()
         )
 
-    def test_as_scalar_select_correlate_off(self):
+    def test_scalar_subquery_select_correlate_off(self):
         addresses, users = self.tables.addresses, self.tables.users
         query = (
             select(
                 [func.count(addresses.c.id)], addresses.c.user_id == users.c.id
             )
             .correlate(None)
-            .as_scalar()
+            .scalar_subquery()
         )
         query = select([users.c.name.label("users_name"), query])
         self.assert_compile(
             query, self.query_not_correlated, dialect=default.DefaultDialect()
         )
 
-    def test_as_scalar_query_auto_correlate(self):
+    def test_scalar_subquery_query_auto_correlate(self):
         sess = create_session()
         Address, User = self.classes.Address, self.classes.User
         query = (
             sess.query(func.count(Address.id))
             .filter(Address.user_id == User.id)
-            .as_scalar()
+            .scalar_subquery()
         )
         query = sess.query(User.name, query)
         self.assert_compile(
             query, self.query_correlated, dialect=default.DefaultDialect()
         )
 
-    def test_as_scalar_query_explicit_correlate(self):
+    def test_scalar_subquery_query_explicit_correlate(self):
         sess = create_session()
         Address, User = self.classes.Address, self.classes.User
         query = (
             sess.query(func.count(Address.id))
             .filter(Address.user_id == User.id)
             .correlate(self.tables.users)
-            .as_scalar()
+            .scalar_subquery()
         )
         query = sess.query(User.name, query)
         self.assert_compile(
             query, self.query_correlated, dialect=default.DefaultDialect()
         )
 
-    def test_as_scalar_query_correlate_off(self):
+    def test_scalar_subquery_query_correlate_off(self):
         sess = create_session()
         Address, User = self.classes.Address, self.classes.User
         query = (
             sess.query(func.count(Address.id))
             .filter(Address.user_id == User.id)
             .correlate(None)
-            .as_scalar()
+            .scalar_subquery()
         )
         query = sess.query(User.name, query)
         self.assert_compile(
@@ -3243,7 +3243,7 @@ class ExternalColumnsTest(QueryTest):
                         users.c.id == addresses.c.user_id,
                     )
                     .correlate(users)
-                    .as_scalar()
+                    .scalar_subquery()
                 ),
             },
         )
@@ -3398,7 +3398,9 @@ class ExternalColumnsTest(QueryTest):
                     select(
                         [func.count(addresses.c.id)],
                         users.c.id == addresses.c.user_id,
-                    ).correlate(users)
+                    )
+                    .correlate(users)
+                    .scalar_subquery()
                 ),
             },
         )
index 78680701e0bdb266b623c4c3d81aefa6ba50639d..8d1e82fb540bf5e8f050454bc67db60d9767b90b 100644 (file)
@@ -1290,7 +1290,7 @@ class CorrelatedTest(fixtures.MappedTest):
                     Stuff,
                     primaryjoin=sa.and_(
                         user_t.c.id == stuff.c.user_id,
-                        stuff.c.id == (stuff_view.as_scalar()),
+                        stuff.c.id == (stuff_view.scalar_subquery()),
                     ),
                 )
             },
index 93cf19faea6a5853b6f9faad5018e1fbe00ed2a6..fa1f1fdf88d00a7bd048b4a6e134fe7289dcb677 100644 (file)
@@ -764,7 +764,7 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL):
         expr = User.name + "name"
         expr2 = sa.select([User.name, users.c.id])
         m.add_property("x", column_property(expr))
-        m.add_property("y", column_property(expr2))
+        m.add_property("y", column_property(expr2.scalar_subquery()))
 
         assert User.x.property.columns[0] is not expr
         assert User.x.property.columns[0].element.left is users.c.name
index 86b304be9de7b70caaf2182d52470dbc7b66ce78..27176d6fb98665f29bf3b20643e92c557f87774b 100644 (file)
@@ -467,7 +467,7 @@ class RawSelectTest(QueryTest, AssertsCompiledSQL):
                     select([func.count(Address.id)])
                     .where(User.id == Address.user_id)
                     .correlate(User)
-                    .as_scalar(),
+                    .scalar_subquery(),
                 ]
             ),
             "SELECT users.name, addresses.id, "
@@ -489,7 +489,7 @@ class RawSelectTest(QueryTest, AssertsCompiledSQL):
                     select([func.count(Address.id)])
                     .where(uu.id == Address.user_id)
                     .correlate(uu)
-                    .as_scalar(),
+                    .scalar_subquery(),
                 ]
             ),
             # for a long time, "uu.id = address.user_id" was reversed;
@@ -1072,7 +1072,7 @@ class InvalidGenerationsTest(QueryTest, AssertsCompiledSQL):
 
         assert_raises_message(
             sa_exc.ArgumentError,
-            "Object .*User.* is not legal as a SQL literal value",
+            "SQL expression element expected, got .*User",
             distinct,
             User,
         )
@@ -1080,7 +1080,7 @@ class InvalidGenerationsTest(QueryTest, AssertsCompiledSQL):
         ua = aliased(User)
         assert_raises_message(
             sa_exc.ArgumentError,
-            "Object .*User.* is not legal as a SQL literal value",
+            "SQL expression element expected, got .*User",
             distinct,
             ua,
         )
@@ -1088,21 +1088,21 @@ class InvalidGenerationsTest(QueryTest, AssertsCompiledSQL):
         s = Session()
         assert_raises_message(
             sa_exc.ArgumentError,
-            "Object .*User.* is not legal as a SQL literal value",
+            "SQL expression element or literal value expected, got .*User",
             lambda: s.query(User).filter(User.name == User),
         )
 
         u1 = User()
         assert_raises_message(
             sa_exc.ArgumentError,
-            "Object .*User.* is not legal as a SQL literal value",
+            "SQL expression element expected, got .*User",
             distinct,
             u1,
         )
 
         assert_raises_message(
             sa_exc.ArgumentError,
-            "Object .*User.* is not legal as a SQL literal value",
+            "SQL expression element or literal value expected, got .*User",
             lambda: s.query(User).filter(User.name == u1),
         )
 
@@ -1757,16 +1757,17 @@ class ExpressionTest(QueryTest, AssertsCompiledSQL):
 
         session = create_session()
 
-        q = session.query(User.id).filter(User.id == 7)
+        q = session.query(User.id).filter(User.id == 7).scalar_subquery()
 
         q = session.query(Address).filter(Address.user_id == q)
+
         assert isinstance(q._criterion.right, expression.ColumnElement)
         self.assert_compile(
             q,
             "SELECT addresses.id AS addresses_id, addresses.user_id "
             "AS addresses_user_id, addresses.email_address AS "
             "addresses_email_address FROM addresses WHERE "
-            "addresses.user_id = (SELECT users.id AS users_id "
+            "addresses.user_id = (SELECT users.id "
             "FROM users WHERE users.id = :id_1)",
         )
 
@@ -1842,12 +1843,12 @@ class ExpressionTest(QueryTest, AssertsCompiledSQL):
             "WHERE users.id = :id_1) AS foo",
         )
 
-    def test_as_scalar(self):
+    def test_scalar_subquery(self):
         User = self.classes.User
 
         session = create_session()
 
-        q = session.query(User.id).filter(User.id == 7).as_scalar()
+        q = session.query(User.id).filter(User.id == 7).scalar_subquery()
 
         self.assert_compile(
             session.query(User).filter(User.id.in_(q)),
@@ -1866,7 +1867,7 @@ class ExpressionTest(QueryTest, AssertsCompiledSQL):
             session.query(User.id)
             .filter(User.id == bindparam("foo"))
             .params(foo=7)
-            .subquery()
+            .scalar_subquery()
         )
 
         q = session.query(User).filter(User.id.in_(q))
@@ -2081,6 +2082,8 @@ class ColumnPropertyTest(_fixtures.FixtureTest, AssertsCompiledSQL):
         )
         if label:
             stmt = stmt.label("email_ad")
+        else:
+            stmt = stmt.scalar_subquery()
 
         mapper(
             User,
@@ -5491,7 +5494,7 @@ class SessionBindTest(QueryTest):
             column_property(
                 select([func.sum(Address.id)])
                 .where(Address.user_id == User.id)
-                .as_scalar()
+                .scalar_subquery()
             ),
         )
         session = Session()
index 5e6ac53fe8582858dc07501b21c3a51e56b7d1ce..259d5ea9c2e3b64c1e89b1f3aa52ab98820c55ba 100644 (file)
@@ -470,7 +470,7 @@ class _JoinFixtures(object):
             self.left,
             self.right,
             primaryjoin=self.left.c.id == func.foo(self.right.c.lid),
-            consider_as_foreign_keys=[self.right.c.lid],
+            consider_as_foreign_keys={self.right.c.lid},
             **kw
         )
 
@@ -480,10 +480,10 @@ class _JoinFixtures(object):
             self.composite_multi_ref,
             self.composite_target,
             self.composite_multi_ref,
-            consider_as_foreign_keys=[
+            consider_as_foreign_keys={
                 self.composite_multi_ref.c.uid2,
                 self.composite_multi_ref.c.oid,
-            ],
+            },
             **kw
         )
 
@@ -1099,10 +1099,10 @@ class DetermineJoinTest(_JoinFixtures, fixtures.TestBase, AssertsCompiledSQL):
             self.m2mleft,
             self.m2mright,
             secondary=self.m2msecondary_ambig_fks,
-            consider_as_foreign_keys=[
+            consider_as_foreign_keys={
                 self.m2msecondary_ambig_fks.c.lid1,
                 self.m2msecondary_ambig_fks.c.rid1,
-            ],
+            },
         )
 
     def test_determine_join_w_fks_ambig_m2m(self):
index 652b8bacd657cc88cfbbbe8e3c7e7ccfcccfacd6..494ec2b0dca2773975260eedd5c11aa3a3d40592 100644 (file)
@@ -2288,9 +2288,8 @@ class JoinConditionErrorTest(fixtures.TestBase):
 
             assert_raises_message(
                 sa.exc.ArgumentError,
-                "Column-based expression object expected "
-                "for argument '%s'; got: '%s', type %r"
-                % (argname, arg[0], type(arg[0])),
+                "Column expression expected "
+                "for argument '%s'; got '%s'" % (argname, arg[0]),
                 configure_mappers,
             )
 
index 9e3f8074fc47c9ad372db8f011a0dad6d78dffb5..217a4f77aaf42035458399cbf294e979d3b94167 100644 (file)
@@ -326,13 +326,15 @@ class UpdateDeleteTest(fixtures.MappedTest):
         assert_raises(
             exc.InvalidRequestError,
             sess.query(User)
-            .filter(User.name == select([func.max(User.name)]))
+            .filter(
+                User.name == select([func.max(User.name)]).scalar_subquery()
+            )
             .delete,
             synchronize_session="evaluate",
         )
 
         sess.query(User).filter(
-            User.name == select([func.max(User.name)])
+            User.name == select([func.max(User.name)]).scalar_subquery()
         ).delete(synchronize_session="fetch")
 
         assert john not in sess
@@ -969,7 +971,7 @@ class UpdateDeleteFromTest(fixtures.MappedTest):
         subq = (
             s.query(func.max(Document.title).label("title"))
             .group_by(Document.user_id)
-            .subquery()
+            .scalar_subquery()
         )
 
         s.query(Document).filter(Document.title.in_(subq)).update(
@@ -999,7 +1001,7 @@ class UpdateDeleteFromTest(fixtures.MappedTest):
         subq = (
             s.query(func.max(Document.title).label("title"))
             .group_by(Document.user_id)
-            .subquery()
+            .scalar_subquery()
         )
 
         # this would work with Firebird if you do literal_column('1')
index 5c51d623f6df9fb1b10be89b183ebc290505d17c..d12b3c4ae068bcc6eead58265197af70af7b4dcd 100644 (file)
@@ -1,15 +1,15 @@
 # /home/classic/dev/sqlalchemy/test/profiles.txt
 # This file is written out on a per-environment basis.
-# For each test in aaa_profiling, the corresponding function and 
+# For each test in aaa_profiling, the corresponding function and
 # environment is located within this file.  If it doesn't exist,
 # the test is skipped.
-# If a callcount does exist, it is compared to what we received. 
+# If a callcount does exist, it is compared to what we received.
 # assertions are raised if the counts do not match.
-# 
-# To add a new callcount test, apply the function_call_count 
-# decorator and re-run the tests using the --write-profiles 
+#
+# To add a new callcount test, apply the function_call_count
+# decorator and re-run the tests using the --write-profiles
 # option - this file will be rewritten including the new count.
-# 
+#
 
 # TEST: test.aaa_profiling.test_compiler.CompileTest.test_insert
 
@@ -803,14 +803,14 @@ test.aaa_profiling.test_resultset.ResultSetTest.test_unicode 3.7_sqlite_pysqlite
 
 # TEST: test.aaa_profiling.test_zoomark.ZooMarkTest.test_invocation
 
-test.aaa_profiling.test_zoomark.ZooMarkTest.test_invocation 2.7_postgresql_psycopg2_dbapiunicode_cextensions 6010,303,3905,12590,1177,2109,2565
-test.aaa_profiling.test_zoomark.ZooMarkTest.test_invocation 2.7_postgresql_psycopg2_dbapiunicode_nocextensions 6054,303,4025,13894,1292,2122,2796
-test.aaa_profiling.test_zoomark.ZooMarkTest.test_invocation 3.7_postgresql_psycopg2_dbapiunicode_cextensions 5742,282,3865,12494,1163,2042,2604
-test.aaa_profiling.test_zoomark.ZooMarkTest.test_invocation 3.7_postgresql_psycopg2_dbapiunicode_nocextensions 5808,282,3993,13758,1273,2061,2816
+test.aaa_profiling.test_zoomark.ZooMarkTest.test_invocation 2.7_postgresql_psycopg2_dbapiunicode_cextensions 6401,324,3953,12769,1200,2189,2624
+test.aaa_profiling.test_zoomark.ZooMarkTest.test_invocation 2.7_postgresql_psycopg2_dbapiunicode_nocextensions 6445,324,4073,14073,1315,2202,2855
+test.aaa_profiling.test_zoomark.ZooMarkTest.test_invocation 3.7_postgresql_psycopg2_dbapiunicode_cextensions 6116,303,3913,12679,1186,2122,2667
+test.aaa_profiling.test_zoomark.ZooMarkTest.test_invocation 3.7_postgresql_psycopg2_dbapiunicode_nocextensions 6165,303,4041,13945,1295,2141,2883
 
 # TEST: test.aaa_profiling.test_zoomark_orm.ZooMarkTest.test_invocation
 
-test.aaa_profiling.test_zoomark_orm.ZooMarkTest.test_invocation 2.7_postgresql_psycopg2_dbapiunicode_cextensions 6627,413,6961,18335,1191,2723
-test.aaa_profiling.test_zoomark_orm.ZooMarkTest.test_invocation 2.7_postgresql_psycopg2_dbapiunicode_nocextensions 6719,418,7081,19404,1297,2758
-test.aaa_profiling.test_zoomark_orm.ZooMarkTest.test_invocation 3.7_postgresql_psycopg2_dbapiunicode_cextensions 6601,403,7149,18918,1178,2790
-test.aaa_profiling.test_zoomark_orm.ZooMarkTest.test_invocation 3.7_postgresql_psycopg2_dbapiunicode_nocextensions 6696,408,7285,20029,1280,2831
+test.aaa_profiling.test_zoomark_orm.ZooMarkTest.test_invocation 2.7_postgresql_psycopg2_dbapiunicode_cextensions 6957,409,7143,18720,1214,2829
+test.aaa_profiling.test_zoomark_orm.ZooMarkTest.test_invocation 2.7_postgresql_psycopg2_dbapiunicode_nocextensions 7053,414,7263,19789,1320,2864
+test.aaa_profiling.test_zoomark_orm.ZooMarkTest.test_invocation 3.7_postgresql_psycopg2_dbapiunicode_cextensions 6946,400,7331,19307,1201,2895
+test.aaa_profiling.test_zoomark_orm.ZooMarkTest.test_invocation 3.7_postgresql_psycopg2_dbapiunicode_nocextensions 7046,405,7467,20418,1302,2935
index 7bb612cfa42f4a50714707f4593f53828f44cdf5..491ff42bc9e2938b65ece8075c9f77cb970b0ad3 100644 (file)
@@ -131,7 +131,7 @@ class CaseTest(fixtures.TestBase, AssertsCompiledSQL):
     def test_literal_interpretation_ambiguous(self):
         assert_raises_message(
             exc.ArgumentError,
-            r"Ambiguous literal: 'x'.  Use the 'text\(\)' function",
+            r"Column expression expected, got 'x'",
             case,
             [("x", "y")],
         )
@@ -139,7 +139,7 @@ class CaseTest(fixtures.TestBase, AssertsCompiledSQL):
     def test_literal_interpretation_ambiguous_tuple(self):
         assert_raises_message(
             exc.ArgumentError,
-            r"Ambiguous literal: \('x', 'y'\).  Use the 'text\(\)' function",
+            r"Column expression expected, got \('x', 'y'\)",
             case,
             [(("x", "y"), "z")],
         )
index 3608851ed6e79b2ab2c55c5cc04c0918d5bad8d3..f2feea757213a78399d9c2845e5f96ac8dc960d9 100644 (file)
@@ -243,8 +243,8 @@ class CompareAndCopyTest(fixtures.TestBase):
             FromGrouping(table_a.alias("b")),
         ),
         lambda: (
-            select([table_a.c.a]).as_scalar(),
-            select([table_a.c.a]).where(table_a.c.b == 5).as_scalar(),
+            select([table_a.c.a]).scalar_subquery(),
+            select([table_a.c.a]).where(table_a.c.b == 5).scalar_subquery(),
         ),
         lambda: (
             exists().where(table_a.c.a == 5),
@@ -291,6 +291,7 @@ class CompareAndCopyTest(fixtures.TestBase):
             and "__init__" in cls.__dict__
             and not issubclass(cls, (Annotated))
             and "orm" not in cls.__module__
+            and "compiler" not in cls.__module__
             and "crud" not in cls.__module__
             and "dialects" not in cls.__module__  # TODO: dialects?
         ).difference({ColumnElement, UnaryExpression})
index 9d6e17a1d285fd9fd5e608561cec0d7fb4876158..e012c2713eccb79b84ebe4e4fd61e6c4ccf2a53e 100644 (file)
@@ -71,7 +71,6 @@ from sqlalchemy.sql import column
 from sqlalchemy.sql import compiler
 from sqlalchemy.sql import label
 from sqlalchemy.sql import table
-from sqlalchemy.sql.expression import _literal_as_text
 from sqlalchemy.sql.expression import ClauseList
 from sqlalchemy.sql.expression import HasPrefixes
 from sqlalchemy.testing import assert_raises
@@ -165,7 +164,8 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
                 "columns; use this object directly within a "
                 "column-level expression.",
                 lambda: hasattr(
-                    select([table1.c.myid]).as_scalar().self_group(), "columns"
+                    select([table1.c.myid]).scalar_subquery().self_group(),
+                    "columns",
                 ),
             )
             assert_raises_message(
@@ -174,14 +174,17 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
                 "columns; use this object directly within a "
                 "column-level expression.",
                 lambda: hasattr(
-                    select([table1.c.myid]).as_scalar(), "columns"
+                    select([table1.c.myid]).scalar_subquery(), "columns"
                 ),
             )
         else:
             assert not hasattr(
-                select([table1.c.myid]).as_scalar().self_group(), "columns"
+                select([table1.c.myid]).scalar_subquery().self_group(),
+                "columns",
+            )
+            assert not hasattr(
+                select([table1.c.myid]).scalar_subquery(), "columns"
             )
-            assert not hasattr(select([table1.c.myid]).as_scalar(), "columns")
 
     def test_prefix_constructor(self):
         class Pref(HasPrefixes):
@@ -327,7 +330,12 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
         )
 
     def test_select_precol_compile_ordering(self):
-        s1 = select([column("x")]).select_from(text("a")).limit(5).as_scalar()
+        s1 = (
+            select([column("x")])
+            .select_from(text("a"))
+            .limit(5)
+            .scalar_subquery()
+        )
         s2 = select([s1]).limit(10)
 
         class MyCompiler(compiler.SQLCompiler):
@@ -631,7 +639,7 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
         )
 
         self.assert_compile(
-            exists(s.as_scalar()),
+            exists(s.scalar_subquery()),
             "EXISTS (SELECT mytable.myid FROM mytable "
             "WHERE mytable.myid = :myid_1)",
         )
@@ -754,22 +762,12 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
             "addresses, users WHERE addresses.user_id = "
             "users.user_id) AS s",
         )
-        self.assert_compile(
-            table1.select(
-                table1.c.myid
-                == select([table1.c.myid], table1.c.name == "jack")
-            ),
-            "SELECT mytable.myid, mytable.name, "
-            "mytable.description FROM mytable WHERE "
-            "mytable.myid = (SELECT mytable.myid FROM "
-            "mytable WHERE mytable.name = :name_1)",
-        )
         self.assert_compile(
             table1.select(
                 table1.c.myid
                 == select(
                     [table2.c.otherid], table1.c.name == table2.c.othername
-                )
+                ).scalar_subquery()
             ),
             "SELECT mytable.myid, mytable.name, "
             "mytable.description FROM mytable WHERE "
@@ -822,7 +820,7 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
                 order_by=[
                     select(
                         [table2.c.otherid], table1.c.myid == table2.c.otherid
-                    )
+                    ).scalar_subquery()
                 ]
             ),
             "SELECT mytable.myid, mytable.name, "
@@ -831,42 +829,22 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
             "myothertable WHERE mytable.myid = "
             "myothertable.otherid)",
         )
-        self.assert_compile(
-            table1.select(
-                order_by=[
-                    desc(
-                        select(
-                            [table2.c.otherid],
-                            table1.c.myid == table2.c.otherid,
-                        )
-                    )
-                ]
-            ),
-            "SELECT mytable.myid, mytable.name, "
-            "mytable.description FROM mytable ORDER BY "
-            "(SELECT myothertable.otherid FROM "
-            "myothertable WHERE mytable.myid = "
-            "myothertable.otherid) DESC",
-        )
 
     def test_scalar_select(self):
-        assert_raises_message(
-            exc.InvalidRequestError,
-            r"Select objects don't have a type\.  Call as_scalar\(\) "
-            r"on this Select object to return a 'scalar' "
-            r"version of this Select\.",
-            func.coalesce,
-            select([table1.c.myid]),
+
+        self.assert_compile(
+            func.coalesce(select([table1.c.myid]).scalar_subquery()),
+            "coalesce((SELECT mytable.myid FROM mytable))",
         )
 
-        s = select([table1.c.myid], correlate=False).as_scalar()
+        s = select([table1.c.myid], correlate=False).scalar_subquery()
         self.assert_compile(
             select([table1, s]),
             "SELECT mytable.myid, mytable.name, "
             "mytable.description, (SELECT mytable.myid "
             "FROM mytable) AS anon_1 FROM mytable",
         )
-        s = select([table1.c.myid]).as_scalar()
+        s = select([table1.c.myid]).scalar_subquery()
         self.assert_compile(
             select([table2, s]),
             "SELECT myothertable.otherid, "
@@ -874,7 +852,7 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
             "mytable.myid FROM mytable) AS anon_1 FROM "
             "myothertable",
         )
-        s = select([table1.c.myid]).correlate(None).as_scalar()
+        s = select([table1.c.myid]).correlate(None).scalar_subquery()
         self.assert_compile(
             select([table1, s]),
             "SELECT mytable.myid, mytable.name, "
@@ -882,17 +860,17 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
             "FROM mytable) AS anon_1 FROM mytable",
         )
 
-        s = select([table1.c.myid]).as_scalar()
+        s = select([table1.c.myid]).scalar_subquery()
         s2 = s.where(table1.c.myid == 5)
         self.assert_compile(
             s2,
             "(SELECT mytable.myid FROM mytable WHERE mytable.myid = :myid_1)",
         )
         self.assert_compile(s, "(SELECT mytable.myid FROM mytable)")
-        # test that aliases use as_scalar() when used in an explicitly
+        # test that aliases use scalar_subquery() when used in an explicitly
         # scalar context
 
-        s = select([table1.c.myid]).alias()
+        s = select([table1.c.myid]).scalar_subquery()
         self.assert_compile(
             select([table1.c.myid]).where(table1.c.myid == s),
             "SELECT mytable.myid FROM mytable WHERE "
@@ -902,10 +880,9 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
         self.assert_compile(
             select([table1.c.myid]).where(s > table1.c.myid),
             "SELECT mytable.myid FROM mytable WHERE "
-            "mytable.myid < (SELECT mytable.myid FROM "
-            "mytable)",
+            "(SELECT mytable.myid FROM mytable) > mytable.myid",
         )
-        s = select([table1.c.myid]).as_scalar()
+        s = select([table1.c.myid]).scalar_subquery()
         self.assert_compile(
             select([table2, s]),
             "SELECT myothertable.otherid, "
@@ -922,7 +899,7 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
             "- :param_1 AS anon_1",
         )
         self.assert_compile(
-            select([select([table1.c.name]).as_scalar() + literal("x")]),
+            select([select([table1.c.name]).scalar_subquery() + literal("x")]),
             "SELECT (SELECT mytable.name FROM mytable) "
             "|| :param_1 AS anon_1",
         )
@@ -939,7 +916,7 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
         # scalar selects should not have any attributes on their 'c' or
         # 'columns' attribute
 
-        s = select([table1.c.myid]).as_scalar()
+        s = select([table1.c.myid]).scalar_subquery()
         try:
             s.c.foo
         except exc.InvalidRequestError as err:
@@ -965,12 +942,12 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
         qlat = (
             select([zips.c.latitude], zips.c.zipcode == zipcode)
             .correlate(None)
-            .as_scalar()
+            .scalar_subquery()
         )
         qlng = (
             select([zips.c.longitude], zips.c.zipcode == zipcode)
             .correlate(None)
-            .as_scalar()
+            .scalar_subquery()
         )
 
         q = select(
@@ -999,10 +976,10 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
         zalias = zips.alias("main_zip")
         qlat = select(
             [zips.c.latitude], zips.c.zipcode == zalias.c.zipcode
-        ).as_scalar()
+        ).scalar_subquery()
         qlng = select(
             [zips.c.longitude], zips.c.zipcode == zalias.c.zipcode
-        ).as_scalar()
+        ).scalar_subquery()
         q = select(
             [
                 places.c.id,
@@ -1025,7 +1002,9 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
         )
 
         a1 = table2.alias("t2alias")
-        s1 = select([a1.c.otherid], table1.c.myid == a1.c.otherid).as_scalar()
+        s1 = select(
+            [a1.c.otherid], table1.c.myid == a1.c.otherid
+        ).scalar_subquery()
         j1 = table1.join(table2, table1.c.myid == table2.c.otherid)
         s2 = select([table1, s1], from_obj=j1)
         self.assert_compile(
@@ -2337,7 +2316,11 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
         assert s3.compile().params == {"myid": 9, "myotherid": 7}
 
         # test using same 'unique' param object twice in one compile
-        s = select([table1.c.myid]).where(table1.c.myid == 12).as_scalar()
+        s = (
+            select([table1.c.myid])
+            .where(table1.c.myid == 12)
+            .scalar_subquery()
+        )
         s2 = select([table1, s], table1.c.myid == s)
         self.assert_compile(
             s2,
@@ -2884,7 +2867,7 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
             lambda: sel1.c,
         )
 
-        # calling label or as_scalar doesn't compile
+        # calling label or scalar_subquery doesn't compile
         # anything.
         sel2 = select([func.substr(my_str, 2, 3)]).label("my_substr")
 
@@ -2895,7 +2878,7 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
             dialect=default.DefaultDialect(),
         )
 
-        sel3 = select([my_str]).as_scalar()
+        sel3 = select([my_str]).scalar_subquery()
         assert_raises_message(
             exc.CompileError,
             "Cannot compile Column object until its 'name' is assigned.",
@@ -3919,13 +3902,13 @@ class CorrelateTest(fixtures.TestBase, AssertsCompiledSQL):
     def test_correlate_semiauto_where(self):
         t1, t2, s1 = self._fixture()
         self._assert_where_correlated(
-            select([t2]).where(t2.c.a == s1.correlate(t2))
+            select([t2]).where(t2.c.a == s1.correlate(t2).scalar_subquery())
         )
 
     def test_correlate_semiauto_column(self):
         t1, t2, s1 = self._fixture()
         self._assert_column_correlated(
-            select([t2, s1.correlate(t2).as_scalar()])
+            select([t2, s1.correlate(t2).scalar_subquery()])
         )
 
     def test_correlate_semiauto_from(self):
@@ -3935,31 +3918,35 @@ class CorrelateTest(fixtures.TestBase, AssertsCompiledSQL):
     def test_correlate_semiauto_having(self):
         t1, t2, s1 = self._fixture()
         self._assert_having_correlated(
-            select([t2]).having(t2.c.a == s1.correlate(t2))
+            select([t2]).having(t2.c.a == s1.correlate(t2).scalar_subquery())
         )
 
     def test_correlate_except_inclusion_where(self):
         t1, t2, s1 = self._fixture()
         self._assert_where_correlated(
-            select([t2]).where(t2.c.a == s1.correlate_except(t1))
+            select([t2]).where(
+                t2.c.a == s1.correlate_except(t1).scalar_subquery()
+            )
         )
 
     def test_correlate_except_exclusion_where(self):
         t1, t2, s1 = self._fixture()
         self._assert_where_uncorrelated(
-            select([t2]).where(t2.c.a == s1.correlate_except(t2))
+            select([t2]).where(
+                t2.c.a == s1.correlate_except(t2).scalar_subquery()
+            )
         )
 
     def test_correlate_except_inclusion_column(self):
         t1, t2, s1 = self._fixture()
         self._assert_column_correlated(
-            select([t2, s1.correlate_except(t1).as_scalar()])
+            select([t2, s1.correlate_except(t1).scalar_subquery()])
         )
 
     def test_correlate_except_exclusion_column(self):
         t1, t2, s1 = self._fixture()
         self._assert_column_uncorrelated(
-            select([t2, s1.correlate_except(t2).as_scalar()])
+            select([t2, s1.correlate_except(t2).scalar_subquery()])
         )
 
     def test_correlate_except_inclusion_from(self):
@@ -3977,22 +3964,28 @@ class CorrelateTest(fixtures.TestBase, AssertsCompiledSQL):
     def test_correlate_except_none(self):
         t1, t2, s1 = self._fixture()
         self._assert_where_all_correlated(
-            select([t1, t2]).where(t2.c.a == s1.correlate_except(None))
+            select([t1, t2]).where(
+                t2.c.a == s1.correlate_except(None).scalar_subquery()
+            )
         )
 
     def test_correlate_except_having(self):
         t1, t2, s1 = self._fixture()
         self._assert_having_correlated(
-            select([t2]).having(t2.c.a == s1.correlate_except(t1))
+            select([t2]).having(
+                t2.c.a == s1.correlate_except(t1).scalar_subquery()
+            )
         )
 
     def test_correlate_auto_where(self):
         t1, t2, s1 = self._fixture()
-        self._assert_where_correlated(select([t2]).where(t2.c.a == s1))
+        self._assert_where_correlated(
+            select([t2]).where(t2.c.a == s1.scalar_subquery())
+        )
 
     def test_correlate_auto_column(self):
         t1, t2, s1 = self._fixture()
-        self._assert_column_correlated(select([t2, s1.as_scalar()]))
+        self._assert_column_correlated(select([t2, s1.scalar_subquery()]))
 
     def test_correlate_auto_from(self):
         t1, t2, s1 = self._fixture()
@@ -4000,18 +3993,20 @@ class CorrelateTest(fixtures.TestBase, AssertsCompiledSQL):
 
     def test_correlate_auto_having(self):
         t1, t2, s1 = self._fixture()
-        self._assert_having_correlated(select([t2]).having(t2.c.a == s1))
+        self._assert_having_correlated(
+            select([t2]).having(t2.c.a == s1.scalar_subquery())
+        )
 
     def test_correlate_disabled_where(self):
         t1, t2, s1 = self._fixture()
         self._assert_where_uncorrelated(
-            select([t2]).where(t2.c.a == s1.correlate(None))
+            select([t2]).where(t2.c.a == s1.correlate(None).scalar_subquery())
         )
 
     def test_correlate_disabled_column(self):
         t1, t2, s1 = self._fixture()
         self._assert_column_uncorrelated(
-            select([t2, s1.correlate(None).as_scalar()])
+            select([t2, s1.correlate(None).scalar_subquery()])
         )
 
     def test_correlate_disabled_from(self):
@@ -4023,19 +4018,21 @@ class CorrelateTest(fixtures.TestBase, AssertsCompiledSQL):
     def test_correlate_disabled_having(self):
         t1, t2, s1 = self._fixture()
         self._assert_having_uncorrelated(
-            select([t2]).having(t2.c.a == s1.correlate(None))
+            select([t2]).having(t2.c.a == s1.correlate(None).scalar_subquery())
         )
 
     def test_correlate_all_where(self):
         t1, t2, s1 = self._fixture()
         self._assert_where_all_correlated(
-            select([t1, t2]).where(t2.c.a == s1.correlate(t1, t2))
+            select([t1, t2]).where(
+                t2.c.a == s1.correlate(t1, t2).scalar_subquery()
+            )
         )
 
     def test_correlate_all_column(self):
         t1, t2, s1 = self._fixture()
         self._assert_column_all_correlated(
-            select([t1, t2, s1.correlate(t1, t2).as_scalar()])
+            select([t1, t2, s1.correlate(t1, t2).scalar_subquery()])
         )
 
     def test_correlate_all_from(self):
@@ -4049,7 +4046,7 @@ class CorrelateTest(fixtures.TestBase, AssertsCompiledSQL):
         assert_raises_message(
             exc.InvalidRequestError,
             "returned no FROM clauses due to auto-correlation",
-            select([t1, t2]).where(t2.c.a == s1).compile,
+            select([t1, t2]).where(t2.c.a == s1.scalar_subquery()).compile,
         )
 
     def test_correlate_from_all_ok(self):
@@ -4063,7 +4060,7 @@ class CorrelateTest(fixtures.TestBase, AssertsCompiledSQL):
     def test_correlate_auto_where_singlefrom(self):
         t1, t2, s1 = self._fixture()
         s = select([t1.c.a])
-        s2 = select([t1]).where(t1.c.a == s)
+        s2 = select([t1]).where(t1.c.a == s.scalar_subquery())
         self.assert_compile(
             s2, "SELECT t1.a FROM t1 WHERE t1.a = " "(SELECT t1.a FROM t1)"
         )
@@ -4073,7 +4070,7 @@ class CorrelateTest(fixtures.TestBase, AssertsCompiledSQL):
 
         s = select([t1.c.a])
 
-        s2 = select([t1]).where(t1.c.a == s.correlate(t1))
+        s2 = select([t1]).where(t1.c.a == s.correlate(t1).scalar_subquery())
         self._assert_where_single_full_correlated(s2)
 
     def test_correlate_except_semiauto_where_singlefrom(self):
@@ -4081,7 +4078,9 @@ class CorrelateTest(fixtures.TestBase, AssertsCompiledSQL):
 
         s = select([t1.c.a])
 
-        s2 = select([t1]).where(t1.c.a == s.correlate_except(t2))
+        s2 = select([t1]).where(
+            t1.c.a == s.correlate_except(t2).scalar_subquery()
+        )
         self._assert_where_single_full_correlated(s2)
 
     def test_correlate_alone_noeffect(self):
@@ -4098,7 +4097,7 @@ class CorrelateTest(fixtures.TestBase, AssertsCompiledSQL):
         s = select([t2.c.b]).where(t1.c.a == t2.c.a)
         s = s.correlate_except(t2).alias("s")
 
-        s2 = select([func.foo(s.c.b)]).as_scalar()
+        s2 = select([func.foo(s.c.b)]).scalar_subquery()
         s3 = select([t1], order_by=s2)
 
         self.assert_compile(
@@ -4155,8 +4154,8 @@ class CorrelateTest(fixtures.TestBase, AssertsCompiledSQL):
         t3 = table("t3", column("z"))
 
         s = select([t1.c.x]).where(t1.c.x == t2.c.y)
-        s2 = select([t3.c.z]).where(t3.c.z == s.as_scalar())
-        s3 = select([t1]).where(t1.c.x == s2.as_scalar())
+        s2 = select([t3.c.z]).where(t3.c.z == s.scalar_subquery())
+        s3 = select([t1]).where(t1.c.x == s2.scalar_subquery())
 
         self.assert_compile(
             s3,
@@ -4214,15 +4213,6 @@ class CoercionTest(fixtures.TestBase, AssertsCompiledSQL):
             dialect=default.DefaultDialect(supports_native_boolean=False),
         )
 
-    def test_null_constant(self):
-        self.assert_compile(_literal_as_text(None), "NULL")
-
-    def test_false_constant(self):
-        self.assert_compile(_literal_as_text(False), "false")
-
-    def test_true_constant(self):
-        self.assert_compile(_literal_as_text(True), "true")
-
     def test_val_and_false(self):
         t = self._fixture()
         self.assert_compile(and_(t.c.id == 1, False), "false")
index 7008bc1cca5631718be3130619e456ac37f807e8..ac46e7d5d94c96a93eaff48fc7aebead8e7180df 100644 (file)
@@ -47,7 +47,9 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL):
             select([regional_sales.c.region])
             .where(
                 regional_sales.c.total_sales
-                > select([func.sum(regional_sales.c.total_sales) / 10])
+                > select(
+                    [func.sum(regional_sales.c.total_sales) / 10]
+                ).scalar_subquery()
             )
             .cte("top_regions")
         )
index 76ef38e1f1379a80ca7ae024a2e228b8c8a5be4e..ed7af2572e2489ca8091e3d7a07016d68a8a03d7 100644 (file)
@@ -657,8 +657,8 @@ class DefaultTest(fixtures.TestBase):
         ):
             assert_raises_message(
                 sa.exc.ArgumentError,
-                "SQL expression object expected, got object of type "
-                "<.* 'list'> instead",
+                r"SQL expression for WHERE/HAVING role expected, "
+                r"got \[(?:Sequence|ColumnDefault|DefaultClause)\('y'.*\)\]",
                 t.select,
                 [const],
             )
@@ -913,7 +913,7 @@ class PKDefaultTest(fixtures.TablesTest):
                 "id",
                 Integer,
                 primary_key=True,
-                default=sa.select([func.max(t2.c.nextid)]).as_scalar(),
+                default=sa.select([func.max(t2.c.nextid)]).scalar_subquery(),
             ),
             Column("data", String(30)),
         )
@@ -1740,7 +1740,7 @@ class SpecialTypePKTest(fixtures.TestBase):
         self._run_test(server_default="1", autoincrement=False)
 
     def test_clause(self):
-        stmt = select([cast("INT_1", type_=self.MyInteger)]).as_scalar()
+        stmt = select([cast("INT_1", type_=self.MyInteger)]).scalar_subquery()
         self._run_test(default=stmt)
 
     @testing.requires.returning
index f572a510ca5a625063a9180a51b65eddd55eeb34..1f4c49c562a8443c3526b2c5d02e5cfb8d31f431 100644 (file)
@@ -109,13 +109,13 @@ class DeleteTest(_DeleteTestBase, fixtures.TablesTest, AssertsCompiledSQL):
             stmt, "DELETE FROM mytable AS t1 WHERE t1.myid = :myid_1"
         )
 
-    def test_correlated(self):
+    def test_non_correlated_select(self):
         table1, table2 = self.tables.mytable, self.tables.myothertable
 
         # test a non-correlated WHERE clause
         s = select([table2.c.othername], table2.c.otherid == 7)
         self.assert_compile(
-            delete(table1, table1.c.name == s),
+            delete(table1, table1.c.name == s.scalar_subquery()),
             "DELETE FROM mytable "
             "WHERE mytable.name = ("
             "SELECT myothertable.othername "
@@ -124,10 +124,13 @@ class DeleteTest(_DeleteTestBase, fixtures.TablesTest, AssertsCompiledSQL):
             ")",
         )
 
+    def test_correlated_select(self):
+        table1, table2 = self.tables.mytable, self.tables.myothertable
+
         # test one that is actually correlated...
         s = select([table2.c.othername], table2.c.otherid == table1.c.myid)
         self.assert_compile(
-            table1.delete(table1.c.name == s),
+            table1.delete(table1.c.name == s.scalar_subquery()),
             "DELETE FROM mytable "
             "WHERE mytable.name = ("
             "SELECT myothertable.othername "
index 7990cd56c69ab2301901bee4f5651fe26fa33233..8e8591aecc0d5c42efbbdc72220c14ddec9f2f6b 100644 (file)
@@ -6,6 +6,7 @@ from sqlalchemy import column
 from sqlalchemy import create_engine
 from sqlalchemy import exc
 from sqlalchemy import ForeignKey
+from sqlalchemy import func
 from sqlalchemy import Integer
 from sqlalchemy import MetaData
 from sqlalchemy import select
@@ -17,6 +18,8 @@ from sqlalchemy import text
 from sqlalchemy import util
 from sqlalchemy.engine import default
 from sqlalchemy.schema import DDL
+from sqlalchemy.sql import coercions
+from sqlalchemy.sql import roles
 from sqlalchemy.sql import util as sql_util
 from sqlalchemy.testing import assert_raises
 from sqlalchemy.testing import assert_raises_message
@@ -24,6 +27,7 @@ from sqlalchemy.testing import AssertsCompiledSQL
 from sqlalchemy.testing import engines
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import fixtures
+from sqlalchemy.testing import is_true
 from sqlalchemy.testing import mock
 
 
@@ -372,6 +376,117 @@ class ForUpdateTest(fixtures.TestBase, AssertsCompiledSQL):
         eq_(s._for_update_arg.nowait, True)
 
 
+class SubqueryCoercionsTest(fixtures.TestBase, AssertsCompiledSQL):
+    def test_column_roles(self):
+        stmt = select([table1.c.myid])
+
+        for role in [
+            roles.WhereHavingRole,
+            roles.ExpressionElementRole,
+            roles.ByOfRole,
+            roles.OrderByRole,
+            # roles.LabeledColumnExprRole
+        ]:
+            with testing.expect_deprecated(
+                "coercing SELECT object to scalar "
+                "subquery in a column-expression context is deprecated"
+            ):
+                coerced = coercions.expect(role, stmt)
+                is_true(coerced.compare(stmt.scalar_subquery()))
+
+            with testing.expect_deprecated(
+                "coercing SELECT object to scalar "
+                "subquery in a column-expression context is deprecated"
+            ):
+                coerced = coercions.expect(role, stmt.alias())
+                is_true(coerced.compare(stmt.scalar_subquery()))
+
+    def test_labeled_role(self):
+        stmt = select([table1.c.myid])
+
+        with testing.expect_deprecated(
+            "coercing SELECT object to scalar "
+            "subquery in a column-expression context is deprecated"
+        ):
+            coerced = coercions.expect(roles.LabeledColumnExprRole, stmt)
+            is_true(coerced.compare(stmt.scalar_subquery().label(None)))
+
+        with testing.expect_deprecated(
+            "coercing SELECT object to scalar "
+            "subquery in a column-expression context is deprecated"
+        ):
+            coerced = coercions.expect(
+                roles.LabeledColumnExprRole, stmt.alias()
+            )
+            is_true(coerced.compare(stmt.scalar_subquery().label(None)))
+
+    def test_scalar_select(self):
+
+        with testing.expect_deprecated(
+            "coercing SELECT object to scalar "
+            "subquery in a column-expression context is deprecated"
+        ):
+            self.assert_compile(
+                func.coalesce(select([table1.c.myid])),
+                "coalesce((SELECT mytable.myid FROM mytable))",
+            )
+
+        with testing.expect_deprecated(
+            "coercing SELECT object to scalar "
+            "subquery in a column-expression context is deprecated"
+        ):
+            s = select([table1.c.myid]).alias()
+            self.assert_compile(
+                select([table1.c.myid]).where(table1.c.myid == s),
+                "SELECT mytable.myid FROM mytable WHERE "
+                "mytable.myid = (SELECT mytable.myid FROM "
+                "mytable)",
+            )
+
+        with testing.expect_deprecated(
+            "coercing SELECT object to scalar "
+            "subquery in a column-expression context is deprecated"
+        ):
+            self.assert_compile(
+                select([table1.c.myid]).where(s > table1.c.myid),
+                "SELECT mytable.myid FROM mytable WHERE "
+                "mytable.myid < (SELECT mytable.myid FROM "
+                "mytable)",
+            )
+
+        with testing.expect_deprecated(
+            "coercing SELECT object to scalar "
+            "subquery in a column-expression context is deprecated"
+        ):
+            s = select([table1.c.myid]).alias()
+            self.assert_compile(
+                select([table1.c.myid]).where(table1.c.myid == s),
+                "SELECT mytable.myid FROM mytable WHERE "
+                "mytable.myid = (SELECT mytable.myid FROM "
+                "mytable)",
+            )
+
+        with testing.expect_deprecated(
+            "coercing SELECT object to scalar "
+            "subquery in a column-expression context is deprecated"
+        ):
+            self.assert_compile(
+                select([table1.c.myid]).where(s > table1.c.myid),
+                "SELECT mytable.myid FROM mytable WHERE "
+                "mytable.myid < (SELECT mytable.myid FROM "
+                "mytable)",
+            )
+
+    def test_as_scalar(self):
+        with testing.expect_deprecated(
+            r"The SelectBase.as_scalar\(\) method is deprecated and "
+            "will be removed in a future release."
+        ):
+            stmt = select([table1.c.myid]).as_scalar()
+
+        is_true(stmt.compare(select([table1.c.myid]).scalar_subquery()))
+
+
 class TextTest(fixtures.TestBase, AssertsCompiledSQL):
     __dialect__ = "default"
 
@@ -425,3 +540,11 @@ class TextTest(fixtures.TestBase, AssertsCompiledSQL):
             "The text.autocommit parameter is deprecated"
         ):
             t = text("select id, name from user", autocommit=True)
+
+
+table1 = table(
+    "mytable",
+    column("myid", Integer),
+    column("name", String),
+    column("description", String),
+)
index 9b902e8ffd7c8801265c081b7076fe3b1e32f819..da139d7c04d17b2770d1d3b6f4f17fbbba7cf947 100644 (file)
@@ -22,7 +22,7 @@ from sqlalchemy.sql import operators
 from sqlalchemy.sql import table
 from sqlalchemy.sql import util as sql_util
 from sqlalchemy.sql import visitors
-from sqlalchemy.sql.expression import _clone
+from sqlalchemy.sql.elements import _clone
 from sqlalchemy.sql.expression import _from_objects
 from sqlalchemy.sql.visitors import ClauseVisitor
 from sqlalchemy.sql.visitors import cloned_traverse
@@ -306,7 +306,7 @@ class BinaryEndpointTraversalTest(fixtures.TestBase):
 
     def test_subquery(self):
         a, b, c = column("a"), column("b"), column("c")
-        subq = select([c]).where(c == a).as_scalar()
+        subq = select([c]).where(c == a).scalar_subquery()
         expr = and_(a == b, b == subq)
         self._assert_traversal(
             expr, [(operators.eq, a, b), (operators.eq, b, subq)]
@@ -706,7 +706,9 @@ class ClauseTest(fixtures.TestBase, AssertsCompiledSQL):
                 select.append_whereclause(t1.c.col2 == 7)
 
         self.assert_compile(
-            select([t2]).where(t2.c.col1 == Vis().traverse(s)),
+            select([t2]).where(
+                t2.c.col1 == Vis().traverse(s).scalar_subquery()
+            ),
             "SELECT table2.col1, table2.col2, table2.col3 "
             "FROM table2 WHERE table2.col1 = "
             "(SELECT * FROM table1 WHERE table1.col1 = table2.col1 "
@@ -739,7 +741,7 @@ class ClauseTest(fixtures.TestBase, AssertsCompiledSQL):
         t1a = t1.alias()
 
         s = select([1], t1.c.col1 == t1a.c.col1, from_obj=t1a).correlate(t1a)
-        s = select([t1]).where(t1.c.col1 == s)
+        s = select([t1]).where(t1.c.col1 == s.scalar_subquery())
         self.assert_compile(
             s,
             "SELECT table1.col1, table1.col2, table1.col3 FROM table1 "
@@ -760,7 +762,7 @@ class ClauseTest(fixtures.TestBase, AssertsCompiledSQL):
         s = select([t1]).where(t1.c.col1 == "foo").alias()
 
         s2 = select([1], t1.c.col1 == s.c.col1, from_obj=s).correlate(t1)
-        s3 = select([t1]).where(t1.c.col1 == s2)
+        s3 = select([t1]).where(t1.c.col1 == s2.scalar_subquery())
         self.assert_compile(
             s3,
             "SELECT table1.col1, table1.col2, table1.col3 "
@@ -982,7 +984,7 @@ class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL):
 
         s = select(
             [literal_column("*")], from_obj=[t1alias, t2alias]
-        ).as_scalar()
+        ).scalar_subquery()
         assert t2alias in s._froms
         assert t1alias in s._froms
 
@@ -1011,7 +1013,7 @@ class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL):
         s = (
             select([literal_column("*")], from_obj=[t1alias, t2alias])
             .correlate(t2alias)
-            .as_scalar()
+            .scalar_subquery()
         )
         self.assert_compile(
             select([literal_column("*")], t2alias.c.col1 == s),
@@ -1037,7 +1039,7 @@ class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL):
         s = (
             select([literal_column("*")])
             .where(t1.c.col1 == t2.c.col1)
-            .as_scalar()
+            .scalar_subquery()
         )
         self.assert_compile(
             select([t1.c.col1, s]),
@@ -1064,7 +1066,7 @@ class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL):
             select([literal_column("*")])
             .where(t1.c.col1 == t2.c.col1)
             .correlate(t1)
-            .as_scalar()
+            .scalar_subquery()
         )
         self.assert_compile(
             select([t1.c.col1, s]),
@@ -1102,7 +1104,7 @@ class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL):
             select([t2.c.col1])
             .where(t2.c.col1 == t1.c.col1)
             .correlate(t2)
-            .as_scalar()
+            .scalar_subquery()
         )
 
         # test subquery - given only t1 and t2 in the enclosing selectable,
@@ -1112,7 +1114,7 @@ class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL):
             select([t2.c.col1])
             .where(t2.c.col1 == t1.c.col1)
             .correlate_except(t1)
-            .as_scalar()
+            .scalar_subquery()
         )
 
         # use both subqueries in statements
@@ -1257,7 +1259,9 @@ class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL):
                         [literal_column("*")],
                         t1.c.col1 == t2.c.col2,
                         from_obj=[t1, t2],
-                    ).correlate(t1)
+                    )
+                    .correlate(t1)
+                    .scalar_subquery()
                 )
             ),
             "SELECT t1alias.col1, t1alias.col2, t1alias.col3, "
@@ -1277,7 +1281,9 @@ class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL):
                         [literal_column("*")],
                         t1.c.col1 == t2.c.col2,
                         from_obj=[t1, t2],
-                    ).correlate(t2)
+                    )
+                    .correlate(t2)
+                    .scalar_subquery()
                 )
             ),
             "SELECT t1alias.col1, t1alias.col2, t1alias.col3, "
@@ -1381,9 +1387,9 @@ class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL):
             select([t1alias, t2alias]).where(
                 t1alias.c.col1
                 == vis.traverse(
-                    select(
-                        ["*"], t1.c.col1 == t2.c.col2, from_obj=[t1, t2]
-                    ).correlate(t1)
+                    select(["*"], t1.c.col1 == t2.c.col2, from_obj=[t1, t2])
+                    .correlate(t1)
+                    .scalar_subquery()
                 )
             ),
             "SELECT t1alias.col1, t1alias.col2, t1alias.col3, "
@@ -1403,9 +1409,9 @@ class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL):
             t2alias.select().where(
                 t2alias.c.col2
                 == vis.traverse(
-                    select(
-                        ["*"], t1.c.col1 == t2.c.col2, from_obj=[t1, t2]
-                    ).correlate(t2)
+                    select(["*"], t1.c.col1 == t2.c.col2, from_obj=[t1, t2])
+                    .correlate(t2)
+                    .scalar_subquery()
                 )
             ),
             "SELECT t2alias.col1, t2alias.col2, t2alias.col3 "
index 0b7aa7a555289599828c62bb087c36ce4fee0aa9..d2a2c1c4842a942628bb5dd6c12f7e51b4c60545 100644 (file)
@@ -5,6 +5,7 @@ from sqlalchemy import inspect
 from sqlalchemy import Integer
 from sqlalchemy import MetaData
 from sqlalchemy import Table
+from sqlalchemy.sql import ClauseElement
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import is_
 
@@ -37,5 +38,13 @@ class TestCoreInspection(fixtures.TestBase):
         # absence of __clause_element__ as a test for "this is the clause
         # element" must be maintained
 
+        class Foo(ClauseElement):
+            pass
+
+        assert not hasattr(Foo(), "__clause_element__")
+
+    def test_col_now_has_a_clauseelement(self):
+
         x = Column("foo", Integer)
-        assert not hasattr(x, "__clause_element__")
+
+        assert hasattr(x, "__clause_element__")
index e91557cd159dce84c587ad666a74a2fd840eda82..b9bcfc16a1069e81bc37e53f7d0e0f60097592a8 100644 (file)
@@ -259,8 +259,8 @@ class _JoinRewriteTestBase(AssertsCompiledSQL):
         self._test(s, self._f_b1a_where_in_b2a)
 
     def test_anon_scalar_subqueries(self):
-        s1 = select([1]).as_scalar()
-        s2 = select([2]).as_scalar()
+        s1 = select([1]).scalar_subquery()
+        s2 = select([2]).scalar_subquery()
 
         s = select([s1, s2]).apply_labels()
         self._test(s, self._anon_scalar_subqueries)
index 3d60fb60ed2b76805cb3ac229d473e65f69ea596..3f9667609655d14903f4bc00f64b4540e05a5010 100644 (file)
@@ -31,7 +31,6 @@ from sqlalchemy import Unicode
 from sqlalchemy import UniqueConstraint
 from sqlalchemy import util
 from sqlalchemy.engine import default
-from sqlalchemy.sql import elements
 from sqlalchemy.sql import naming
 from sqlalchemy.testing import assert_raises
 from sqlalchemy.testing import assert_raises_message
@@ -3378,7 +3377,8 @@ class ConstraintTest(fixtures.TestBase):
 
         assert_raises_message(
             exc.ArgumentError,
-            r"Element Table\('t2', .* is not a string name or column element",
+            r"String column name or column object for DDL constraint "
+            r"expected, got .*SomeClass",
             Index,
             "foo",
             SomeClass(),
@@ -4609,7 +4609,7 @@ class NamingConventionTest(fixtures.TestBase, AssertsCompiledSQL):
         u1 = self._fixture(
             naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"}
         )
-        ck = CheckConstraint(u1.c.data == "x", name=elements._defer_name(None))
+        ck = CheckConstraint(u1.c.data == "x", name=naming._defer_name(None))
 
         assert_raises_message(
             exc.InvalidRequestError,
index c6eff6ac93c92d4e484ac13c5d6fb53324ff6c1b..f85a601bab742dd1edb0c248245c3c21e3f07faa 100644 (file)
@@ -26,6 +26,7 @@ from sqlalchemy.schema import Table
 from sqlalchemy.sql import all_
 from sqlalchemy.sql import any_
 from sqlalchemy.sql import asc
+from sqlalchemy.sql import coercions
 from sqlalchemy.sql import collate
 from sqlalchemy.sql import column
 from sqlalchemy.sql import compiler
@@ -34,10 +35,10 @@ from sqlalchemy.sql import false
 from sqlalchemy.sql import literal
 from sqlalchemy.sql import null
 from sqlalchemy.sql import operators
+from sqlalchemy.sql import roles
 from sqlalchemy.sql import sqltypes
 from sqlalchemy.sql import table
 from sqlalchemy.sql import true
-from sqlalchemy.sql.elements import _literal_as_text
 from sqlalchemy.sql.elements import BindParameter
 from sqlalchemy.sql.elements import Label
 from sqlalchemy.sql.expression import BinaryExpression
@@ -82,13 +83,17 @@ class DefaultColumnComparatorTest(fixtures.TestBase):
 
         assert left.comparator.operate(operator, right).compare(
             BinaryExpression(
-                _literal_as_text(left), _literal_as_text(right), operator
+                coercions.expect(roles.WhereHavingRole, left),
+                coercions.expect(roles.WhereHavingRole, right),
+                operator,
             )
         )
 
         assert operator(left, right).compare(
             BinaryExpression(
-                _literal_as_text(left), _literal_as_text(right), operator
+                coercions.expect(roles.WhereHavingRole, left),
+                coercions.expect(roles.WhereHavingRole, right),
+                operator,
             )
         )
 
@@ -227,8 +232,9 @@ class DefaultColumnComparatorTest(fixtures.TestBase):
         left = column("left")
         foo = ClauseList()
         assert_raises_message(
-            exc.InvalidRequestError,
-            r"in_\(\) accepts either a list of expressions, a selectable",
+            exc.ArgumentError,
+            r"IN expression list, SELECT construct, or bound parameter "
+            r"object expected, got .*ClauseList",
             left.in_,
             [foo],
         )
@@ -237,8 +243,9 @@ class DefaultColumnComparatorTest(fixtures.TestBase):
         left = column("left")
         right = column("right")
         assert_raises_message(
-            exc.InvalidRequestError,
-            r"in_\(\) accepts either a list of expressions, a selectable",
+            exc.ArgumentError,
+            r"IN expression list, SELECT construct, or bound parameter "
+            r"object expected, got .*ColumnClause",
             left.in_,
             right,
         )
@@ -253,8 +260,9 @@ class DefaultColumnComparatorTest(fixtures.TestBase):
         left = column("left")
         right = column("right", HasGetitem)
         assert_raises_message(
-            exc.InvalidRequestError,
-            r"in_\(\) accepts either a list of expressions, a selectable",
+            exc.ArgumentError,
+            r"IN expression list, SELECT construct, or bound parameter "
+            r"object expected, got .*ColumnClause",
             left.in_,
             right,
         )
@@ -1680,7 +1688,7 @@ class InTest(fixtures.TestBase, testing.AssertsCompiledSQL):
             select(
                 [
                     self.table1.c.myid.in_(
-                        select([self.table2.c.otherid]).as_scalar()
+                        select([self.table2.c.otherid]).scalar_subquery()
                     )
                 ]
             ),
@@ -1738,6 +1746,29 @@ class InTest(fixtures.TestBase, testing.AssertsCompiledSQL):
             self.table1.c.myid.in_([None]), "mytable.myid IN (NULL)"
         )
 
+    def test_in_set(self):
+        self.assert_compile(
+            self.table1.c.myid.in_({1, 2, 3}),
+            "mytable.myid IN (:myid_1, :myid_2, :myid_3)",
+        )
+
+    def test_in_arbitrary_sequence(self):
+        class MySeq(object):
+            def __init__(self, d):
+                self.d = d
+
+            def __getitem__(self, idx):
+                return self.d[idx]
+
+            def __iter__(self):
+                return iter(self.d)
+
+        seq = MySeq([1, 2, 3])
+        self.assert_compile(
+            self.table1.c.myid.in_(seq),
+            "mytable.myid IN (:myid_1, :myid_2, :myid_3)",
+        )
+
     def test_empty_in_dynamic_1(self):
         self.assert_compile(
             self.table1.c.myid.in_([]),
@@ -2073,7 +2104,9 @@ class NegationTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         assert not (self.table1.c.myid + 5)._is_implicitly_boolean
         assert not not_(column("x", Boolean))._is_implicitly_boolean
         assert (
-            not select([self.table1.c.myid]).as_scalar()._is_implicitly_boolean
+            not select([self.table1.c.myid])
+            .scalar_subquery()
+            ._is_implicitly_boolean
         )
         assert not text("x = y")._is_implicitly_boolean
         assert not literal_column("x = y")._is_implicitly_boolean
@@ -2869,7 +2902,8 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         t = self._fixture()
 
         self.assert_compile(
-            5 == any_(select([t.c.data]).where(t.c.data < 10)),
+            5
+            == any_(select([t.c.data]).where(t.c.data < 10).scalar_subquery()),
             ":param_1 = ANY (SELECT tab1.data "
             "FROM tab1 WHERE tab1.data < :data_1)",
             checkparams={"data_1": 10, "param_1": 5},
@@ -2879,7 +2913,11 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         t = self._fixture()
 
         self.assert_compile(
-            5 == select([t.c.data]).where(t.c.data < 10).as_scalar().any_(),
+            5
+            == select([t.c.data])
+            .where(t.c.data < 10)
+            .scalar_subquery()
+            .any_(),
             ":param_1 = ANY (SELECT tab1.data "
             "FROM tab1 WHERE tab1.data < :data_1)",
             checkparams={"data_1": 10, "param_1": 5},
@@ -2889,7 +2927,8 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         t = self._fixture()
 
         self.assert_compile(
-            5 == all_(select([t.c.data]).where(t.c.data < 10)),
+            5
+            == all_(select([t.c.data]).where(t.c.data < 10).scalar_subquery()),
             ":param_1 = ALL (SELECT tab1.data "
             "FROM tab1 WHERE tab1.data < :data_1)",
             checkparams={"data_1": 10, "param_1": 5},
@@ -2899,7 +2938,11 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         t = self._fixture()
 
         self.assert_compile(
-            5 == select([t.c.data]).where(t.c.data < 10).as_scalar().all_(),
+            5
+            == select([t.c.data])
+            .where(t.c.data < 10)
+            .scalar_subquery()
+            .all_(),
             ":param_1 = ALL (SELECT tab1.data "
             "FROM tab1 WHERE tab1.data < :data_1)",
             checkparams={"data_1": 10, "param_1": 5},
index 5987c77467f33630928e1e9443db9321d640e16c..6e48374ca3c32d3f13c52dfdf0b3581c2c73f99a 100644 (file)
@@ -121,7 +121,7 @@ class ResultProxyTest(fixtures.TablesTest):
         sel = (
             select([users.c.user_id])
             .where(users.c.user_name == "jack")
-            .as_scalar()
+            .scalar_subquery()
         )
         for row in select([sel + 1, sel + 3], bind=users.bind).execute():
             eq_(row["anon_1"], 8)
index aa10626c5c853aea3d634f3d5d39ef9666f31fe5..ad1eb33480806452a394e83e4c28556e797181d3 100644 (file)
@@ -222,7 +222,7 @@ class CompositeStatementTest(fixtures.TestBase):
 
             stmt = (
                 t2.insert()
-                .values(x=select([t1.c.x]).as_scalar())
+                .values(x=select([t1.c.x]).scalar_subquery())
                 .returning(t2.c.x)
             )
 
diff --git a/test/sql/test_roles.py b/test/sql/test_roles.py
new file mode 100644 (file)
index 0000000..81934de
--- /dev/null
@@ -0,0 +1,218 @@
+from sqlalchemy import Column
+from sqlalchemy import exc
+from sqlalchemy import Integer
+from sqlalchemy import MetaData
+from sqlalchemy import select
+from sqlalchemy import Table
+from sqlalchemy import text
+from sqlalchemy.schema import DDL
+from sqlalchemy.schema import Sequence
+from sqlalchemy.sql import ClauseElement
+from sqlalchemy.sql import coercions
+from sqlalchemy.sql import column
+from sqlalchemy.sql import false
+from sqlalchemy.sql import False_
+from sqlalchemy.sql import literal
+from sqlalchemy.sql import roles
+from sqlalchemy.sql import true
+from sqlalchemy.sql import True_
+from sqlalchemy.sql.coercions import expect
+from sqlalchemy.sql.elements import _truncated_label
+from sqlalchemy.sql.elements import Null
+from sqlalchemy.testing import assert_raises_message
+from sqlalchemy.testing import fixtures
+from sqlalchemy.testing import is_
+from sqlalchemy.testing import is_instance_of
+from sqlalchemy.testing import is_true
+
+m = MetaData()
+
+t = Table("t", m, Column("q", Integer))
+
+
+class NotAThing1(object):
+    pass
+
+
+not_a_thing1 = NotAThing1()
+
+
+class NotAThing2(ClauseElement):
+    pass
+
+
+not_a_thing2 = NotAThing2()
+
+
+class NotAThing3(object):
+    def __clause_element__(self):
+        return not_a_thing2
+
+
+not_a_thing3 = NotAThing3()
+
+
+class RoleTest(fixtures.TestBase):
+    # TODO: the individual role tests here are incomplete.  The functionality
+    # of each role is covered by other tests in the sql testing suite however
+    # ideally they'd all have direct tests here as well.
+
+    def _test_role_neg_comparisons(self, role):
+        impl = coercions._impl_lookup[role]
+        role_name = impl.name
+
+        assert_raises_message(
+            exc.ArgumentError,
+            r"%s expected, got .*NotAThing1" % role_name,
+            expect,
+            role,
+            not_a_thing1,
+        )
+
+        assert_raises_message(
+            exc.ArgumentError,
+            r"%s expected, got .*NotAThing2" % role_name,
+            expect,
+            role,
+            not_a_thing2,
+        )
+
+        assert_raises_message(
+            exc.ArgumentError,
+            r"%s expected, got .*NotAThing3" % role_name,
+            expect,
+            role,
+            not_a_thing3,
+        )
+
+        assert_raises_message(
+            exc.ArgumentError,
+            r"%s expected for argument 'foo'; got .*NotAThing3" % role_name,
+            expect,
+            role,
+            not_a_thing3,
+            argname="foo",
+        )
+
+    def test_const_expr_role(self):
+        t = true()
+        is_(expect(roles.ConstExprRole, t), t)
+
+        f = false()
+        is_(expect(roles.ConstExprRole, f), f)
+
+        is_instance_of(expect(roles.ConstExprRole, True), True_)
+
+        is_instance_of(expect(roles.ConstExprRole, False), False_)
+
+        is_instance_of(expect(roles.ConstExprRole, None), Null)
+
+    def test_truncated_label_role(self):
+        is_instance_of(
+            expect(roles.TruncatedLabelRole, "foobar"), _truncated_label
+        )
+
+    def test_labeled_column_expr_role(self):
+        c = column("q")
+        is_true(expect(roles.LabeledColumnExprRole, c).compare(c))
+
+        is_true(
+            expect(roles.LabeledColumnExprRole, c.label("foo")).compare(
+                c.label("foo")
+            )
+        )
+
+        is_true(
+            expect(
+                roles.LabeledColumnExprRole,
+                select([column("q")]).scalar_subquery(),
+            ).compare(select([column("q")]).label(None))
+        )
+
+        is_true(
+            expect(roles.LabeledColumnExprRole, not_a_thing1).compare(
+                literal(not_a_thing1).label(None)
+            )
+        )
+
+    def test_scalar_select_no_coercion(self):
+        # this is also tested in test/sql/test_deprecations.py; when the
+        # deprecation is turned to an error, those tests go away, and these
+        # will assert the correct exception plus informative error message.
+        assert_raises_message(
+            exc.SADeprecationWarning,
+            "coercing SELECT object to scalar subquery in a column-expression "
+            "context is deprecated",
+            expect,
+            roles.LabeledColumnExprRole,
+            select([column("q")]),
+        )
+
+        assert_raises_message(
+            exc.SADeprecationWarning,
+            "coercing SELECT object to scalar subquery in a column-expression "
+            "context is deprecated",
+            expect,
+            roles.LabeledColumnExprRole,
+            select([column("q")]).alias(),
+        )
+
+    def test_statement_no_text_coercion(self):
+        assert_raises_message(
+            exc.ArgumentError,
+            r"Textual SQL expression 'select \* from table' should be "
+            r"explicitly declared",
+            expect,
+            roles.StatementRole,
+            "select * from table",
+        )
+
+    def test_statement_text_coercion(self):
+        is_true(
+            expect(
+                roles.CoerceTextStatementRole, "select * from table"
+            ).compare(text("select * from table"))
+        )
+
+    def test_select_statement_no_text_coercion(self):
+        assert_raises_message(
+            exc.ArgumentError,
+            r"Textual SQL expression 'select \* from table' should be "
+            r"explicitly declared",
+            expect,
+            roles.SelectStatementRole,
+            "select * from table",
+        )
+
+    def test_statement_coercion_select(self):
+        is_true(
+            expect(roles.CoerceTextStatementRole, select([t])).compare(
+                select([t])
+            )
+        )
+
+    def test_statement_coercion_ddl(self):
+        d1 = DDL("hi")
+        is_(expect(roles.CoerceTextStatementRole, d1), d1)
+
+    def test_statement_coercion_sequence(self):
+        s1 = Sequence("hi")
+        is_(expect(roles.CoerceTextStatementRole, s1), s1)
+
+    def test_columns_clause_role(self):
+        is_(expect(roles.ColumnsClauseRole, t.c.q), t.c.q)
+
+    def test_truncated_label_role_neg(self):
+        self._test_role_neg_comparisons(roles.TruncatedLabelRole)
+
+    def test_where_having_role_neg(self):
+        self._test_role_neg_comparisons(roles.WhereHavingRole)
+
+    def test_by_of_role_neg(self):
+        self._test_role_neg_comparisons(roles.ByOfRole)
+
+    def test_const_expr_role_neg(self):
+        self._test_role_neg_comparisons(roles.ConstExprRole)
+
+    def test_columns_clause_role_neg(self):
+        self._test_role_neg_comparisons(roles.ColumnsClauseRole)
index a4d3e1b406a6b6197493fbd6f9619351b3ecd176..f88243fc27e5f02b33446a8b412096c610620393 100644 (file)
@@ -31,7 +31,6 @@ from sqlalchemy import util
 from sqlalchemy.sql import Alias
 from sqlalchemy.sql import column
 from sqlalchemy.sql import elements
-from sqlalchemy.sql import expression
 from sqlalchemy.sql import table
 from sqlalchemy.sql import util as sql_util
 from sqlalchemy.sql import visitors
@@ -221,7 +220,7 @@ class SelectableTest(
         s3c1 = s3._clone()
 
         eq_(
-            expression._cloned_intersection([s1c1, s3c1], [s2c1, s1c2]),
+            elements._cloned_intersection([s1c1, s3c1], [s2c1, s1c2]),
             set([s1c1]),
         )
 
@@ -240,7 +239,7 @@ class SelectableTest(
         s3c1 = s3._clone()
 
         eq_(
-            expression._cloned_difference([s1c1, s2c1, s3c1], [s2c1, s1c2]),
+            elements._cloned_difference([s1c1, s2c1, s3c1], [s2c1, s1c2]),
             set([s3c1]),
         )
 
@@ -313,7 +312,7 @@ class SelectableTest(
         assert sel3.corresponding_column(col) is sel3.c.foo
 
     def test_with_only_generative(self):
-        s1 = table1.select().as_scalar()
+        s1 = table1.select().scalar_subquery()
         self.assert_compile(
             s1.with_only_columns([s1]),
             "SELECT (SELECT table1.col1, table1.col2, "
@@ -365,12 +364,18 @@ class SelectableTest(
         criterion = a.c.col1 == table2.c.col2
         self.assert_(criterion.compare(j.onclause))
 
+    @testing.fails("not supported with rework, need a new approach")
     def test_alias_handles_column_context(self):
         # not quite a use case yet but this is expected to become
         # prominent w/ PostgreSQL's tuple functions
 
         stmt = select([table1.c.col1, table1.c.col2])
         a = stmt.alias("a")
+
+        # TODO: this case is crazy, sending SELECT or FROMCLAUSE has to
+        # be figured out - is it a scalar row query?  what kinds of
+        # statements go into functions in PG. seems likely select statment,
+        # but not alias, subquery or other FROM object
         self.assert_compile(
             select([func.foo(a)]),
             "SELECT foo(SELECT table1.col1, table1.col2 FROM table1) "
@@ -652,7 +657,7 @@ class SelectableTest(
         self.assert_(criterion.compare(j.onclause))
 
     def test_scalar_cloned_comparator(self):
-        sel = select([table1.c.col1]).as_scalar()
+        sel = select([table1.c.col1]).scalar_subquery()
         expr = sel == table1.c.col1
 
         sel2 = visitors.ReplacingCloningVisitor().traverse(sel)
@@ -2535,7 +2540,7 @@ class ResultMapTest(fixtures.TestBase):
 
     def test_column_subquery_plain(self):
         t = self._fixture()
-        s1 = select([t.c.x]).where(t.c.x > 5).as_scalar()
+        s1 = select([t.c.x]).where(t.c.x > 5).scalar_subquery()
         s2 = select([s1])
         mapping = self._mapping(s2)
         assert t.c.x not in mapping
index 188ac3878caaa9d2acf78d67e17a11fcb53d860c..afc7755983bb72a2b4e745eb843db3a4c4922845 100644 (file)
@@ -564,7 +564,7 @@ class AsFromTest(fixtures.TestBase, AssertsCompiledSQL):
 
     def test_scalar_subquery(self):
         t = text("select id from user").columns(id=Integer)
-        subq = t.as_scalar()
+        subq = t.scalar_subquery()
 
         assert subq.type._type_affinity is Integer()._type_affinity
 
index a1b1f024b91d7a66e72dd56b8be70c4617aabfa6..a5c9313f80d2bda2b9beabfae8b424c642d93ef0 100644 (file)
@@ -2608,7 +2608,8 @@ class ExpressionTest(
 
         assert_raises_message(
             exc.ArgumentError,
-            r"Object some_sqla_thing\(\) is not legal as a SQL literal value",
+            r"SQL expression element or literal value expected, got "
+            r"some_sqla_thing\(\).",
             lambda: column("a", String) == SomeSQLAThing(),
         )
 
index 514076daab3720eb8a8473c52f47c6bea0733669..9309ca45a9f6108f8856452ef03af5bd1f7dfbe2 100644 (file)
@@ -179,7 +179,7 @@ class UpdateTest(_UpdateFromTestBase, fixtures.TablesTest, AssertsCompiledSQL):
 
         # test a non-correlated WHERE clause
         s = select([table2.c.othername], table2.c.otherid == 7)
-        u = update(table1, table1.c.name == s)
+        u = update(table1, table1.c.name == s.scalar_subquery())
         self.assert_compile(
             u,
             "UPDATE mytable SET myid=:myid, name=:name, "
@@ -194,7 +194,7 @@ class UpdateTest(_UpdateFromTestBase, fixtures.TablesTest, AssertsCompiledSQL):
 
         # test one that is actually correlated...
         s = select([table2.c.othername], table2.c.otherid == table1.c.myid)
-        u = table1.update(table1.c.name == s)
+        u = table1.update(table1.c.name == s.scalar_subquery())
         self.assert_compile(
             u,
             "UPDATE mytable SET myid=:myid, name=:name, "