]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Allow optional *args with base AnsiFunction
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 29 Nov 2018 15:13:03 +0000 (10:13 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 30 Nov 2018 04:46:42 +0000 (23:46 -0500)
Amended the :class:`.AnsiFunction` class, the base of common SQL
functions like ``CURRENT_TIMESTAMP``, to accept positional arguments
like a regular ad-hoc function.  This to suit the case that many of
these functions on specific backends accept arguments such as
"fractional seconds" precision and such.  If the function is created
with arguments, it renders the the parenthesis and the arguments.  If
no arguents are present, the compiler generates the non-parenthesized form.

Fixes: #4386
Change-Id: Ic492ef177e4987cec99ec4d95f55292be8daa087

doc/build/changelog/unreleased_13/4386.rst [new file with mode: 0644]
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/functions.py
test/sql/test_functions.py

diff --git a/doc/build/changelog/unreleased_13/4386.rst b/doc/build/changelog/unreleased_13/4386.rst
new file mode 100644 (file)
index 0000000..24e9f84
--- /dev/null
@@ -0,0 +1,11 @@
+.. change::
+   :tags: feature, sql
+   :tickets: 4386
+
+   Amended the :class:`.AnsiFunction` class, the base of common SQL
+   functions like ``CURRENT_TIMESTAMP``, to accept positional arguments
+   like a regular ad-hoc function.  This to suit the case that many of
+   these functions on specific backends accept arguments such as
+   "fractional seconds" precision and such.  If the function is created
+   with arguments, it renders the the parenthesis and the arguments.  If
+   no arguents are present, the compiler generates the non-parenthesized form.
index c2a23a758c2e827422eba09480646ec3fd743b72..80ed707edf3f06ad3f30744a7f18175608953f74 100644 (file)
@@ -111,20 +111,20 @@ OPERATORS = {
 }
 
 FUNCTIONS = {
-    functions.coalesce: 'coalesce%(expr)s',
+    functions.coalesce: 'coalesce',
     functions.current_date: 'CURRENT_DATE',
     functions.current_time: 'CURRENT_TIME',
     functions.current_timestamp: 'CURRENT_TIMESTAMP',
     functions.current_user: 'CURRENT_USER',
     functions.localtime: 'LOCALTIME',
     functions.localtimestamp: 'LOCALTIMESTAMP',
-    functions.random: 'random%(expr)s',
+    functions.random: 'random',
     functions.sysdate: 'sysdate',
     functions.session_user: 'SESSION_USER',
     functions.user: 'USER',
-    functions.cube: 'CUBE%(expr)s',
-    functions.rollup: 'ROLLUP%(expr)s',
-    functions.grouping_sets: 'GROUPING SETS%(expr)s',
+    functions.cube: 'CUBE',
+    functions.rollup: 'ROLLUP',
+    functions.grouping_sets: 'GROUPING SETS',
 }
 
 EXTRACT_MAP = {
@@ -927,7 +927,12 @@ class SQLCompiler(Compiled):
         if disp:
             return disp(func, **kwargs)
         else:
-            name = FUNCTIONS.get(func.__class__, func.name + "%(expr)s")
+            name = FUNCTIONS.get(func.__class__, None)
+            if name:
+                if func._has_args:
+                    name += "%(expr)s"
+            else:
+                name = func.name + "%(expr)s"
             return ".".join(list(func.packagenames) + [name]) % \
                 {'expr': self.function_argspec(func, **kwargs)}
 
index 5cea7750a73f248b4a067b59bef2fc1148406d69..4b4d2d463d8cd01f5af0e35cde04747949da8fd1 100644 (file)
@@ -54,10 +54,13 @@ class FunctionElement(Executable, ColumnElement, FromClause):
 
     packagenames = ()
 
+    _has_args = False
+
     def __init__(self, *clauses, **kwargs):
         """Construct a :class:`.FunctionElement`.
         """
         args = [_literal_as_binds(c, self.name) for c in clauses]
+        self._has_args = self._has_args or bool(args)
         self.clause_expr = ClauseList(
             operator=operators.comma_op,
             group_contents=True, *args).\
@@ -635,6 +638,7 @@ class GenericFunction(util.with_metaclass(_GenericMeta, Function)):
         parsed_args = kwargs.pop('_parsed_args', None)
         if parsed_args is None:
             parsed_args = [_literal_as_binds(c, self.name) for c in args]
+        self._has_args = self._has_args or bool(parsed_args)
         self.packagenames = []
         self._bind = kwargs.get('bind', None)
         self.clause_expr = ClauseList(
@@ -671,8 +675,8 @@ class next_value(GenericFunction):
 
 
 class AnsiFunction(GenericFunction):
-    def __init__(self, **kwargs):
-        GenericFunction.__init__(self, **kwargs)
+    def __init__(self, *args, **kwargs):
+        GenericFunction.__init__(self, *args, **kwargs)
 
 
 class ReturnTypeFromArgs(GenericFunction):
@@ -686,7 +690,7 @@ class ReturnTypeFromArgs(GenericFunction):
 
 
 class coalesce(ReturnTypeFromArgs):
-    pass
+    _has_args = True
 
 
 class max(ReturnTypeFromArgs):
@@ -717,7 +721,7 @@ class char_length(GenericFunction):
 
 
 class random(GenericFunction):
-    pass
+    _has_args = True
 
 
 class count(GenericFunction):
@@ -937,6 +941,7 @@ class cube(GenericFunction):
     .. versionadded:: 1.2
 
     """
+    _has_args = True
 
 
 class rollup(GenericFunction):
@@ -952,6 +957,7 @@ class rollup(GenericFunction):
     .. versionadded:: 1.2
 
     """
+    _has_args = True
 
 
 class grouping_sets(GenericFunction):
@@ -984,3 +990,4 @@ class grouping_sets(GenericFunction):
     .. versionadded:: 1.2
 
     """
+    _has_args = True
index 48d5fc37f861b87135d416f2769b2f83bf61f220..ffc72e9eee2e585793e89c8344e09c2aee74ee4c 100644 (file)
@@ -14,7 +14,7 @@ import decimal
 from sqlalchemy import testing
 from sqlalchemy.testing import fixtures, AssertsCompiledSQL, engines
 from sqlalchemy.dialects import sqlite, postgresql, mysql, oracle
-from sqlalchemy.testing import assert_raises_message
+from sqlalchemy.testing import assert_raises_message, assert_raises
 
 table1 = table('mytable',
                column('myid', Integer),
@@ -133,12 +133,14 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
             pass
 
         assert isinstance(func.myfunc(), myfunc)
+        self.assert_compile(func.myfunc(), "myfunc()")
 
     def test_custom_type(self):
         class myfunc(GenericFunction):
             type = DateTime
 
         assert isinstance(func.myfunc().type, DateTime)
+        self.assert_compile(func.myfunc(), "myfunc()")
 
     def test_custom_legacy_type(self):
         # in case someone was using this system
@@ -228,24 +230,19 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
         c = column('abc')
         self.assert_compile(func.count(c), 'count(abc)')
 
-    def test_constructor(self):
-        try:
-            func.current_timestamp('somearg')
-            assert False
-        except TypeError:
-            assert True
-
-        try:
-            func.char_length('a', 'b')
-            assert False
-        except TypeError:
-            assert True
+    def test_ansi_functions_with_args(self):
+        ct = func.current_timestamp('somearg')
+        self.assert_compile(ct, "CURRENT_TIMESTAMP(:current_timestamp_1)")
 
-        try:
-            func.char_length()
-            assert False
-        except TypeError:
-            assert True
+    def test_char_length_fixed_args(self):
+        assert_raises(
+            TypeError,
+            func.char_length, 'a', 'b'
+        )
+        assert_raises(
+            TypeError,
+            func.char_length
+        )
 
     def test_return_type_detection(self):