]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Clean code and tests. Use sqlalchemy warnings.
authorAdrien Berchet <adrien.berchet@gmail.com>
Fri, 29 Mar 2019 14:07:23 +0000 (15:07 +0100)
committerAdrien Berchet <adrien.berchet@gmail.com>
Fri, 29 Mar 2019 14:38:45 +0000 (15:38 +0100)
lib/sqlalchemy/sql/functions.py
test/sql/test_deprecations.py
test/sql/test_functions.py

index d104925a33fc4de8c82fce1a9da4b723f1d14795..3c268a59723e5a42c4336aefc4c5971975335207 100644 (file)
@@ -60,31 +60,32 @@ def register_function(identifier, fn, package="_default"):
     identifier = identifier.lower()
 
     # Check if a function with the same lowercase identifier is registered.
-    if identifier in _case_sensitive_reg[package]:
+    if identifier in _registry[package]:
         if (
             raw_identifier not in _case_sensitive_reg[package][identifier]
         ):
-            warnings.warn(
+            util.warn_deprecated(
                 "GenericFunction(s) '{}' are already registered with "
                 "different letter cases and might interact with {}.".format(
                     _case_sensitive_reg[package][identifier],
-                    raw_identifier),
-                sa_exc.SADeprecationWarning)
+                    raw_identifier))
+            # If a function with the same lowercase identifier is registered,
+            # then these 2 functions are considered as case-sensitive.
+            if len(_case_sensitive_reg[package][identifier]) == 1:
+                old_fn = reg[identifier]
+                del reg[identifier]
+                reg[_case_sensitive_reg[package][identifier][0]] = old_fn
         else:
             warnings.warn(
                 "The GenericFunction '{}' is already registered and "
-                "is going to be overriden.".format(identifier))
-
-        # If a function with the same lowercase identifier is registered, then
-        # these 2 functions are considered as case-sensitive.
-        if len(_case_sensitive_reg[package][identifier]) == 1:
-            old_fn = reg[identifier]
-            del reg[identifier]
-            reg[_case_sensitive_reg[package][identifier][0]] = old_fn
+                "is going to be overriden.".format(identifier),
+                sa_exc.SAWarning)
+
         reg[raw_identifier] = fn
     else:
         reg[identifier] = fn
 
+    # Add the raw_identifier to the case-sensitive registry
     if raw_identifier not in _case_sensitive_reg[package][identifier]:
         _case_sensitive_reg[package][identifier].append(raw_identifier)
 
index 7990cd56c69ab2301901bee4f5651fe26fa33233..f84859794ace3e34bab07f804052613368d3ddfb 100644 (file)
@@ -1,11 +1,16 @@
 #! coding: utf-8
 
+from copy import deepcopy
+import pytest
+
 from sqlalchemy import bindparam
 from sqlalchemy import Column
 from sqlalchemy import column
 from sqlalchemy import create_engine
+from sqlalchemy import DateTime
 from sqlalchemy import exc
 from sqlalchemy import ForeignKey
+from sqlalchemy import func
 from sqlalchemy import Integer
 from sqlalchemy import MetaData
 from sqlalchemy import select
@@ -18,6 +23,9 @@ from sqlalchemy import util
 from sqlalchemy.engine import default
 from sqlalchemy.schema import DDL
 from sqlalchemy.sql import util as sql_util
+from sqlalchemy.sql import functions
+from sqlalchemy.sql.functions import GenericFunction
+from sqlalchemy.sql.sqltypes import NullType
 from sqlalchemy.testing import assert_raises
 from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.testing import AssertsCompiledSQL
@@ -25,11 +33,20 @@ from sqlalchemy.testing import engines
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import mock
+from sqlalchemy.testing.assertions import expect_warnings
 
 
 class DeprecationWarningsTest(fixtures.TestBase):
     __backend__ = True
 
+    def setup_method(self):
+        self._registry = deepcopy(functions._registry)
+        self._case_sensitive_reg = deepcopy(functions._case_sensitive_reg)
+
+    def teardown_method(self):
+        functions._registry = self._registry
+        functions._case_sensitive_reg = self._case_sensitive_reg
+
     def test_ident_preparer_force(self):
         preparer = testing.db.dialect.identifier_preparer
         preparer.quote("hi")
@@ -135,6 +152,59 @@ class DeprecationWarningsTest(fixtures.TestBase):
                 autoload_with=testing.db,
             )
 
+    def test_case_sensitive(self):
+        class MYFUNC(GenericFunction):
+            type = DateTime
+
+        assert isinstance(func.MYFUNC().type, DateTime)
+        assert isinstance(func.MyFunc().type, DateTime)
+        assert isinstance(func.mYfUnC().type, DateTime)
+        assert isinstance(func.myfunc().type, DateTime)
+
+        with testing.expect_deprecated():
+            class MyFunc(GenericFunction):
+                type = Integer
+
+        assert isinstance(func.MYFUNC().type, DateTime)
+        assert isinstance(func.MyFunc().type, Integer)
+        with pytest.raises(AssertionError):
+            assert isinstance(func.mYfUnC().type, Integer)
+        with pytest.raises(AssertionError):
+            assert isinstance(func.myfunc().type, Integer)
+
+    def test_replace_function_case_sensitive(self):
+
+        class replacable_func(GenericFunction):
+            __return_type__ = Integer
+            identifier = 'replacable_func'
+
+        assert isinstance(func.Replacable_Func().type, Integer)
+        assert isinstance(func.RePlAcaBlE_fUnC().type, Integer)
+        assert isinstance(func.replacable_func().type, Integer)
+
+        with testing.expect_deprecated():
+            class Replacable_Func(GenericFunction):
+                __return_type__ = DateTime
+                identifier = 'Replacable_Func'
+
+        assert isinstance(func.Replacable_Func().type, DateTime)
+        assert isinstance(func.RePlAcaBlE_fUnC().type, NullType)
+        assert isinstance(func.replacable_func().type, Integer)
+
+        with expect_warnings():
+            class replacable_func_override(GenericFunction):
+                __return_type__ = DateTime
+                identifier = 'replacable_func'
+
+        with expect_warnings():
+            class Replacable_Func_override(GenericFunction):
+                __return_type__ = Integer
+                identifier = 'Replacable_Func'
+
+        assert isinstance(func.Replacable_Func().type, Integer)
+        assert isinstance(func.RePlAcaBlE_fUnC().type, NullType)
+        assert isinstance(func.replacable_func().type, DateTime)
+
 
 class DDLListenerDeprecationsTest(fixtures.TestBase):
     def setup(self):
index 311b476a806a59dc6d8f9e4c831bf69af9f5f90a..62107ee5bda61093bdbd6abb136d276e02e35747 100644 (file)
@@ -2,8 +2,6 @@ from copy import deepcopy
 import datetime
 import decimal
 
-import pytest
-
 from sqlalchemy import ARRAY
 from sqlalchemy import bindparam
 from sqlalchemy import Boolean
@@ -35,7 +33,6 @@ from sqlalchemy.sql import table
 from sqlalchemy.sql.compiler import BIND_TEMPLATES
 from sqlalchemy.sql.functions import FunctionElement
 from sqlalchemy.sql.functions import GenericFunction
-from sqlalchemy.sql.sqltypes import NullType
 from sqlalchemy.testing import assert_raises
 from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.testing import AssertsCompiledSQL
@@ -43,6 +40,7 @@ from sqlalchemy.testing import engines
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import is_
+from sqlalchemy.testing.assertions import expect_warnings
 from sqlalchemy.testing.engines import all_dialects
 
 
@@ -96,6 +94,9 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
                 dialect=dialect,
             )
 
+            functions._registry['_default'].pop('fake_func')
+            functions._case_sensitive_reg['_default'].pop('fake_func')
+
     def test_use_labels(self):
         self.assert_compile(
             select([func.foo()], use_labels=True), "SELECT foo() AS foo_1"
@@ -237,46 +238,23 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
         assert isinstance(func.mYfUnC().type, DateTime)
         assert isinstance(func.myfunc().type, DateTime)
 
-        with testing.expect_deprecated():
-            class MyFunc(GenericFunction):
-                type = Integer
-
-        assert isinstance(func.MYFUNC().type, DateTime)
-        assert isinstance(func.MyFunc().type, Integer)
-        with pytest.raises(AssertionError):
-            assert isinstance(func.mYfUnC().type, Integer)
-        with pytest.raises(AssertionError):
-            assert isinstance(func.myfunc().type, Integer)
-
     def test_replace_function(self):
 
         class replacable_func(GenericFunction):
-            __return_type__ = Integer
+            type = Integer
             identifier = 'replacable_func'
 
         assert isinstance(func.Replacable_Func().type, Integer)
         assert isinstance(func.RePlAcaBlE_fUnC().type, Integer)
         assert isinstance(func.replacable_func().type, Integer)
 
-        with testing.expect_deprecated():
-            class Replacable_Func(GenericFunction):
-                __return_type__ = DateTime
-                identifier = 'Replacable_Func'
+        with expect_warnings():
+            class replacable_func_override(GenericFunction):
+                type = DateTime
+                identifier = 'replacable_func'
 
         assert isinstance(func.Replacable_Func().type, DateTime)
-        assert isinstance(func.RePlAcaBlE_fUnC().type, NullType)
-        assert isinstance(func.replacable_func().type, Integer)
-
-        class replacable_func_override(GenericFunction):
-            __return_type__ = DateTime
-            identifier = 'replacable_func'
-
-        class Replacable_Func_override(GenericFunction):
-            __return_type__ = Integer
-            identifier = 'Replacable_Func'
-
-        assert isinstance(func.Replacable_Func().type, Integer)
-        assert isinstance(func.RePlAcaBlE_fUnC().type, NullType)
+        assert isinstance(func.RePlAcaBlE_fUnC().type, DateTime)
         assert isinstance(func.replacable_func().type, DateTime)
 
     def test_custom_w_custom_name(self):
@@ -341,22 +319,6 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
         self.assert_compile(func.bUf4_(), "BufferFour()")
         self.assert_compile(func.buf4(), "BufferFour()")
 
-        with testing.expect_deprecated():
-            class geobufferfour_lowercase(GenericFunction):
-                type = Integer
-                name = "BufferFour_lowercase"
-                identifier = "buf4"
-
-        with testing.expect_deprecated():
-            class geobufferfour_random_case(GenericFunction):
-                type = Integer
-                name = "BuFferFouR"
-                identifier = "BuF4"
-
-        self.assert_compile(func.Buf4(), "BufferFour()")
-        self.assert_compile(func.BuF4(), "BuFferFouR()")
-        self.assert_compile(func.buf4(), "BufferFour_lowercase()")
-
     def test_custom_args(self):
         class myfunc(GenericFunction):
             pass