]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- tweak the GenericFunction constructor more so that it's action in parsing the
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 26 Aug 2012 15:30:42 +0000 (11:30 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 26 Aug 2012 15:30:42 +0000 (11:30 -0400)
arguments is easier to understand
- add a test to ensure generic function can have a custom name

lib/sqlalchemy/sql/functions.py
test/sql/test_functions.py

index fe64764b77f5e3daf8a4a2b2467233523c8a7a0e..24f7e81b4132f24b5799eda6488a6237c12ced1b 100644 (file)
@@ -25,11 +25,6 @@ class _GenericMeta(VisitableType):
         reg[name] = cls
         super(_GenericMeta, cls).__init__(clsname, bases, clsdict)
 
-    def __call__(cls, *args, **kwargs):
-        if cls.coerce_arguments:
-            args = [_literal_as_binds(c) for c in args]
-        return type.__call__(cls, *args, **kwargs)
-
 class GenericFunction(Function):
     """Define a 'generic' function.
 
@@ -88,11 +83,14 @@ class GenericFunction(Function):
 
     coerce_arguments = True
     def __init__(self, *args, **kwargs):
+        parsed_args = kwargs.pop('_parsed_args', None)
+        if parsed_args is None:
+            parsed_args = [_literal_as_binds(c) for c in args]
         self.packagenames = []
         self._bind = kwargs.get('bind', None)
         self.clause_expr = ClauseList(
                 operator=operators.comma_op,
-                group_contents=True, *args).self_group()
+                group_contents=True, *parsed_args).self_group()
         self.type = sqltypes.to_instance(
             kwargs.pop("type_", None) or getattr(self, 'type', None))
 
@@ -108,7 +106,6 @@ class next_value(GenericFunction):
     """
     type = sqltypes.Integer()
     name = "next_value"
-    coerce_arguments = False
 
     def __init__(self, seq, **kw):
         assert isinstance(seq, schema.Sequence), \
@@ -128,7 +125,9 @@ class ReturnTypeFromArgs(GenericFunction):
     """Define a function whose return type is the same as its arguments."""
 
     def __init__(self, *args, **kwargs):
+        args = [_literal_as_binds(c) for c in args]
         kwargs.setdefault('type_', _type_from_args(args))
+        kwargs['_parsed_args'] = args
         GenericFunction.__init__(self, *args, **kwargs)
 
 class coalesce(ReturnTypeFromArgs):
index b69f8f6ba53edc70da86bb839802c68a63d8134a..fc227f37533405561317d72e97ffb46c3f4ed603 100644 (file)
@@ -114,6 +114,19 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
         assert isinstance(func.mypackage.myfunc(), f1)
         assert isinstance(func.myotherpackage.myfunc(), f2)
 
+    def test_custom_name(self):
+        class MyFunction(GenericFunction):
+            name = 'my_func'
+
+            def __init__(self, *args):
+                args = args + (3,)
+                super(MyFunction, self).__init__(*args)
+
+        self.assert_compile(
+            func.my_func(1, 2),
+            "my_func(:param_1, :param_2, :param_3)"
+        )
+
     def test_custom_args(self):
         class myfunc(GenericFunction):
             pass