--- /dev/null
+.. change::
+ :tags: feature, sql
+ :tickets: 4386
+
+ Amended the :class:`.AnsiFunction` class, the base of common SQL
+ functions like ``CURRENT_TIMESTAMP``, to accept positional arguments
+ like a regular ad-hoc function. This to suit the case that many of
+ these functions on specific backends accept arguments such as
+ "fractional seconds" precision and such. If the function is created
+ with arguments, it renders the the parenthesis and the arguments. If
+ no arguents are present, the compiler generates the non-parenthesized form.
}
FUNCTIONS = {
- functions.coalesce: 'coalesce%(expr)s',
+ functions.coalesce: 'coalesce',
functions.current_date: 'CURRENT_DATE',
functions.current_time: 'CURRENT_TIME',
functions.current_timestamp: 'CURRENT_TIMESTAMP',
functions.current_user: 'CURRENT_USER',
functions.localtime: 'LOCALTIME',
functions.localtimestamp: 'LOCALTIMESTAMP',
- functions.random: 'random%(expr)s',
+ functions.random: 'random',
functions.sysdate: 'sysdate',
functions.session_user: 'SESSION_USER',
functions.user: 'USER',
- functions.cube: 'CUBE%(expr)s',
- functions.rollup: 'ROLLUP%(expr)s',
- functions.grouping_sets: 'GROUPING SETS%(expr)s',
+ functions.cube: 'CUBE',
+ functions.rollup: 'ROLLUP',
+ functions.grouping_sets: 'GROUPING SETS',
}
EXTRACT_MAP = {
if disp:
return disp(func, **kwargs)
else:
- name = FUNCTIONS.get(func.__class__, func.name + "%(expr)s")
+ name = FUNCTIONS.get(func.__class__, None)
+ if name:
+ if func._has_args:
+ name += "%(expr)s"
+ else:
+ name = func.name + "%(expr)s"
return ".".join(list(func.packagenames) + [name]) % \
{'expr': self.function_argspec(func, **kwargs)}
packagenames = ()
+ _has_args = False
+
def __init__(self, *clauses, **kwargs):
"""Construct a :class:`.FunctionElement`.
"""
args = [_literal_as_binds(c, self.name) for c in clauses]
+ self._has_args = self._has_args or bool(args)
self.clause_expr = ClauseList(
operator=operators.comma_op,
group_contents=True, *args).\
parsed_args = kwargs.pop('_parsed_args', None)
if parsed_args is None:
parsed_args = [_literal_as_binds(c, self.name) for c in args]
+ self._has_args = self._has_args or bool(parsed_args)
self.packagenames = []
self._bind = kwargs.get('bind', None)
self.clause_expr = ClauseList(
class AnsiFunction(GenericFunction):
- def __init__(self, **kwargs):
- GenericFunction.__init__(self, **kwargs)
+ def __init__(self, *args, **kwargs):
+ GenericFunction.__init__(self, *args, **kwargs)
class ReturnTypeFromArgs(GenericFunction):
class coalesce(ReturnTypeFromArgs):
- pass
+ _has_args = True
class max(ReturnTypeFromArgs):
class random(GenericFunction):
- pass
+ _has_args = True
class count(GenericFunction):
.. versionadded:: 1.2
"""
+ _has_args = True
class rollup(GenericFunction):
.. versionadded:: 1.2
"""
+ _has_args = True
class grouping_sets(GenericFunction):
.. versionadded:: 1.2
"""
+ _has_args = True
from sqlalchemy import testing
from sqlalchemy.testing import fixtures, AssertsCompiledSQL, engines
from sqlalchemy.dialects import sqlite, postgresql, mysql, oracle
-from sqlalchemy.testing import assert_raises_message
+from sqlalchemy.testing import assert_raises_message, assert_raises
table1 = table('mytable',
column('myid', Integer),
pass
assert isinstance(func.myfunc(), myfunc)
+ self.assert_compile(func.myfunc(), "myfunc()")
def test_custom_type(self):
class myfunc(GenericFunction):
type = DateTime
assert isinstance(func.myfunc().type, DateTime)
+ self.assert_compile(func.myfunc(), "myfunc()")
def test_custom_legacy_type(self):
# in case someone was using this system
c = column('abc')
self.assert_compile(func.count(c), 'count(abc)')
- def test_constructor(self):
- try:
- func.current_timestamp('somearg')
- assert False
- except TypeError:
- assert True
-
- try:
- func.char_length('a', 'b')
- assert False
- except TypeError:
- assert True
+ def test_ansi_functions_with_args(self):
+ ct = func.current_timestamp('somearg')
+ self.assert_compile(ct, "CURRENT_TIMESTAMP(:current_timestamp_1)")
- try:
- func.char_length()
- assert False
- except TypeError:
- assert True
+ def test_char_length_fixed_args(self):
+ assert_raises(
+ TypeError,
+ func.char_length, 'a', 'b'
+ )
+ assert_raises(
+ TypeError,
+ func.char_length
+ )
def test_return_type_detection(self):