]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Use dict instead of list and use more sqlalchemy.testing functions.
authorAdrien Berchet <adrien.berchet@gmail.com>
Sun, 7 Apr 2019 20:25:40 +0000 (22:25 +0200)
committerAdrien Berchet <adrien.berchet@gmail.com>
Sun, 7 Apr 2019 20:31:05 +0000 (22:31 +0200)
lib/sqlalchemy/sql/functions.py
test/sql/test_deprecations.py
test/sql/test_functions.py

index ed3fbaf2f4d95bb5d5e94b7e1d763454a30270ab..dc97a97f241e1e4719573c5ea79de63ca8effbb0 100644 (file)
@@ -8,7 +8,6 @@
 """SQL function API, factories, and built-in functions.
 
 """
-from functools import partial
 import warnings
 
 from . import annotation
@@ -40,10 +39,10 @@ from .. import exc as sa_exc
 from .. import util
 
 
-_RegistryList = partial(util.defaultdict, list)
-
 _registry = util.defaultdict(dict)
-_case_sensitive_reg = util.defaultdict(_RegistryList)
+_case_sensitive_registry = util.defaultdict(
+    lambda: util.defaultdict(dict)
+)
 
 
 def register_function(identifier, fn, package="_default"):
@@ -56,25 +55,26 @@ def register_function(identifier, fn, package="_default"):
 
     """
     reg = _registry[package]
+    case_sensitive_reg = _case_sensitive_registry[package]
     raw_identifier = identifier
     identifier = identifier.lower()
 
     # Check if a function with the same lowercase identifier is registered.
-    if identifier in _case_sensitive_reg[package]:
+    if case_sensitive_reg[identifier]:
         if (
-            raw_identifier not in _case_sensitive_reg[package][identifier]
+            raw_identifier not in case_sensitive_reg[identifier]
         ):
             util.warn_deprecated(
-                "GenericFunction(s) '{}' are already registered with "
+                "GenericFunction(s) {} are already registered with "
                 "different letter cases and might interact with {}.".format(
-                    _case_sensitive_reg[package][identifier],
+                    list(case_sensitive_reg[identifier].keys()),
                     raw_identifier))
             # If a function with the same lowercase identifier is registered,
             # then these 2 functions are considered as case-sensitive.
-            if len(_case_sensitive_reg[package][identifier]) == 1:
-                old_fn = reg[identifier]
-                del reg[identifier]
-                reg[_case_sensitive_reg[package][identifier][0]] = old_fn
+            if len(case_sensitive_reg[identifier]) == 1:
+                reg.pop(identifier, None)
+                reg.update(case_sensitive_reg[identifier])
+
         else:
             warnings.warn(
                 "The GenericFunction '{}' is already registered and "
@@ -86,8 +86,8 @@ def register_function(identifier, fn, package="_default"):
         reg[identifier] = fn
 
     # Add the raw_identifier to the case-sensitive registry
-    if raw_identifier not in _case_sensitive_reg[package][identifier]:
-        _case_sensitive_reg[package][identifier].append(raw_identifier)
+    if raw_identifier not in case_sensitive_reg[identifier]:
+        case_sensitive_reg[identifier][raw_identifier] = fn
 
 
 class FunctionElement(Executable, ColumnElement, FromClause):
@@ -478,12 +478,14 @@ class _FunctionGenerator(object):
             package = None
 
         if package is not None:
+            reg = _registry[package]
+            case_sensitive_reg = _case_sensitive_registry[package]
             if (
-                len(_case_sensitive_reg[package].get(fname.lower(), [])) > 1
+                len(case_sensitive_reg.get(fname.lower(), [])) > 1
             ):
-                func = _registry[package].get(fname)
+                func = reg.get(fname)
             else:
-                func = _registry[package].get(fname.lower())
+                func = reg.get(fname.lower())
             if func is not None:
                 return func(*c, **o)
 
index 27cd8dcc53a66cc2b6809d2cf389dfe512bb24cd..4b55332b490babb2fcb109b7c88abc994a088c47 100644 (file)
@@ -33,7 +33,9 @@ from sqlalchemy.testing import AssertsCompiledSQL
 from sqlalchemy.testing import engines
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import fixtures
+from sqlalchemy.testing import in_
 from sqlalchemy.testing import mock
+from sqlalchemy.testing import not_in_
 from sqlalchemy.testing.assertions import expect_warnings
 
 
@@ -42,11 +44,12 @@ class DeprecationWarningsTest(fixtures.TestBase):
 
     def setup_method(self):
         self._registry = deepcopy(functions._registry)
-        self._case_sensitive_reg = deepcopy(functions._case_sensitive_reg)
+        self._case_sensitive_registry = deepcopy(
+            functions._case_sensitive_registry)
 
     def teardown_method(self):
         functions._registry = self._registry
-        functions._case_sensitive_reg = self._case_sensitive_reg
+        functions._case_sensitive_registry = self._case_sensitive_registry
 
     def test_ident_preparer_force(self):
         preparer = testing.db.dialect.identifier_preparer
@@ -155,7 +158,7 @@ class DeprecationWarningsTest(fixtures.TestBase):
 
     def test_case_sensitive(self):
         reg = functions._registry
-        cs_reg = functions._case_sensitive_reg
+        cs_reg = functions._case_sensitive_registry
 
         class MYFUNC(GenericFunction):
             type = DateTime
@@ -165,11 +168,11 @@ class DeprecationWarningsTest(fixtures.TestBase):
         assert isinstance(func.mYfUnC().type, DateTime)
         assert isinstance(func.myfunc().type, DateTime)
 
-        assert "myfunc" in reg['_default']
-        assert "MYFUNC" not in reg['_default']
-        assert "MyFunc" not in reg['_default']
-        assert "myfunc" in cs_reg['_default']
-        assert cs_reg['_default']['myfunc'] == ['MYFUNC']
+        in_("myfunc", reg['_default'])
+        not_in_("MYFUNC", reg['_default'])
+        not_in_("MyFunc", reg['_default'])
+        in_("myfunc", cs_reg['_default'])
+        eq_(list(cs_reg['_default']['myfunc'].keys()), ['MYFUNC'])
 
         with testing.expect_deprecated():
             class MyFunc(GenericFunction):
@@ -182,15 +185,15 @@ class DeprecationWarningsTest(fixtures.TestBase):
         with pytest.raises(AssertionError):
             assert isinstance(func.myfunc().type, Integer)
 
-        assert "myfunc" not in reg['_default']
-        assert "MYFUNC" in reg['_default']
-        assert "MyFunc" in reg['_default']
-        assert "myfunc" in cs_reg['_default']
-        assert cs_reg['_default']['myfunc'] == ['MYFUNC', 'MyFunc']
+        not_in_("myfunc", reg['_default'])
+        in_("MYFUNC", reg['_default'])
+        in_("MyFunc", reg['_default'])
+        in_("myfunc", cs_reg['_default'])
+        eq_(list(cs_reg['_default']['myfunc'].keys()), ['MYFUNC', 'MyFunc'])
 
     def test_replace_function_case_sensitive(self):
         reg = functions._registry
-        cs_reg = functions._case_sensitive_reg
+        cs_reg = functions._case_sensitive_registry
 
         class replaceable_func(GenericFunction):
             type = Integer
@@ -201,11 +204,12 @@ class DeprecationWarningsTest(fixtures.TestBase):
         assert isinstance(func.RePlAcEaBlE_fUnC().type, Integer)
         assert isinstance(func.replaceable_func().type, Integer)
 
-        assert "replaceable_func" in reg['_default']
-        assert "REPLACEABLE_FUNC" not in reg['_default']
-        assert "Replaceable_Func" not in reg['_default']
-        assert "replaceable_func" in cs_reg['_default']
-        assert cs_reg['_default']['replaceable_func'] == ['REPLACEABLE_FUNC']
+        in_("replaceable_func", reg['_default'])
+        not_in_("REPLACEABLE_FUNC", reg['_default'])
+        not_in_("Replaceable_Func", reg['_default'])
+        in_("replaceable_func", cs_reg['_default'])
+        eq_(list(cs_reg['_default']['replaceable_func'].keys()),
+            ['REPLACEABLE_FUNC'])
 
         with testing.expect_deprecated():
             class Replaceable_Func(GenericFunction):
@@ -217,12 +221,12 @@ class DeprecationWarningsTest(fixtures.TestBase):
         assert isinstance(func.RePlAcEaBlE_fUnC().type, NullType)
         assert isinstance(func.replaceable_func().type, NullType)
 
-        assert "replaceable_func" not in reg['_default']
-        assert "REPLACEABLE_FUNC" in reg['_default']
-        assert "Replaceable_Func" in reg['_default']
-        assert "replaceable_func" in cs_reg['_default']
-        assert cs_reg['_default']['replaceable_func'] == ['REPLACEABLE_FUNC',
-                                                          'Replaceable_Func']
+        not_in_("replaceable_func", reg['_default'])
+        in_("REPLACEABLE_FUNC", reg['_default'])
+        in_("Replaceable_Func", reg['_default'])
+        in_("replaceable_func", cs_reg['_default'])
+        eq_(list(cs_reg['_default']['replaceable_func'].keys()),
+            ['REPLACEABLE_FUNC', 'Replaceable_Func'])
 
         with expect_warnings():
             class replaceable_func_override(GenericFunction):
@@ -244,13 +248,12 @@ class DeprecationWarningsTest(fixtures.TestBase):
         assert isinstance(func.RePlAcEaBlE_fUnC().type, NullType)
         assert isinstance(func.replaceable_func().type, String)
 
-        assert "replaceable_func" in reg['_default']
-        assert "REPLACEABLE_FUNC" in reg['_default']
-        assert "Replaceable_Func" in reg['_default']
-        assert "replaceable_func" in cs_reg['_default']
-        assert cs_reg['_default']['replaceable_func'] == ['REPLACEABLE_FUNC',
-                                                          'Replaceable_Func',
-                                                          'replaceable_func']
+        in_("replaceable_func", reg['_default'])
+        in_("REPLACEABLE_FUNC", reg['_default'])
+        in_("Replaceable_Func", reg['_default'])
+        in_("replaceable_func", cs_reg['_default'])
+        eq_(list(cs_reg['_default']['replaceable_func'].keys()),
+            ['REPLACEABLE_FUNC', 'Replaceable_Func', 'replaceable_func'])
 
 
 class DDLListenerDeprecationsTest(fixtures.TestBase):
index 62107ee5bda61093bdbd6abb136d276e02e35747..e4b40892d0129728f0b8d918e08627330891986d 100644 (file)
@@ -57,11 +57,12 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
 
     def setup_method(self):
         self._registry = deepcopy(functions._registry)
-        self._case_sensitive_reg = deepcopy(functions._case_sensitive_reg)
+        self._case_sensitive_registry = deepcopy(
+            functions._case_sensitive_registry)
 
     def teardown_method(self):
         functions._registry = self._registry
-        functions._case_sensitive_reg = self._case_sensitive_reg
+        functions._case_sensitive_registry = self._case_sensitive_registry
 
     def test_compile(self):
         for dialect in all_dialects(exclude=("sybase",)):
@@ -95,7 +96,7 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
             )
 
             functions._registry['_default'].pop('fake_func')
-            functions._case_sensitive_reg['_default'].pop('fake_func')
+            functions._case_sensitive_registry['_default'].pop('fake_func')
 
     def test_use_labels(self):
         self.assert_compile(