]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Add case insensitivity feature to GenericFunction.
authorAdrien Berchet <adrien.berchet@gmail.com>
Sun, 24 Mar 2019 15:14:45 +0000 (16:14 +0100)
committerAdrien Berchet <adrien.berchet@gmail.com>
Mon, 25 Mar 2019 12:07:42 +0000 (13:07 +0100)
lib/sqlalchemy/sql/functions.py
test/sql/test_functions.py

index fcc843d9195343c02d1c45fe103e0a88132b0650..aadbb73d813e236a46a6d4d08a8a54fe5d18b70e 100644 (file)
@@ -37,9 +37,11 @@ from .. import util
 
 
 _registry = util.defaultdict(dict)
+_case_insensitive_functions = util.defaultdict(dict)
 
 
-def register_function(identifier, fn, package="_default"):
+def register_function(identifier, fn, package="_default",
+                      case_sensitive=True):
     """Associate a callable with a particular func. name.
 
     This is normally called by _GenericMeta, but is also
@@ -50,6 +52,10 @@ def register_function(identifier, fn, package="_default"):
     """
     reg = _registry[package]
     reg[identifier] = fn
+    if not case_sensitive:
+        _case_insensitive_functions[package][identifier.lower()] = identifier
+    elif identifier.lower() in _case_insensitive_functions[package]:
+        del _case_insensitive_functions[package][identifier.lower()]
 
 
 class FunctionElement(Executable, ColumnElement, FromClause):
@@ -440,7 +446,8 @@ class _FunctionGenerator(object):
             package = None
 
         if package is not None:
-            func = _registry[package].get(fname)
+            func = _registry[package].get(fname) or _registry[package].get(
+                _case_insensitive_functions[package].get(fname.lower()))
             if func is not None:
                 return func(*c, **o)
 
@@ -570,11 +577,13 @@ class _GenericMeta(VisitableType):
         if annotation.Annotated not in cls.__mro__:
             cls.name = name = clsdict.get("name", clsname)
             cls.identifier = identifier = clsdict.get("identifier", name)
+            cls.case_sensitive = case_sensitive = clsdict.get(
+                "case_sensitive", True)
             package = clsdict.pop("package", "_default")
             # legacy
             if "__return_type__" in clsdict:
                 cls.type = clsdict["__return_type__"]
-            register_function(identifier, cls, package)
+            register_function(identifier, cls, package, case_sensitive)
         super(_GenericMeta, cls).__init__(clsname, bases, clsdict)
 
 
index 5fb4bc2e4a6da9b3f02a717adbdbc377f88d9bd1..a2ed9f0486e1a861fc10dae99cccbf00e07912f5 100644 (file)
@@ -1,5 +1,6 @@
 import datetime
 import decimal
+import pytest
 
 from sqlalchemy import ARRAY
 from sqlalchemy import bindparam
@@ -55,6 +56,7 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
 
     def tear_down(self):
         functions._registry.clear()
+        functions._case_insensitive_functions.clear()
 
     def test_compile(self):
         for dialect in all_dialects(exclude=("sybase",)):
@@ -217,6 +219,20 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
         class myfunc(GenericFunction):
             __return_type__ = DateTime
 
+        with pytest.raises(AssertionError):
+            assert isinstance(func.MyFunc().type, DateTime)
+        with pytest.raises(AssertionError):
+            assert isinstance(func.mYfUnC().type, DateTime)
+        assert isinstance(func.myfunc().type, DateTime)
+
+    def test_custom_legacy_type_case_insensitive(self):
+        # in case someone was using this system
+        class MyFunc(GenericFunction):
+            __return_type__ = DateTime
+            case_sensitive = False
+
+        assert isinstance(func.MyFunc().type, DateTime)
+        assert isinstance(func.mYfUnC().type, DateTime)
         assert isinstance(func.myfunc().type, DateTime)
 
     def test_custom_w_custom_name(self):
@@ -267,9 +283,32 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
             type = Integer
             identifier = "buf3"
 
+        class GeoBufferFour(GenericFunction):
+            type = Integer
+            name = "BufferFour"
+            identifier = "Buf4"
+            case_sensitive = False
+
         self.assert_compile(func.geo.buf1(), "BufferOne()")
         self.assert_compile(func.buf2(), "BufferTwo()")
         self.assert_compile(func.buf3(), "BufferThree()")
+        self.assert_compile(func.Buf4(), "BufferFour()")
+        self.assert_compile(func.BuF4(), "BufferFour()")
+        self.assert_compile(func.bUf4(), "BufferFour()")
+        self.assert_compile(func.bUf4_(), "BufferFour()")
+        self.assert_compile(func.buf4(), "BufferFour()")
+
+        class geobufferfour(GenericFunction):
+            type = Integer
+            name = "BufferFour"
+            identifier = "Buf4"
+            case_sensitive = True
+
+        self.assert_compile(func.Buf4(), "BufferFour()")
+        with pytest.raises(AssertionError):
+            self.assert_compile(func.BuF4(), "BufferFour()")
+        with pytest.raises(AssertionError):
+            self.assert_compile(func.buf4(), "BufferFour()")
 
     def test_custom_args(self):
         class myfunc(GenericFunction):