]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Improve type detection for Values / Tuple
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 18 Dec 2020 17:22:12 +0000 (12:22 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 19 Dec 2020 02:50:14 +0000 (21:50 -0500)
Fixed issue in new :class:`_sql.Values` construct where passing tuples of
objects would fall back to per-value type detection rather than making use
of the :class:`_schema.Column` objects passed directly to
:class:`_sql.Values` that tells SQLAlchemy what the expected type is. This
would lead to issues for objects such as enumerations and numpy strings
that are not actually necessary since the expected type is given.

note this changes NullType() to raise CompileError for
literal_processor; NullType() does not imply the actual value
NULL as much as it does "unknown type" so this should make failure
modes more clear.

Fixes: #5785
Change-Id: Ifbf5e78373102380b301098f30e15011efa98b5e

doc/build/changelog/unreleased_14/5785.rst [new file with mode: 0644]
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/selectable.py
lib/sqlalchemy/sql/sqltypes.py
test/sql/test_values.py

diff --git a/doc/build/changelog/unreleased_14/5785.rst b/doc/build/changelog/unreleased_14/5785.rst
new file mode 100644 (file)
index 0000000..2e07d2d
--- /dev/null
@@ -0,0 +1,10 @@
+.. change::
+    :tags: bug, sql
+    :tickets: 5785
+
+    Fixed issue in new :class:`_sql.Values` construct where passing tuples of
+    objects would fall back to per-value type detection rather than making use
+    of the :class:`_schema.Column` objects passed directly to
+    :class:`_sql.Values` that tells SQLAlchemy what the expected type is. This
+    would lead to issues for objects such as enumerations and numpy strings
+    that are not actually necessary since the expected type is given.
\ No newline at end of file
index 46f111d2c5770cf25be07ddc6d017545f7a53042..a734bb5825db3e0dfbe7d49f35b21c802a434161 100644 (file)
@@ -2610,7 +2610,9 @@ class SQLCompiler(Compiled):
 
         v = "VALUES %s" % ", ".join(
             self.process(
-                elements.Tuple(*elem).self_group(),
+                elements.Tuple(
+                    types=element._column_types, *elem
+                ).self_group(),
                 literal_binds=element.literal_binds,
             )
             for chunk in element._data
index ab8701dd6566689540218df9daa3536ef7397041..75c1fc1bf35755a8d754c10f0394b59f306f7ab2 100644 (file)
@@ -2497,11 +2497,28 @@ class Tuple(ClauseList, ColumnElement):
         """
         sqltypes = util.preloaded.sql_sqltypes
 
-        clauses = [
-            coercions.expect(roles.ExpressionElementRole, c) for c in clauses
-        ]
-        self.type = sqltypes.TupleType(*[arg.type for arg in clauses])
+        types = kw.pop("types", None)
+        if types is None:
+            clauses = [
+                coercions.expect(roles.ExpressionElementRole, c)
+                for c in clauses
+            ]
+        else:
+            if len(types) != len(clauses):
+                raise exc.ArgumentError(
+                    "Wrong number of elements for %d-tuple: %r "
+                    % (len(types), clauses)
+                )
+            clauses = [
+                coercions.expect(
+                    roles.ExpressionElementRole,
+                    c,
+                    type_=typ if not typ._isnull else None,
+                )
+                for typ, c in zip(types, clauses)
+            ]
 
+        self.type = sqltypes.TupleType(*[arg.type for arg in clauses])
         super(Tuple, self).__init__(*clauses, **kw)
 
     @property
index d60afdbacf771a0789ecc64939f84aefb0b991e5..b49fe92df5a57f7d021b16dcd2472ec48c913517 100644 (file)
@@ -2397,6 +2397,10 @@ class Values(Generative, FromClause):
         self.literal_binds = kw.pop("literal_binds", False)
         self.named_with_column = self.name is not None
 
+    @property
+    def _column_types(self):
+        return [col.type for col in self._column_args]
+
     @_generative
     def alias(self, name, **kw):
         """Return a new :class:`_expression.Values`
index 581573d17e282f3f4ed58684c7a119956d52e06a..09c7388abe7845935552f25da9e21d319cf98246 100644 (file)
@@ -3074,7 +3074,9 @@ class NullType(TypeEngine):
 
     def literal_processor(self, dialect):
         def process(value):
-            return "NULL"
+            raise exc.CompileError(
+                "Don't know how to render literal SQL value: %r" % value
+            )
 
         return process
 
@@ -3131,6 +3133,7 @@ else:
     _type_map[unicode] = Unicode()  # noqa
     _type_map[str] = String()
 
+
 _type_map_get = _type_map.get
 
 
index 1e4f2244290e1b1041aef52d0e498122e064a58b..43e8f85316282e6ed514a393453483ba3a538a81 100644 (file)
@@ -1,6 +1,8 @@
 from sqlalchemy import alias
 from sqlalchemy import Column
 from sqlalchemy import column
+from sqlalchemy import Enum
+from sqlalchemy import exc
 from sqlalchemy import ForeignKey
 from sqlalchemy import Integer
 from sqlalchemy import String
@@ -12,7 +14,9 @@ from sqlalchemy.sql import select
 from sqlalchemy.sql import Values
 from sqlalchemy.sql.compiler import FROM_LINTING
 from sqlalchemy.testing import AssertsCompiledSQL
+from sqlalchemy.testing import expect_raises_message
 from sqlalchemy.testing import fixtures
+from sqlalchemy.util import OrderedDict
 
 
 class ValuesTest(fixtures.TablesTest, AssertsCompiledSQL):
@@ -52,34 +56,104 @@ class ValuesTest(fixtures.TablesTest, AssertsCompiledSQL):
             Column("book_weight", Integer),
         )
 
+    def test_wrong_number_of_elements(self):
+        v1 = Values(
+            column("CaseSensitive", Integer),
+            column("has spaces", String),
+            name="Spaces and Cases",
+        ).data([(1, "textA", 99), (2, "textB", 88)])
+
+        with expect_raises_message(
+            exc.ArgumentError,
+            r"Wrong number of elements for 2-tuple: \(1, 'textA', 99\)",
+        ):
+            str(v1)
+
     def test_column_quoting(self):
         v1 = Values(
             column("CaseSensitive", Integer),
             column("has spaces", String),
+            column("number", Integer),
             name="Spaces and Cases",
         ).data([(1, "textA", 99), (2, "textB", 88)])
         self.assert_compile(
             select(v1),
             'SELECT "Spaces and Cases"."CaseSensitive", '
-            '"Spaces and Cases"."has spaces" FROM '
+            '"Spaces and Cases"."has spaces", "Spaces and Cases".number FROM '
             "(VALUES (:param_1, :param_2, :param_3), "
             "(:param_4, :param_5, :param_6)) "
-            'AS "Spaces and Cases" ("CaseSensitive", "has spaces")',
+            'AS "Spaces and Cases" ("CaseSensitive", "has spaces", number)',
         )
 
     @testing.fixture
     def literal_parameter_fixture(self):
-        def go(literal_binds):
-            return Values(
+        def go(literal_binds, omit=None):
+            cols = [
                 column("mykey", Integer),
                 column("mytext", String),
                 column("myint", Integer),
-                name="myvalues",
-                literal_binds=literal_binds,
+            ]
+            if omit:
+                for idx in omit:
+                    cols[idx] = column(cols[idx].name)
+
+            return Values(
+                *cols, name="myvalues", literal_binds=literal_binds
             ).data([(1, "textA", 99), (2, "textB", 88)])
 
         return go
 
+    @testing.fixture
+    def tricky_types_parameter_fixture(self):
+        class SomeEnum(object):
+            # Implements PEP 435 in the minimal fashion needed by SQLAlchemy
+            __members__ = OrderedDict()
+
+            def __init__(self, name, value, alias=None):
+                self.name = name
+                self.value = value
+                self.__members__[name] = self
+                setattr(self.__class__, name, self)
+                if alias:
+                    self.__members__[alias] = self
+                    setattr(self.__class__, alias, self)
+
+        one = SomeEnum("one", 1)
+        two = SomeEnum("two", 2)
+
+        class MumPyString(str):
+            """some kind of string, can't imagine where such a thing might
+            be found
+
+            """
+
+        class MumPyNumber(int):
+            """some kind of int, can't imagine where such a thing might
+            be found
+
+            """
+
+        def go(literal_binds, omit=None):
+            cols = [
+                column("mykey", Integer),
+                column("mytext", String),
+                column("myenum", Enum(SomeEnum)),
+            ]
+            if omit:
+                for idx in omit:
+                    cols[idx] = column(cols[idx].name)
+
+            return Values(
+                *cols, name="myvalues", literal_binds=literal_binds
+            ).data(
+                [
+                    (MumPyNumber(1), MumPyString("textA"), one),
+                    (MumPyNumber(2), MumPyString("textB"), two),
+                ]
+            )
+
+        return go
+
     def test_bound_parameters(self, literal_parameter_fixture):
         literal_parameter_fixture = literal_parameter_fixture(False)
 
@@ -114,6 +188,49 @@ class ValuesTest(fixtures.TablesTest, AssertsCompiledSQL):
             checkparams={},
         )
 
+    def test_literal_parameters_not_every_type_given(
+        self, literal_parameter_fixture
+    ):
+        literal_parameter_fixture = literal_parameter_fixture(True, omit=(1,))
+
+        stmt = select(literal_parameter_fixture)
+
+        self.assert_compile(
+            stmt,
+            "SELECT myvalues.mykey, myvalues.mytext, myvalues.myint FROM "
+            "(VALUES (1, 'textA', 99), (2, 'textB', 88)"
+            ") AS myvalues (mykey, mytext, myint)",
+            checkparams={},
+        )
+
+    def test_use_cols_tricky_not_every_type_given(
+        self, tricky_types_parameter_fixture
+    ):
+        literal_parameter_fixture = tricky_types_parameter_fixture(
+            True, omit=(1,)
+        )
+
+        stmt = select(literal_parameter_fixture)
+
+        with expect_raises_message(
+            exc.CompileError,
+            "Don't know how to render literal SQL value: 'textA'",
+        ):
+            str(stmt)
+
+    def test_use_cols_for_types(self, tricky_types_parameter_fixture):
+        literal_parameter_fixture = tricky_types_parameter_fixture(True)
+
+        stmt = select(literal_parameter_fixture)
+
+        self.assert_compile(
+            stmt,
+            "SELECT myvalues.mykey, myvalues.mytext, myvalues.myenum FROM "
+            "(VALUES (1, 'textA', 'one'), (2, 'textB', 'two')"
+            ") AS myvalues (mykey, mytext, myenum)",
+            checkparams={},
+        )
+
     def test_with_join_unnamed(self):
         people = self.tables.people
         values = Values(