"""SQL function API, factories, and built-in functions.
"""
-from functools import partial
import warnings
from . import annotation
from .. import util
-_RegistryList = partial(util.defaultdict, list)
-
_registry = util.defaultdict(dict)
-_case_sensitive_reg = util.defaultdict(_RegistryList)
+_case_sensitive_registry = util.defaultdict(
+ lambda: util.defaultdict(dict)
+)
def register_function(identifier, fn, package="_default"):
"""
reg = _registry[package]
+ case_sensitive_reg = _case_sensitive_registry[package]
raw_identifier = identifier
identifier = identifier.lower()
# Check if a function with the same lowercase identifier is registered.
- if identifier in _case_sensitive_reg[package]:
+ if case_sensitive_reg[identifier]:
if (
- raw_identifier not in _case_sensitive_reg[package][identifier]
+ raw_identifier not in case_sensitive_reg[identifier]
):
util.warn_deprecated(
- "GenericFunction(s) '{}' are already registered with "
+ "GenericFunction(s) {} are already registered with "
"different letter cases and might interact with {}.".format(
- _case_sensitive_reg[package][identifier],
+ list(case_sensitive_reg[identifier].keys()),
raw_identifier))
# If a function with the same lowercase identifier is registered,
# then these 2 functions are considered as case-sensitive.
- if len(_case_sensitive_reg[package][identifier]) == 1:
- old_fn = reg[identifier]
- del reg[identifier]
- reg[_case_sensitive_reg[package][identifier][0]] = old_fn
+ if len(case_sensitive_reg[identifier]) == 1:
+ reg.pop(identifier, None)
+ reg.update(case_sensitive_reg[identifier])
+
else:
warnings.warn(
"The GenericFunction '{}' is already registered and "
reg[identifier] = fn
# Add the raw_identifier to the case-sensitive registry
- if raw_identifier not in _case_sensitive_reg[package][identifier]:
- _case_sensitive_reg[package][identifier].append(raw_identifier)
+ if raw_identifier not in case_sensitive_reg[identifier]:
+ case_sensitive_reg[identifier][raw_identifier] = fn
class FunctionElement(Executable, ColumnElement, FromClause):
package = None
if package is not None:
+ reg = _registry[package]
+ case_sensitive_reg = _case_sensitive_registry[package]
if (
- len(_case_sensitive_reg[package].get(fname.lower(), [])) > 1
+ len(case_sensitive_reg.get(fname.lower(), [])) > 1
):
- func = _registry[package].get(fname)
+ func = reg.get(fname)
else:
- func = _registry[package].get(fname.lower())
+ func = reg.get(fname.lower())
if func is not None:
return func(*c, **o)
from sqlalchemy.testing import engines
from sqlalchemy.testing import eq_
from sqlalchemy.testing import fixtures
+from sqlalchemy.testing import in_
from sqlalchemy.testing import mock
+from sqlalchemy.testing import not_in_
from sqlalchemy.testing.assertions import expect_warnings
def setup_method(self):
self._registry = deepcopy(functions._registry)
- self._case_sensitive_reg = deepcopy(functions._case_sensitive_reg)
+ self._case_sensitive_registry = deepcopy(
+ functions._case_sensitive_registry)
def teardown_method(self):
functions._registry = self._registry
- functions._case_sensitive_reg = self._case_sensitive_reg
+ functions._case_sensitive_registry = self._case_sensitive_registry
def test_ident_preparer_force(self):
preparer = testing.db.dialect.identifier_preparer
def test_case_sensitive(self):
reg = functions._registry
- cs_reg = functions._case_sensitive_reg
+ cs_reg = functions._case_sensitive_registry
class MYFUNC(GenericFunction):
type = DateTime
assert isinstance(func.mYfUnC().type, DateTime)
assert isinstance(func.myfunc().type, DateTime)
- assert "myfunc" in reg['_default']
- assert "MYFUNC" not in reg['_default']
- assert "MyFunc" not in reg['_default']
- assert "myfunc" in cs_reg['_default']
- assert cs_reg['_default']['myfunc'] == ['MYFUNC']
+ in_("myfunc", reg['_default'])
+ not_in_("MYFUNC", reg['_default'])
+ not_in_("MyFunc", reg['_default'])
+ in_("myfunc", cs_reg['_default'])
+ eq_(list(cs_reg['_default']['myfunc'].keys()), ['MYFUNC'])
with testing.expect_deprecated():
class MyFunc(GenericFunction):
with pytest.raises(AssertionError):
assert isinstance(func.myfunc().type, Integer)
- assert "myfunc" not in reg['_default']
- assert "MYFUNC" in reg['_default']
- assert "MyFunc" in reg['_default']
- assert "myfunc" in cs_reg['_default']
- assert cs_reg['_default']['myfunc'] == ['MYFUNC', 'MyFunc']
+ not_in_("myfunc", reg['_default'])
+ in_("MYFUNC", reg['_default'])
+ in_("MyFunc", reg['_default'])
+ in_("myfunc", cs_reg['_default'])
+ eq_(list(cs_reg['_default']['myfunc'].keys()), ['MYFUNC', 'MyFunc'])
def test_replace_function_case_sensitive(self):
reg = functions._registry
- cs_reg = functions._case_sensitive_reg
+ cs_reg = functions._case_sensitive_registry
class replaceable_func(GenericFunction):
type = Integer
assert isinstance(func.RePlAcEaBlE_fUnC().type, Integer)
assert isinstance(func.replaceable_func().type, Integer)
- assert "replaceable_func" in reg['_default']
- assert "REPLACEABLE_FUNC" not in reg['_default']
- assert "Replaceable_Func" not in reg['_default']
- assert "replaceable_func" in cs_reg['_default']
- assert cs_reg['_default']['replaceable_func'] == ['REPLACEABLE_FUNC']
+ in_("replaceable_func", reg['_default'])
+ not_in_("REPLACEABLE_FUNC", reg['_default'])
+ not_in_("Replaceable_Func", reg['_default'])
+ in_("replaceable_func", cs_reg['_default'])
+ eq_(list(cs_reg['_default']['replaceable_func'].keys()),
+ ['REPLACEABLE_FUNC'])
with testing.expect_deprecated():
class Replaceable_Func(GenericFunction):
assert isinstance(func.RePlAcEaBlE_fUnC().type, NullType)
assert isinstance(func.replaceable_func().type, NullType)
- assert "replaceable_func" not in reg['_default']
- assert "REPLACEABLE_FUNC" in reg['_default']
- assert "Replaceable_Func" in reg['_default']
- assert "replaceable_func" in cs_reg['_default']
- assert cs_reg['_default']['replaceable_func'] == ['REPLACEABLE_FUNC',
- 'Replaceable_Func']
+ not_in_("replaceable_func", reg['_default'])
+ in_("REPLACEABLE_FUNC", reg['_default'])
+ in_("Replaceable_Func", reg['_default'])
+ in_("replaceable_func", cs_reg['_default'])
+ eq_(list(cs_reg['_default']['replaceable_func'].keys()),
+ ['REPLACEABLE_FUNC', 'Replaceable_Func'])
with expect_warnings():
class replaceable_func_override(GenericFunction):
assert isinstance(func.RePlAcEaBlE_fUnC().type, NullType)
assert isinstance(func.replaceable_func().type, String)
- assert "replaceable_func" in reg['_default']
- assert "REPLACEABLE_FUNC" in reg['_default']
- assert "Replaceable_Func" in reg['_default']
- assert "replaceable_func" in cs_reg['_default']
- assert cs_reg['_default']['replaceable_func'] == ['REPLACEABLE_FUNC',
- 'Replaceable_Func',
- 'replaceable_func']
+ in_("replaceable_func", reg['_default'])
+ in_("REPLACEABLE_FUNC", reg['_default'])
+ in_("Replaceable_Func", reg['_default'])
+ in_("replaceable_func", cs_reg['_default'])
+ eq_(list(cs_reg['_default']['replaceable_func'].keys()),
+ ['REPLACEABLE_FUNC', 'Replaceable_Func', 'replaceable_func'])
class DDLListenerDeprecationsTest(fixtures.TestBase):