]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Try to deal with the len() calls.
authorAdrien Berchet <adrien.berchet@gmail.com>
Sun, 7 Apr 2019 21:27:52 +0000 (23:27 +0200)
committerAdrien Berchet <adrien.berchet@gmail.com>
Sun, 7 Apr 2019 21:47:38 +0000 (23:47 +0200)
lib/sqlalchemy/sql/functions.py
test/sql/test_deprecations.py

index dc97a97f241e1e4719573c5ea79de63ca8effbb0..f5720408bc8e319f97902f6adaf35e77a0ab16ba 100644 (file)
@@ -60,7 +60,22 @@ def register_function(identifier, fn, package="_default"):
     identifier = identifier.lower()
 
     # Check if a function with the same lowercase identifier is registered.
-    if case_sensitive_reg[identifier]:
+    if identifier in reg:
+        if raw_identifier in case_sensitive_reg[identifier]:
+            warnings.warn(
+                "The GenericFunction '{}' is already registered and "
+                "is going to be overriden.".format(identifier),
+                sa_exc.SAWarning)
+            reg[identifier] = fn
+        else:
+            # If a function with the same lowercase identifier is registered,
+            # then these 2 functions are considered as case-sensitive.
+            # Note: This case should raise an error in a later release.
+            reg.pop(identifier)
+
+    # Check if a function with different letter case identifier is registered.
+    elif identifier in case_sensitive_reg:
+        # Note: This case will be removed in a later release.
         if (
             raw_identifier not in case_sensitive_reg[identifier]
         ):
@@ -69,11 +84,6 @@ def register_function(identifier, fn, package="_default"):
                 "different letter cases and might interact with {}.".format(
                     list(case_sensitive_reg[identifier].keys()),
                     raw_identifier))
-            # If a function with the same lowercase identifier is registered,
-            # then these 2 functions are considered as case-sensitive.
-            if len(case_sensitive_reg[identifier]) == 1:
-                reg.pop(identifier, None)
-                reg.update(case_sensitive_reg[identifier])
 
         else:
             warnings.warn(
@@ -81,13 +91,12 @@ def register_function(identifier, fn, package="_default"):
                 "is going to be overriden.".format(identifier),
                 sa_exc.SAWarning)
 
-        reg[raw_identifier] = fn
+    # Register by default
     else:
         reg[identifier] = fn
 
-    # Add the raw_identifier to the case-sensitive registry
-    if raw_identifier not in case_sensitive_reg[identifier]:
-        case_sensitive_reg[identifier][raw_identifier] = fn
+    # Always register in case-sensitive registry
+    case_sensitive_reg[identifier][raw_identifier] = fn
 
 
 class FunctionElement(Executable, ColumnElement, FromClause):
@@ -480,12 +489,9 @@ class _FunctionGenerator(object):
         if package is not None:
             reg = _registry[package]
             case_sensitive_reg = _case_sensitive_registry[package]
-            if (
-                len(case_sensitive_reg.get(fname.lower(), [])) > 1
-            ):
-                func = reg.get(fname)
-            else:
-                func = reg.get(fname.lower())
+            func = reg.get(fname.lower())
+            if func is None and fname.lower() in case_sensitive_reg:
+                func = case_sensitive_reg[fname.lower()].get(fname)
             if func is not None:
                 return func(*c, **o)
 
index 4b55332b490babb2fcb109b7c88abc994a088c47..7a65c59efcd54d78041a688518785c993df94e85 100644 (file)
@@ -157,8 +157,8 @@ class DeprecationWarningsTest(fixtures.TestBase):
             )
 
     def test_case_sensitive(self):
-        reg = functions._registry
-        cs_reg = functions._case_sensitive_registry
+        reg = functions._registry['_default']
+        cs_reg = functions._case_sensitive_registry['_default']
 
         class MYFUNC(GenericFunction):
             type = DateTime
@@ -168,11 +168,11 @@ class DeprecationWarningsTest(fixtures.TestBase):
         assert isinstance(func.mYfUnC().type, DateTime)
         assert isinstance(func.myfunc().type, DateTime)
 
-        in_("myfunc", reg['_default'])
-        not_in_("MYFUNC", reg['_default'])
-        not_in_("MyFunc", reg['_default'])
-        in_("myfunc", cs_reg['_default'])
-        eq_(list(cs_reg['_default']['myfunc'].keys()), ['MYFUNC'])
+        in_("myfunc", reg)
+        not_in_("MYFUNC", reg)
+        not_in_("MyFunc", reg)
+        in_("myfunc", cs_reg)
+        eq_(list(cs_reg['myfunc'].keys()), ['MYFUNC'])
 
         with testing.expect_deprecated():
             class MyFunc(GenericFunction):
@@ -185,15 +185,15 @@ class DeprecationWarningsTest(fixtures.TestBase):
         with pytest.raises(AssertionError):
             assert isinstance(func.myfunc().type, Integer)
 
-        not_in_("myfunc", reg['_default'])
-        in_("MYFUNC", reg['_default'])
-        in_("MyFunc", reg['_default'])
-        in_("myfunc", cs_reg['_default'])
-        eq_(list(cs_reg['_default']['myfunc'].keys()), ['MYFUNC', 'MyFunc'])
+        not_in_("myfunc", reg)
+        not_in_("MYFUNC", reg)
+        not_in_("MyFunc", reg)
+        in_("myfunc", cs_reg)
+        eq_(list(cs_reg['myfunc'].keys()), ['MYFUNC', 'MyFunc'])
 
     def test_replace_function_case_sensitive(self):
-        reg = functions._registry
-        cs_reg = functions._case_sensitive_registry
+        reg = functions._registry['_default']
+        cs_reg = functions._case_sensitive_registry['_default']
 
         class replaceable_func(GenericFunction):
             type = Integer
@@ -204,11 +204,11 @@ class DeprecationWarningsTest(fixtures.TestBase):
         assert isinstance(func.RePlAcEaBlE_fUnC().type, Integer)
         assert isinstance(func.replaceable_func().type, Integer)
 
-        in_("replaceable_func", reg['_default'])
-        not_in_("REPLACEABLE_FUNC", reg['_default'])
-        not_in_("Replaceable_Func", reg['_default'])
-        in_("replaceable_func", cs_reg['_default'])
-        eq_(list(cs_reg['_default']['replaceable_func'].keys()),
+        in_("replaceable_func", reg)
+        not_in_("REPLACEABLE_FUNC", reg)
+        not_in_("Replaceable_Func", reg)
+        in_("replaceable_func", cs_reg)
+        eq_(list(cs_reg['replaceable_func'].keys()),
             ['REPLACEABLE_FUNC'])
 
         with testing.expect_deprecated():
@@ -221,11 +221,11 @@ class DeprecationWarningsTest(fixtures.TestBase):
         assert isinstance(func.RePlAcEaBlE_fUnC().type, NullType)
         assert isinstance(func.replaceable_func().type, NullType)
 
-        not_in_("replaceable_func", reg['_default'])
-        in_("REPLACEABLE_FUNC", reg['_default'])
-        in_("Replaceable_Func", reg['_default'])
-        in_("replaceable_func", cs_reg['_default'])
-        eq_(list(cs_reg['_default']['replaceable_func'].keys()),
+        not_in_("replaceable_func", reg)
+        not_in_("REPLACEABLE_FUNC", reg)
+        not_in_("Replaceable_Func", reg)
+        in_("replaceable_func", cs_reg)
+        eq_(list(cs_reg['replaceable_func'].keys()),
             ['REPLACEABLE_FUNC', 'Replaceable_Func'])
 
         with expect_warnings():
@@ -248,11 +248,11 @@ class DeprecationWarningsTest(fixtures.TestBase):
         assert isinstance(func.RePlAcEaBlE_fUnC().type, NullType)
         assert isinstance(func.replaceable_func().type, String)
 
-        in_("replaceable_func", reg['_default'])
-        in_("REPLACEABLE_FUNC", reg['_default'])
-        in_("Replaceable_Func", reg['_default'])
-        in_("replaceable_func", cs_reg['_default'])
-        eq_(list(cs_reg['_default']['replaceable_func'].keys()),
+        not_in_("replaceable_func", reg)
+        not_in_("REPLACEABLE_FUNC", reg)
+        not_in_("Replaceable_Func", reg)
+        in_("replaceable_func", cs_reg)
+        eq_(list(cs_reg['replaceable_func'].keys()),
             ['REPLACEABLE_FUNC', 'Replaceable_Func', 'replaceable_func'])