"""SQL function API, factories, and built-in functions.
"""
+from functools import partial
+import warnings
+
from . import annotation
from . import operators
from . import schema
from .selectable import FromClause
from .selectable import Select
from .visitors import VisitableType
+from .. import exc as sa_exc
from .. import util
+_RegistryList = partial(util.defaultdict, list)
+
_registry = util.defaultdict(dict)
-_case_insensitive_functions = util.defaultdict(dict)
+_case_sensitive_reg = util.defaultdict(_RegistryList)
-def register_function(identifier, fn, package="_default",
- case_sensitive=True):
+def register_function(identifier, fn, package="_default"):
"""Associate a callable with a particular func. name.
This is normally called by _GenericMeta, but is also
"""
reg = _registry[package]
- reg[identifier] = fn
- if not case_sensitive:
- _case_insensitive_functions[package][identifier.lower()] = identifier
- elif identifier.lower() in _case_insensitive_functions[package]:
- del _case_insensitive_functions[package][identifier.lower()]
+ raw_identifier = identifier
+ identifier = identifier.lower()
+
+ # Check if a function with the same lowercase identifier is registered.
+ if identifier in _case_sensitive_reg[package]:
+ if (
+ raw_identifier not in _case_sensitive_reg[package][identifier]
+ ):
+ warnings.warn(
+ "GenericFunction(s) '{}' are already registered with "
+ "different letter cases and might interact with {}.".format(
+ _case_sensitive_reg[package][identifier],
+ raw_identifier),
+ sa_exc.SADeprecationWarning)
+ else:
+ warnings.warn(
+ "The GenericFunction '{}' is already registered and "
+ "is going to be overriden.".format(identifier))
+
+ # If a function with the same lowercase identifier is registered, then
+ # these 2 functions are considered as case-sensitive.
+ if len(_case_sensitive_reg[package][identifier]) == 1:
+ old_fn = reg[identifier]
+ del reg[identifier]
+ reg[_case_sensitive_reg[package][identifier][0]] = old_fn
+ reg[raw_identifier] = fn
+ else:
+ reg[identifier] = fn
+
+ if raw_identifier not in _case_sensitive_reg[package][identifier]:
+ _case_sensitive_reg[package][identifier].append(raw_identifier)
class FunctionElement(Executable, ColumnElement, FromClause):
package = None
if package is not None:
- func = _registry[package].get(fname) or _registry[package].get(
- _case_insensitive_functions[package].get(fname.lower()))
+ if (
+ len(_case_sensitive_reg[package].get(fname.lower(), [])) > 1
+ ):
+ func = _registry[package].get(fname)
+ else:
+ func = _registry[package].get(fname.lower())
if func is not None:
return func(*c, **o)
if annotation.Annotated not in cls.__mro__:
cls.name = name = clsdict.get("name", clsname)
cls.identifier = identifier = clsdict.get("identifier", name)
- cls.case_sensitive = case_sensitive = clsdict.get(
- "case_sensitive", True)
package = clsdict.pop("package", "_default")
# legacy
if "__return_type__" in clsdict:
cls.type = clsdict["__return_type__"]
- register_function(identifier, cls, package, case_sensitive)
+ register_function(identifier, cls, package)
super(_GenericMeta, cls).__init__(clsname, bases, clsdict)
from sqlalchemy.sql.compiler import BIND_TEMPLATES
from sqlalchemy.sql.functions import FunctionElement
from sqlalchemy.sql.functions import GenericFunction
+from sqlalchemy.sql.sqltypes import NullType
from sqlalchemy.testing import assert_raises
from sqlalchemy.testing import assert_raises_message
from sqlalchemy.testing import AssertsCompiledSQL
def tear_down(self):
functions._registry.clear()
- functions._case_insensitive_functions.clear()
+ functions._case_sensitive_functions.clear()
def test_compile(self):
for dialect in all_dialects(exclude=("sybase",)):
def test_uppercase(self):
# for now, we need to keep case insensitivity
- self.assert_compile(func.NOW(), "NOW()")
+ self.assert_compile(func.UNREGISTERED_FN(), "UNREGISTERED_FN()")
def test_uppercase_packages(self):
# for now, we need to keep case insensitivity
class myfunc(GenericFunction):
__return_type__ = DateTime
- with pytest.raises(AssertionError):
- assert isinstance(func.MyFunc().type, DateTime)
- with pytest.raises(AssertionError):
- assert isinstance(func.mYfUnC().type, DateTime)
assert isinstance(func.myfunc().type, DateTime)
- def test_custom_legacy_type_case_insensitive(self):
- # in case someone was using this system
- class MyFunc(GenericFunction):
+ def test_replace_function(self):
+
+ class replacable_func(GenericFunction):
+ __return_type__ = Integer
+ identifier = 'replacable_func'
+
+ assert isinstance(func.Replacable_Func().type, Integer)
+ assert isinstance(func.RePlAcaBlE_fUnC().type, Integer)
+ assert isinstance(func.replacable_func().type, Integer)
+
+ with testing.expect_deprecated():
+ class Replacable_Func(GenericFunction):
+ __return_type__ = DateTime
+ identifier = 'Replacable_Func'
+
+ assert isinstance(func.Replacable_Func().type, DateTime)
+ assert isinstance(func.RePlAcaBlE_fUnC().type, NullType)
+ assert isinstance(func.replacable_func().type, Integer)
+
+ class replacable_func_override(GenericFunction):
__return_type__ = DateTime
- case_sensitive = False
+ identifier = 'replacable_func'
- assert isinstance(func.MyFunc().type, DateTime)
- assert isinstance(func.mYfUnC().type, DateTime)
- assert isinstance(func.myfunc().type, DateTime)
+ class Replacable_Func_override(GenericFunction):
+ __return_type__ = Integer
+ identifier = 'Replacable_Func'
+
+ assert isinstance(func.Replacable_Func().type, Integer)
+ assert isinstance(func.RePlAcaBlE_fUnC().type, NullType)
+ assert isinstance(func.replacable_func().type, DateTime)
+
+ def test_custom_legacy_case_sensitive(self):
+ # in case someone was using this system
+ with testing.expect_deprecated():
+ class MyFunc(GenericFunction):
+ __return_type__ = Integer
+
+ assert isinstance(func.MyFunc().type, Integer)
+ with pytest.raises(AssertionError):
+ assert isinstance(func.mYfUnC().type, Integer)
+ with pytest.raises(AssertionError):
+ assert isinstance(func.myfunc().type, Integer)
def test_custom_w_custom_name(self):
class myfunc(GenericFunction):
type = Integer
name = "BufferFour"
identifier = "Buf4"
- case_sensitive = False
self.assert_compile(func.geo.buf1(), "BufferOne()")
self.assert_compile(func.buf2(), "BufferTwo()")
self.assert_compile(func.bUf4_(), "BufferFour()")
self.assert_compile(func.buf4(), "BufferFour()")
- class geobufferfour(GenericFunction):
- type = Integer
- name = "BufferFour"
- identifier = "Buf4"
- case_sensitive = True
+ with testing.expect_deprecated():
+ class geobufferfour_lowercase(GenericFunction):
+ type = Integer
+ name = "BufferFour_lowercase"
+ identifier = "buf4"
+
+ with testing.expect_deprecated():
+ class geobufferfour_random_case(GenericFunction):
+ type = Integer
+ name = "BuFferFouR"
+ identifier = "BuF4"
self.assert_compile(func.Buf4(), "BufferFour()")
- with pytest.raises(AssertionError):
- self.assert_compile(func.BuF4(), "BufferFour()")
- with pytest.raises(AssertionError):
- self.assert_compile(func.buf4(), "BufferFour()")
+ self.assert_compile(func.BuF4(), "BuFferFouR()")
+ self.assert_compile(func.buf4(), "BufferFour_lowercase()")
def test_custom_args(self):
class myfunc(GenericFunction):