_registry = util.defaultdict(dict)
+_case_insensitive_functions = util.defaultdict(dict)
-def register_function(identifier, fn, package="_default"):
+def register_function(identifier, fn, package="_default",
+ case_sensitive=True):
"""Associate a callable with a particular func. name.
This is normally called by _GenericMeta, but is also
"""
reg = _registry[package]
reg[identifier] = fn
+ if not case_sensitive:
+ _case_insensitive_functions[package][identifier.lower()] = identifier
+ elif identifier.lower() in _case_insensitive_functions[package]:
+ del _case_insensitive_functions[package][identifier.lower()]
class FunctionElement(Executable, ColumnElement, FromClause):
package = None
if package is not None:
- func = _registry[package].get(fname)
+ func = _registry[package].get(fname) or _registry[package].get(
+ _case_insensitive_functions[package].get(fname.lower()))
if func is not None:
return func(*c, **o)
if annotation.Annotated not in cls.__mro__:
cls.name = name = clsdict.get("name", clsname)
cls.identifier = identifier = clsdict.get("identifier", name)
+ cls.case_sensitive = case_sensitive = clsdict.get(
+ "case_sensitive", True)
package = clsdict.pop("package", "_default")
# legacy
if "__return_type__" in clsdict:
cls.type = clsdict["__return_type__"]
- register_function(identifier, cls, package)
+ register_function(identifier, cls, package, case_sensitive)
super(_GenericMeta, cls).__init__(clsname, bases, clsdict)
import datetime
import decimal
+import pytest
from sqlalchemy import ARRAY
from sqlalchemy import bindparam
def tear_down(self):
functions._registry.clear()
+ functions._case_insensitive_functions.clear()
def test_compile(self):
for dialect in all_dialects(exclude=("sybase",)):
class myfunc(GenericFunction):
__return_type__ = DateTime
+ with pytest.raises(AssertionError):
+ assert isinstance(func.MyFunc().type, DateTime)
+ with pytest.raises(AssertionError):
+ assert isinstance(func.mYfUnC().type, DateTime)
+ assert isinstance(func.myfunc().type, DateTime)
+
+ def test_custom_legacy_type_case_insensitive(self):
+ # in case someone was using this system
+ class MyFunc(GenericFunction):
+ __return_type__ = DateTime
+ case_sensitive = False
+
+ assert isinstance(func.MyFunc().type, DateTime)
+ assert isinstance(func.mYfUnC().type, DateTime)
assert isinstance(func.myfunc().type, DateTime)
def test_custom_w_custom_name(self):
type = Integer
identifier = "buf3"
+ class GeoBufferFour(GenericFunction):
+ type = Integer
+ name = "BufferFour"
+ identifier = "Buf4"
+ case_sensitive = False
+
self.assert_compile(func.geo.buf1(), "BufferOne()")
self.assert_compile(func.buf2(), "BufferTwo()")
self.assert_compile(func.buf3(), "BufferThree()")
+ self.assert_compile(func.Buf4(), "BufferFour()")
+ self.assert_compile(func.BuF4(), "BufferFour()")
+ self.assert_compile(func.bUf4(), "BufferFour()")
+ self.assert_compile(func.bUf4_(), "BufferFour()")
+ self.assert_compile(func.buf4(), "BufferFour()")
+
+ class geobufferfour(GenericFunction):
+ type = Integer
+ name = "BufferFour"
+ identifier = "Buf4"
+ case_sensitive = True
+
+ self.assert_compile(func.Buf4(), "BufferFour()")
+ with pytest.raises(AssertionError):
+ self.assert_compile(func.BuF4(), "BufferFour()")
+ with pytest.raises(AssertionError):
+ self.assert_compile(func.buf4(), "BufferFour()")
def test_custom_args(self):
class myfunc(GenericFunction):