From: Adrien Berchet Date: Sun, 24 Mar 2019 15:14:45 +0000 (+0100) Subject: Add case insensitivity feature to GenericFunction. X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=30ac79e81f07f1f70f2adb64c9b791d3c766c517;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Add case insensitivity feature to GenericFunction. --- diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index fcc843d919..aadbb73d81 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -37,9 +37,11 @@ from .. import util _registry = util.defaultdict(dict) +_case_insensitive_functions = util.defaultdict(dict) -def register_function(identifier, fn, package="_default"): +def register_function(identifier, fn, package="_default", + case_sensitive=True): """Associate a callable with a particular func. name. This is normally called by _GenericMeta, but is also @@ -50,6 +52,10 @@ def register_function(identifier, fn, package="_default"): """ reg = _registry[package] reg[identifier] = fn + if not case_sensitive: + _case_insensitive_functions[package][identifier.lower()] = identifier + elif identifier.lower() in _case_insensitive_functions[package]: + del _case_insensitive_functions[package][identifier.lower()] class FunctionElement(Executable, ColumnElement, FromClause): @@ -440,7 +446,8 @@ class _FunctionGenerator(object): package = None if package is not None: - func = _registry[package].get(fname) + func = _registry[package].get(fname) or _registry[package].get( + _case_insensitive_functions[package].get(fname.lower())) if func is not None: return func(*c, **o) @@ -570,11 +577,13 @@ class _GenericMeta(VisitableType): if annotation.Annotated not in cls.__mro__: cls.name = name = clsdict.get("name", clsname) cls.identifier = identifier = clsdict.get("identifier", name) + cls.case_sensitive = case_sensitive = clsdict.get( + "case_sensitive", True) package = clsdict.pop("package", "_default") # legacy if "__return_type__" in clsdict: cls.type = clsdict["__return_type__"] - register_function(identifier, cls, package) + register_function(identifier, cls, package, case_sensitive) super(_GenericMeta, cls).__init__(clsname, bases, clsdict) diff --git a/test/sql/test_functions.py b/test/sql/test_functions.py index 5fb4bc2e4a..a2ed9f0486 100644 --- a/test/sql/test_functions.py +++ b/test/sql/test_functions.py @@ -1,5 +1,6 @@ import datetime import decimal +import pytest from sqlalchemy import ARRAY from sqlalchemy import bindparam @@ -55,6 +56,7 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): def tear_down(self): functions._registry.clear() + functions._case_insensitive_functions.clear() def test_compile(self): for dialect in all_dialects(exclude=("sybase",)): @@ -217,6 +219,20 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): class myfunc(GenericFunction): __return_type__ = DateTime + with pytest.raises(AssertionError): + assert isinstance(func.MyFunc().type, DateTime) + with pytest.raises(AssertionError): + assert isinstance(func.mYfUnC().type, DateTime) + assert isinstance(func.myfunc().type, DateTime) + + def test_custom_legacy_type_case_insensitive(self): + # in case someone was using this system + class MyFunc(GenericFunction): + __return_type__ = DateTime + case_sensitive = False + + assert isinstance(func.MyFunc().type, DateTime) + assert isinstance(func.mYfUnC().type, DateTime) assert isinstance(func.myfunc().type, DateTime) def test_custom_w_custom_name(self): @@ -267,9 +283,32 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): type = Integer identifier = "buf3" + class GeoBufferFour(GenericFunction): + type = Integer + name = "BufferFour" + identifier = "Buf4" + case_sensitive = False + self.assert_compile(func.geo.buf1(), "BufferOne()") self.assert_compile(func.buf2(), "BufferTwo()") self.assert_compile(func.buf3(), "BufferThree()") + self.assert_compile(func.Buf4(), "BufferFour()") + self.assert_compile(func.BuF4(), "BufferFour()") + self.assert_compile(func.bUf4(), "BufferFour()") + self.assert_compile(func.bUf4_(), "BufferFour()") + self.assert_compile(func.buf4(), "BufferFour()") + + class geobufferfour(GenericFunction): + type = Integer + name = "BufferFour" + identifier = "Buf4" + case_sensitive = True + + self.assert_compile(func.Buf4(), "BufferFour()") + with pytest.raises(AssertionError): + self.assert_compile(func.BuF4(), "BufferFour()") + with pytest.raises(AssertionError): + self.assert_compile(func.buf4(), "BufferFour()") def test_custom_args(self): class myfunc(GenericFunction):