From: Adrien Berchet Date: Sun, 7 Apr 2019 20:25:40 +0000 (+0200) Subject: Use dict instead of list and use more sqlalchemy.testing functions. X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=79bd3530ea08f679a7fde27973e98edd57b81405;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Use dict instead of list and use more sqlalchemy.testing functions. --- diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index ed3fbaf2f4..dc97a97f24 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -8,7 +8,6 @@ """SQL function API, factories, and built-in functions. """ -from functools import partial import warnings from . import annotation @@ -40,10 +39,10 @@ from .. import exc as sa_exc from .. import util -_RegistryList = partial(util.defaultdict, list) - _registry = util.defaultdict(dict) -_case_sensitive_reg = util.defaultdict(_RegistryList) +_case_sensitive_registry = util.defaultdict( + lambda: util.defaultdict(dict) +) def register_function(identifier, fn, package="_default"): @@ -56,25 +55,26 @@ def register_function(identifier, fn, package="_default"): """ reg = _registry[package] + case_sensitive_reg = _case_sensitive_registry[package] raw_identifier = identifier identifier = identifier.lower() # Check if a function with the same lowercase identifier is registered. - if identifier in _case_sensitive_reg[package]: + if case_sensitive_reg[identifier]: if ( - raw_identifier not in _case_sensitive_reg[package][identifier] + raw_identifier not in case_sensitive_reg[identifier] ): util.warn_deprecated( - "GenericFunction(s) '{}' are already registered with " + "GenericFunction(s) {} are already registered with " "different letter cases and might interact with {}.".format( - _case_sensitive_reg[package][identifier], + list(case_sensitive_reg[identifier].keys()), raw_identifier)) # If a function with the same lowercase identifier is registered, # then these 2 functions are considered as case-sensitive. - if len(_case_sensitive_reg[package][identifier]) == 1: - old_fn = reg[identifier] - del reg[identifier] - reg[_case_sensitive_reg[package][identifier][0]] = old_fn + if len(case_sensitive_reg[identifier]) == 1: + reg.pop(identifier, None) + reg.update(case_sensitive_reg[identifier]) + else: warnings.warn( "The GenericFunction '{}' is already registered and " @@ -86,8 +86,8 @@ def register_function(identifier, fn, package="_default"): reg[identifier] = fn # Add the raw_identifier to the case-sensitive registry - if raw_identifier not in _case_sensitive_reg[package][identifier]: - _case_sensitive_reg[package][identifier].append(raw_identifier) + if raw_identifier not in case_sensitive_reg[identifier]: + case_sensitive_reg[identifier][raw_identifier] = fn class FunctionElement(Executable, ColumnElement, FromClause): @@ -478,12 +478,14 @@ class _FunctionGenerator(object): package = None if package is not None: + reg = _registry[package] + case_sensitive_reg = _case_sensitive_registry[package] if ( - len(_case_sensitive_reg[package].get(fname.lower(), [])) > 1 + len(case_sensitive_reg.get(fname.lower(), [])) > 1 ): - func = _registry[package].get(fname) + func = reg.get(fname) else: - func = _registry[package].get(fname.lower()) + func = reg.get(fname.lower()) if func is not None: return func(*c, **o) diff --git a/test/sql/test_deprecations.py b/test/sql/test_deprecations.py index 27cd8dcc53..4b55332b49 100644 --- a/test/sql/test_deprecations.py +++ b/test/sql/test_deprecations.py @@ -33,7 +33,9 @@ from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import engines from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures +from sqlalchemy.testing import in_ from sqlalchemy.testing import mock +from sqlalchemy.testing import not_in_ from sqlalchemy.testing.assertions import expect_warnings @@ -42,11 +44,12 @@ class DeprecationWarningsTest(fixtures.TestBase): def setup_method(self): self._registry = deepcopy(functions._registry) - self._case_sensitive_reg = deepcopy(functions._case_sensitive_reg) + self._case_sensitive_registry = deepcopy( + functions._case_sensitive_registry) def teardown_method(self): functions._registry = self._registry - functions._case_sensitive_reg = self._case_sensitive_reg + functions._case_sensitive_registry = self._case_sensitive_registry def test_ident_preparer_force(self): preparer = testing.db.dialect.identifier_preparer @@ -155,7 +158,7 @@ class DeprecationWarningsTest(fixtures.TestBase): def test_case_sensitive(self): reg = functions._registry - cs_reg = functions._case_sensitive_reg + cs_reg = functions._case_sensitive_registry class MYFUNC(GenericFunction): type = DateTime @@ -165,11 +168,11 @@ class DeprecationWarningsTest(fixtures.TestBase): assert isinstance(func.mYfUnC().type, DateTime) assert isinstance(func.myfunc().type, DateTime) - assert "myfunc" in reg['_default'] - assert "MYFUNC" not in reg['_default'] - assert "MyFunc" not in reg['_default'] - assert "myfunc" in cs_reg['_default'] - assert cs_reg['_default']['myfunc'] == ['MYFUNC'] + in_("myfunc", reg['_default']) + not_in_("MYFUNC", reg['_default']) + not_in_("MyFunc", reg['_default']) + in_("myfunc", cs_reg['_default']) + eq_(list(cs_reg['_default']['myfunc'].keys()), ['MYFUNC']) with testing.expect_deprecated(): class MyFunc(GenericFunction): @@ -182,15 +185,15 @@ class DeprecationWarningsTest(fixtures.TestBase): with pytest.raises(AssertionError): assert isinstance(func.myfunc().type, Integer) - assert "myfunc" not in reg['_default'] - assert "MYFUNC" in reg['_default'] - assert "MyFunc" in reg['_default'] - assert "myfunc" in cs_reg['_default'] - assert cs_reg['_default']['myfunc'] == ['MYFUNC', 'MyFunc'] + not_in_("myfunc", reg['_default']) + in_("MYFUNC", reg['_default']) + in_("MyFunc", reg['_default']) + in_("myfunc", cs_reg['_default']) + eq_(list(cs_reg['_default']['myfunc'].keys()), ['MYFUNC', 'MyFunc']) def test_replace_function_case_sensitive(self): reg = functions._registry - cs_reg = functions._case_sensitive_reg + cs_reg = functions._case_sensitive_registry class replaceable_func(GenericFunction): type = Integer @@ -201,11 +204,12 @@ class DeprecationWarningsTest(fixtures.TestBase): assert isinstance(func.RePlAcEaBlE_fUnC().type, Integer) assert isinstance(func.replaceable_func().type, Integer) - assert "replaceable_func" in reg['_default'] - assert "REPLACEABLE_FUNC" not in reg['_default'] - assert "Replaceable_Func" not in reg['_default'] - assert "replaceable_func" in cs_reg['_default'] - assert cs_reg['_default']['replaceable_func'] == ['REPLACEABLE_FUNC'] + in_("replaceable_func", reg['_default']) + not_in_("REPLACEABLE_FUNC", reg['_default']) + not_in_("Replaceable_Func", reg['_default']) + in_("replaceable_func", cs_reg['_default']) + eq_(list(cs_reg['_default']['replaceable_func'].keys()), + ['REPLACEABLE_FUNC']) with testing.expect_deprecated(): class Replaceable_Func(GenericFunction): @@ -217,12 +221,12 @@ class DeprecationWarningsTest(fixtures.TestBase): assert isinstance(func.RePlAcEaBlE_fUnC().type, NullType) assert isinstance(func.replaceable_func().type, NullType) - assert "replaceable_func" not in reg['_default'] - assert "REPLACEABLE_FUNC" in reg['_default'] - assert "Replaceable_Func" in reg['_default'] - assert "replaceable_func" in cs_reg['_default'] - assert cs_reg['_default']['replaceable_func'] == ['REPLACEABLE_FUNC', - 'Replaceable_Func'] + not_in_("replaceable_func", reg['_default']) + in_("REPLACEABLE_FUNC", reg['_default']) + in_("Replaceable_Func", reg['_default']) + in_("replaceable_func", cs_reg['_default']) + eq_(list(cs_reg['_default']['replaceable_func'].keys()), + ['REPLACEABLE_FUNC', 'Replaceable_Func']) with expect_warnings(): class replaceable_func_override(GenericFunction): @@ -244,13 +248,12 @@ class DeprecationWarningsTest(fixtures.TestBase): assert isinstance(func.RePlAcEaBlE_fUnC().type, NullType) assert isinstance(func.replaceable_func().type, String) - assert "replaceable_func" in reg['_default'] - assert "REPLACEABLE_FUNC" in reg['_default'] - assert "Replaceable_Func" in reg['_default'] - assert "replaceable_func" in cs_reg['_default'] - assert cs_reg['_default']['replaceable_func'] == ['REPLACEABLE_FUNC', - 'Replaceable_Func', - 'replaceable_func'] + in_("replaceable_func", reg['_default']) + in_("REPLACEABLE_FUNC", reg['_default']) + in_("Replaceable_Func", reg['_default']) + in_("replaceable_func", cs_reg['_default']) + eq_(list(cs_reg['_default']['replaceable_func'].keys()), + ['REPLACEABLE_FUNC', 'Replaceable_Func', 'replaceable_func']) class DDLListenerDeprecationsTest(fixtures.TestBase): diff --git a/test/sql/test_functions.py b/test/sql/test_functions.py index 62107ee5bd..e4b40892d0 100644 --- a/test/sql/test_functions.py +++ b/test/sql/test_functions.py @@ -57,11 +57,12 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): def setup_method(self): self._registry = deepcopy(functions._registry) - self._case_sensitive_reg = deepcopy(functions._case_sensitive_reg) + self._case_sensitive_registry = deepcopy( + functions._case_sensitive_registry) def teardown_method(self): functions._registry = self._registry - functions._case_sensitive_reg = self._case_sensitive_reg + functions._case_sensitive_registry = self._case_sensitive_registry def test_compile(self): for dialect in all_dialects(exclude=("sybase",)): @@ -95,7 +96,7 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): ) functions._registry['_default'].pop('fake_func') - functions._case_sensitive_reg['_default'].pop('fake_func') + functions._case_sensitive_registry['_default'].pop('fake_func') def test_use_labels(self): self.assert_compile(