]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
The GenericFunction class is no more registered
authorAdrien Berchet <adrien.berchet@gmail.com>
Thu, 2 May 2019 10:16:54 +0000 (12:16 +0200)
committerAdrien Berchet <adrien.berchet@gmail.com>
Thu, 2 May 2019 10:16:54 +0000 (12:16 +0200)
lib/sqlalchemy/sql/functions.py
test/sql/test_functions.py

index feb4fdb90b8c241c039e52946d0257cd04e4cea0..d0aa239881222d90d9dcb5ce08e74649b6e5de24 100644 (file)
@@ -603,7 +603,17 @@ class _GenericMeta(VisitableType):
             # legacy
             if "__return_type__" in clsdict:
                 cls.type = clsdict["__return_type__"]
-            register_function(identifier, cls, package)
+
+            # Check _register attribute status
+            cls._register = getattr(cls, '_register', True)
+
+            # Register the function if required
+            if cls._register:
+                register_function(identifier, cls, package)
+            else:
+                # Set _register to True to register child classes by default
+                cls._register = True
+
         super(_GenericMeta, cls).__init__(clsname, bases, clsdict)
 
 
@@ -671,6 +681,7 @@ class GenericFunction(util.with_metaclass(_GenericMeta, Function)):
     """
 
     coerce_arguments = True
+    _register = False
 
     def __init__(self, *args, **kwargs):
         parsed_args = kwargs.pop("_parsed_args", None)
index ac711c669e48dde6b27ecab1c5a89937b5923027..feb2b9b891570e2b99775cb655f9cf49e549f3ca 100644 (file)
@@ -1075,3 +1075,43 @@ def exec_sorted(statement, *args, **kw):
     return sorted(
         [tuple(row) for row in statement.execute(*args, **kw).fetchall()]
     )
+
+
+class RegisterTest(fixtures.TestBase, AssertsCompiledSQL):
+    __dialect__ = "default"
+
+    def setup(self):
+        self._registry = deepcopy(functions._registry)
+
+    def teardown(self):
+        functions._registry = self._registry
+
+    def test_GenericFunction_is_registered(self):
+        assert 'GenericFunction' not in functions._registry['_default']
+
+    def test_register_function(self):
+
+            # test generic function registering
+            class registered_func(GenericFunction):
+                _register = True
+
+                def __init__(self, *args, **kwargs):
+                    GenericFunction.__init__(self, *args, **kwargs)
+
+            class registered_func_child(registered_func):
+                type = sqltypes.Integer
+
+            assert 'registered_func' in functions._registry['_default']
+            assert isinstance(func.registered_func_child().type, Integer)
+
+            class not_registered_func(GenericFunction):
+                _register = False
+
+                def __init__(self, *args, **kwargs):
+                    GenericFunction.__init__(self, *args, **kwargs)
+
+            class not_registered_func_child(not_registered_func):
+                type = sqltypes.Integer
+
+            assert 'not_registered_func' not in functions._registry['_default']
+            assert isinstance(func.not_registered_func_child().type, Integer)