]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Added new checks for the common error case of passing mapped classes
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 22 Aug 2015 16:47:13 +0000 (12:47 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 22 Aug 2015 16:47:13 +0000 (12:47 -0400)
or mapped instances into contexts where they are interpreted as
SQL bound parameters; a new exception is raised for this.
fixes #3321

doc/build/changelog/changelog_11.rst
doc/build/changelog/migration_11.rst
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/default_comparator.py
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/sqltypes.py
lib/sqlalchemy/sql/type_api.py
lib/sqlalchemy/types.py
test/aaa_profiling/test_compiler.py
test/orm/test_query.py
test/sql/test_types.py

index 0f974dc8c577b0864a4ac2a1f28e0f4bd39eb931..695aa3c5c62b839400e9a07d4d816d313928ef8d 100644 (file)
 .. changelog::
     :version: 1.1.0b1
 
+    .. change::
+        :tags: feature, orm
+        :tickets: 3321
+
+        Added new checks for the common error case of passing mapped classes
+        or mapped instances into contexts where they are interpreted as
+        SQL bound parameters; a new exception is raised for this.
+
+        .. seealso::
+
+            :ref:`change_3321`
+
     .. change::
         :tags: bug, postgresql
         :tickets: 3499
index c40d5a9c1e6e6e9114c3b2f578ed329e5687b2be..849d4516b693a365aac1efffb4eda0ec99add022 100644 (file)
@@ -104,6 +104,47 @@ approach which applied a counter to the object.
 
 :ticket:`3499`
 
+.. _change_3321:
+
+Specific checks added for passing mapped classes, instances as SQL literals
+---------------------------------------------------------------------------
+
+The typing system now has specific checks for passing of SQLAlchemy
+"inspectable" objects in contexts where they would otherwise be handled as
+literal values.   Any SQLAlchemy built-in object that is legal to pass as a
+SQL value includes a method ``__clause_element__()`` which provides a
+valid SQL expression for that object.  For SQLAlchemy objects that
+don't provide this, such as mapped classes, mappers, and mapped
+instances, a more informative error message is emitted rather than
+allowing the DBAPI to receive the object and fail later.  An example
+is illustrated below, where a string-based attribute ``User.name`` is
+compared to a full instance of ``User()``, rather than against a
+string value::
+
+    >>> some_user = User()
+    >>> q = s.query(User).filter(User.name == some_user)
+    ...
+    sqlalchemy.exc.ArgumentError: Object <__main__.User object at 0x103167e90> is not legal as a SQL literal value
+
+The exception is now immediate when the comparison is made between
+``User.name == some_user``.  Previously, a comparison like the above
+would produce a SQL expression that would only fail once resolved
+into a DBAPI execution call; the mapped ``User`` object would
+ultimately become a bound parameter that would be rejected by the
+DBAPI.
+
+Note that in the above example, the expression fails because
+``User.name`` is a string-based (e.g. column oriented) attribute.
+The change does *not* impact the usual case of comparing a many-to-one
+relationship attribute to an object, which is handled distinctly::
+
+    >>> # Address.user refers to the User mapper, so
+    >>> # this is of course still OK!
+    >>> q = s.query(Address).filter(Address.user == some_user)
+
+
+:ticket:`3321`
+
 New Features and Improvements - Core
 ====================================
 
index d3c46e6437db706d4777d59adb5cb44900ac7fac..4717b777fc5da0a03f82edd89ada91550d46de4e 100644 (file)
@@ -281,6 +281,8 @@ class _CompileLabel(visitors.Visitable):
     def type(self):
         return self.element.type
 
+    def self_group(self, **kw):
+        return self
 
 class SQLCompiler(Compiled):
 
index 09f6391638d5f901ae7b6fa22d5933f9ff2a139d..125fec33f85ccc069cfe8c2017ac10f5670adda6 100644 (file)
@@ -15,7 +15,7 @@ from .elements import BindParameter, True_, False_, BinaryExpression, \
     Null, _const_expr, _clause_element_as_expr, \
     ClauseList, ColumnElement, TextClause, UnaryExpression, \
     collate, _is_literal, _literal_as_text, ClauseElement, and_, or_, \
-    Slice
+    Slice, Visitable
 from .selectable import SelectBase, Alias, Selectable, ScalarSelect
 
 
@@ -304,7 +304,7 @@ def _check_literal(expr, operator, other):
 
     if isinstance(other, (SelectBase, Alias)):
         return other.as_scalar()
-    elif not isinstance(other, (ColumnElement, TextClause)):
+    elif not isinstance(other, Visitable):
         return expr._bind_param(operator, other)
     else:
         return other
index 00c749b4080660e38574a848e735898ce94cf040..e2d81afc184441caa305f8d509434d1635f906d4 100644 (file)
@@ -1145,8 +1145,7 @@ class BindParameter(ColumnElement):
                     _compared_to_type.coerce_compared_value(
                         _compared_to_operator, value)
             else:
-                self.type = type_api._type_map.get(type(value),
-                                                   type_api.NULLTYPE)
+                self.type = type_api._resolve_value_to_type(value)
         elif isinstance(type_, type):
             self.type = type_()
         else:
@@ -1161,8 +1160,7 @@ class BindParameter(ColumnElement):
         cloned.callable = None
         cloned.required = False
         if cloned.type is type_api.NULLTYPE:
-            cloned.type = type_api._type_map.get(type(value),
-                                                 type_api.NULLTYPE)
+            cloned.type = type_api._resolve_value_to_type(value)
         return cloned
 
     @property
index ec7dea3007dd6eb88c95f1ea1d8fef75d9348fc9..b5c575143bff8ae70a3fa05d04fd595b01e8cc52 100644 (file)
@@ -9,7 +9,6 @@
 
 """
 
-import collections
 import datetime as dt
 import codecs
 
@@ -18,6 +17,7 @@ from .elements import quoted_name, type_coerce, _defer_name
 from .. import exc, util, processors
 from .base import _bind_or_error, SchemaEventTarget
 from . import operators
+from .. import inspection
 from .. import event
 from ..util import pickle
 import decimal
@@ -1736,6 +1736,21 @@ else:
     _type_map[unicode] = Unicode()
     _type_map[str] = String()
 
+_type_map_get = _type_map.get
+
+
+def _resolve_value_to_type(value):
+    _result_type = _type_map_get(type(value), False)
+    if _result_type is False:
+        # use inspect() to detect SQLAlchemy built-in
+        # objects.
+        insp = inspection.inspect(value, False)
+        if insp is not None:
+            raise exc.ArgumentError(
+                "Object %r is not legal as a SQL literal value" % value)
+        return NULLTYPE
+    else:
+        return _result_type
 
 # back-assign to type_api
 from . import type_api
@@ -1745,6 +1760,5 @@ type_api.INTEGERTYPE = INTEGERTYPE
 type_api.NULLTYPE = NULLTYPE
 type_api.MATCHTYPE = MATCHTYPE
 type_api.INDEXABLE = Indexable
-type_api._type_map = _type_map
-
+type_api._resolve_value_to_type = _resolve_value_to_type
 TypeEngine.Comparator.BOOLEANTYPE = BOOLEANTYPE
index c4e830b7f3e8d49db614b1881b3a3a31634d0639..b9826e585c6c4111dbdbca7fa0ab59c7c37fbcad 100644 (file)
@@ -21,6 +21,7 @@ NULLTYPE = None
 STRINGTYPE = None
 MATCHTYPE = None
 INDEXABLE = None
+_resolve_value_to_type = None
 
 
 class TypeEngine(Visitable):
@@ -454,7 +455,7 @@ class TypeEngine(Visitable):
         end-user customization of this behavior.
 
         """
-        _coerced_type = _type_map.get(type(value), NULLTYPE)
+        _coerced_type = _resolve_value_to_type(value)
         if _coerced_type is NULLTYPE or _coerced_type._type_affinity \
                 is self._type_affinity:
             return self
index 3a0e2a58fa73943c9954e654b2b124885a1f0c88..61b89969f9001a266d3da2046f751da389b59995 100644 (file)
@@ -76,5 +76,4 @@ from .sql.sqltypes import (
     UnicodeText,
     VARBINARY,
     VARCHAR,
-    _type_map
     )
index 5eece46025d5a966f90eb1b91eaca2f22f7c0dd3..5095be1032dc4126d4cc6330a8086f1be4dd94d9 100644 (file)
@@ -32,8 +32,8 @@ class CompileTest(fixtures.TestBase, AssertsExecutionResults):
         for t in (t1, t2):
             for c in t.c:
                 c.type._type_affinity
-        from sqlalchemy import types
-        for t in list(types._type_map.values()):
+        from sqlalchemy.sql import sqltypes
+        for t in list(sqltypes._type_map.values()):
             t._type_affinity
 
         cls.dialect = default.DefaultDialect()
index 3ed2e7d7a97dbd4880744212ad06428d3c1a2e5c..b0501739f6942a8b8591a1fb8d3b04263c2a725c 100644 (file)
@@ -776,6 +776,42 @@ class InvalidGenerationsTest(QueryTest, AssertsCompiledSQL):
                 meth, q, *arg, **kw
             )
 
+    def test_illegal_coercions(self):
+        User = self.classes.User
+
+        assert_raises_message(
+            sa_exc.ArgumentError,
+            "Object .*User.* is not legal as a SQL literal value",
+            distinct, User
+        )
+
+        ua = aliased(User)
+        assert_raises_message(
+            sa_exc.ArgumentError,
+            "Object .*User.* is not legal as a SQL literal value",
+            distinct, ua
+        )
+
+        s = Session()
+        assert_raises_message(
+            sa_exc.ArgumentError,
+            "Object .*User.* is not legal as a SQL literal value",
+            lambda: s.query(User).filter(User.name == User)
+        )
+
+        u1 = User()
+        assert_raises_message(
+            sa_exc.ArgumentError,
+            "Object .*User.* is not legal as a SQL literal value",
+            distinct, u1
+        )
+
+        assert_raises_message(
+            sa_exc.ArgumentError,
+            "Object .*User.* is not legal as a SQL literal value",
+            lambda: s.query(User).filter(User.name == u1)
+        )
+
 
 class OperatorTest(QueryTest, AssertsCompiledSQL):
     """test sql.Comparator implementation for MapperProperties"""
index 0ab8ef451cbc99becc6c72935196890d1cd6e581..90fac97c21b449a23ba7e4c814dd4c3c61eed323 100644 (file)
@@ -1,5 +1,6 @@
 # coding: utf-8
-from sqlalchemy.testing import eq_, assert_raises, assert_raises_message, expect_warnings
+from sqlalchemy.testing import eq_, is_, assert_raises, \
+    assert_raises_message, expect_warnings
 import decimal
 import datetime
 import os
@@ -11,7 +12,7 @@ from sqlalchemy import (
     BLOB, NCHAR, NVARCHAR, CLOB, TIME, DATE, DATETIME, TIMESTAMP, SMALLINT,
     INTEGER, DECIMAL, NUMERIC, FLOAT, REAL)
 from sqlalchemy.sql import ddl
-
+from sqlalchemy import inspection
 from sqlalchemy import exc, types, util, dialects
 for name in dialects.__all__:
     __import__("sqlalchemy.dialects.%s" % name)
@@ -1647,6 +1648,26 @@ class ExpressionTest(
         assert distinct(test_table.c.data).type == test_table.c.data.type
         assert test_table.c.data.distinct().type == test_table.c.data.type
 
+    def test_detect_coercion_of_builtins(self):
+        @inspection._self_inspects
+        class SomeSQLAThing(object):
+            def __repr__(self):
+                return "some_sqla_thing()"
+
+        class SomeOtherThing(object):
+            pass
+
+        assert_raises_message(
+            exc.ArgumentError,
+            r"Object some_sqla_thing\(\) is not legal as a SQL literal value",
+            lambda: column('a', String) == SomeSQLAThing()
+        )
+
+        is_(
+            bindparam('x', SomeOtherThing()).type,
+            types.NULLTYPE
+        )
+
 
 class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
     __dialect__ = 'default'