]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
test Function(?:Element)._bind_param() with in_()
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 11 Feb 2021 22:51:15 +0000 (17:51 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 11 Feb 2021 22:53:29 +0000 (17:53 -0500)
Fixed 1.4 regression where the :meth:`_functions.Function.in_` method was
not covered by tests and failed to function properly in all cases.

Fixes: #5934
Change-Id: I93423a296e391aabd5594cb670d36b91ced0231d

doc/build/changelog/unreleased_14/5934.rst [new file with mode: 0644]
lib/sqlalchemy/sql/functions.py
test/sql/test_functions.py

diff --git a/doc/build/changelog/unreleased_14/5934.rst b/doc/build/changelog/unreleased_14/5934.rst
new file mode 100644 (file)
index 0000000..10d09ae
--- /dev/null
@@ -0,0 +1,6 @@
+.. change::
+    :tags: bug, sql
+    :tickets: 5934
+
+    Fixed 1.4 regression where the :meth:`_functions.Function.in_` method was
+    not covered by tests and failed to function properly in all cases.
index 78f7ead2ebc8e77e454c2e8eca5b497503815326..641715327bff66a7cf91efeb58271bfcac22e759 100644 (file)
@@ -595,7 +595,7 @@ class FunctionElement(Executable, ColumnElement, FromClause, Generative):
         """
         return self.select().execute()
 
-    def _bind_param(self, operator, obj, type_=None):
+    def _bind_param(self, operator, obj, type_=None, **kw):
         return BindParameter(
             None,
             obj,
@@ -603,6 +603,7 @@ class FunctionElement(Executable, ColumnElement, FromClause, Generative):
             _compared_to_type=self.type,
             unique=True,
             type_=type_,
+            **kw
         )
 
     def self_group(self, against=None):
@@ -887,7 +888,7 @@ class Function(FunctionElement):
             )
             return kw["bind"]
 
-    def _bind_param(self, operator, obj, type_=None):
+    def _bind_param(self, operator, obj, type_=None, **kw):
         return BindParameter(
             self.name,
             obj,
@@ -895,6 +896,7 @@ class Function(FunctionElement):
             _compared_to_type=self.type,
             type_=type_,
             unique=True,
+            **kw
         )
 
 
index e460a90cbbe214a5aedf0da888e1b3d4433538b4..96e0a91291b593670302a0605027bc53f9a16bea 100644 (file)
@@ -31,9 +31,11 @@ from sqlalchemy.dialects import mysql
 from sqlalchemy.dialects import oracle
 from sqlalchemy.dialects import postgresql
 from sqlalchemy.dialects import sqlite
+from sqlalchemy.ext.compiler import compiles
 from sqlalchemy.sql import column
 from sqlalchemy.sql import functions
 from sqlalchemy.sql import LABEL_STYLE_TABLENAME_PLUS_COL
+from sqlalchemy.sql import operators
 from sqlalchemy.sql import quoted_name
 from sqlalchemy.sql import table
 from sqlalchemy.sql.compiler import BIND_TEMPLATES
@@ -99,6 +101,36 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
 
             functions._registry["_default"].pop("fake_func")
 
+    @testing.combinations(
+        (operators.in_op, [1, 2, 3], "myfunc() IN (1, 2, 3)"),
+        (operators.add, 5, "myfunc() + 5"),
+        (operators.eq, column("q"), "myfunc() = q"),
+        argnames="op,other,expected",
+    )
+    @testing.combinations((True,), (False,), argnames="use_custom")
+    def test_operators_custom(self, op, other, expected, use_custom):
+        if use_custom:
+
+            class MyFunc(FunctionElement):
+                name = "myfunc"
+                type = Integer()
+
+            @compiles(MyFunc)
+            def visit_myfunc(element, compiler, **kw):
+                return "myfunc(%s)" % compiler.process(element.clauses, **kw)
+
+            expr = op(MyFunc(), other)
+        else:
+            expr = op(func.myfunc(type_=Integer), other)
+
+        self.assert_compile(
+            select(1).where(expr),
+            "SELECT 1 WHERE %s" % (expected,),
+            literal_binds=True,
+            render_postcompile=True,
+            dialect="default_enhanced",
+        )
+
     def test_use_labels(self):
         self.assert_compile(
             select(func.foo()).set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL),
@@ -106,8 +138,6 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
         )
 
     def test_use_labels_function_element(self):
-        from sqlalchemy.ext.compiler import compiles
-
         class max_(FunctionElement):
             name = "max"