]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Use lower case by default.
authorAdrien Berchet <adrien.berchet@gmail.com>
Wed, 27 Mar 2019 19:03:14 +0000 (20:03 +0100)
committerAdrien Berchet <adrien.berchet@gmail.com>
Wed, 27 Mar 2019 19:30:08 +0000 (20:30 +0100)
lib/sqlalchemy/sql/functions.py
test/sql/test_functions.py

index aadbb73d813e236a46a6d4d08a8a54fe5d18b70e..c914b3dd630504d6811ae4c61834dc82ae6a78ce 100644 (file)
@@ -8,6 +8,9 @@
 """SQL function API, factories, and built-in functions.
 
 """
+from functools import partial
+import warnings
+
 from . import annotation
 from . import operators
 from . import schema
@@ -33,15 +36,17 @@ from .selectable import Alias
 from .selectable import FromClause
 from .selectable import Select
 from .visitors import VisitableType
+from .. import exc as sa_exc
 from .. import util
 
 
+_RegistryList = partial(util.defaultdict, list)
+
 _registry = util.defaultdict(dict)
-_case_insensitive_functions = util.defaultdict(dict)
+_case_sensitive_reg = util.defaultdict(_RegistryList)
 
 
-def register_function(identifier, fn, package="_default",
-                      case_sensitive=True):
+def register_function(identifier, fn, package="_default"):
     """Associate a callable with a particular func. name.
 
     This is normally called by _GenericMeta, but is also
@@ -51,11 +56,37 @@ def register_function(identifier, fn, package="_default",
 
     """
     reg = _registry[package]
-    reg[identifier] = fn
-    if not case_sensitive:
-        _case_insensitive_functions[package][identifier.lower()] = identifier
-    elif identifier.lower() in _case_insensitive_functions[package]:
-        del _case_insensitive_functions[package][identifier.lower()]
+    raw_identifier = identifier
+    identifier = identifier.lower()
+
+    # Check if a function with the same lowercase identifier is registered.
+    if identifier in _case_sensitive_reg[package]:
+        if (
+            raw_identifier not in _case_sensitive_reg[package][identifier]
+        ):
+            warnings.warn(
+                "GenericFunction(s) '{}' are already registered with "
+                "different letter cases and might interact with {}.".format(
+                    _case_sensitive_reg[package][identifier],
+                    raw_identifier),
+                sa_exc.SADeprecationWarning)
+        else:
+                warnings.warn(
+                    "The GenericFunction '{}' is already registered and "
+                    "is going to be overriden.".format(identifier))
+
+        # If a function with the same lowercase identifier is registered, then
+        # these 2 functions are considered as case-sensitive.
+        if len(_case_sensitive_reg[package][identifier]) == 1:
+            old_fn = reg[identifier]
+            del reg[identifier]
+            reg[_case_sensitive_reg[package][identifier][0]] = old_fn
+        reg[raw_identifier] = fn
+    else:
+        reg[identifier] = fn
+
+    if raw_identifier not in _case_sensitive_reg[package][identifier]:
+        _case_sensitive_reg[package][identifier].append(raw_identifier)
 
 
 class FunctionElement(Executable, ColumnElement, FromClause):
@@ -446,8 +477,12 @@ class _FunctionGenerator(object):
             package = None
 
         if package is not None:
-            func = _registry[package].get(fname) or _registry[package].get(
-                _case_insensitive_functions[package].get(fname.lower()))
+            if (
+                len(_case_sensitive_reg[package].get(fname.lower(), [])) > 1
+            ):
+                func = _registry[package].get(fname)
+            else:
+                func = _registry[package].get(fname.lower())
             if func is not None:
                 return func(*c, **o)
 
@@ -577,13 +612,11 @@ class _GenericMeta(VisitableType):
         if annotation.Annotated not in cls.__mro__:
             cls.name = name = clsdict.get("name", clsname)
             cls.identifier = identifier = clsdict.get("identifier", name)
-            cls.case_sensitive = case_sensitive = clsdict.get(
-                "case_sensitive", True)
             package = clsdict.pop("package", "_default")
             # legacy
             if "__return_type__" in clsdict:
                 cls.type = clsdict["__return_type__"]
-            register_function(identifier, cls, package, case_sensitive)
+            register_function(identifier, cls, package)
         super(_GenericMeta, cls).__init__(clsname, bases, clsdict)
 
 
index a2ed9f0486e1a861fc10dae99cccbf00e07912f5..0518ee4e49b56404648d719762b8be5149e54801 100644 (file)
@@ -33,6 +33,7 @@ from sqlalchemy.sql import table
 from sqlalchemy.sql.compiler import BIND_TEMPLATES
 from sqlalchemy.sql.functions import FunctionElement
 from sqlalchemy.sql.functions import GenericFunction
+from sqlalchemy.sql.sqltypes import NullType
 from sqlalchemy.testing import assert_raises
 from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.testing import AssertsCompiledSQL
@@ -56,7 +57,7 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
 
     def tear_down(self):
         functions._registry.clear()
-        functions._case_insensitive_functions.clear()
+        functions._case_sensitive_functions.clear()
 
     def test_compile(self):
         for dialect in all_dialects(exclude=("sybase",)):
@@ -102,7 +103,7 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
 
     def test_uppercase(self):
         # for now, we need to keep case insensitivity
-        self.assert_compile(func.NOW(), "NOW()")
+        self.assert_compile(func.UNREGISTERED_FN(), "UNREGISTERED_FN()")
 
     def test_uppercase_packages(self):
         # for now, we need to keep case insensitivity
@@ -219,21 +220,50 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
         class myfunc(GenericFunction):
             __return_type__ = DateTime
 
-        with pytest.raises(AssertionError):
-            assert isinstance(func.MyFunc().type, DateTime)
-        with pytest.raises(AssertionError):
-            assert isinstance(func.mYfUnC().type, DateTime)
         assert isinstance(func.myfunc().type, DateTime)
 
-    def test_custom_legacy_type_case_insensitive(self):
-        # in case someone was using this system
-        class MyFunc(GenericFunction):
+    def test_replace_function(self):
+
+        class replacable_func(GenericFunction):
+            __return_type__ = Integer
+            identifier = 'replacable_func'
+
+        assert isinstance(func.Replacable_Func().type, Integer)
+        assert isinstance(func.RePlAcaBlE_fUnC().type, Integer)
+        assert isinstance(func.replacable_func().type, Integer)
+
+        with testing.expect_deprecated():
+            class Replacable_Func(GenericFunction):
+                __return_type__ = DateTime
+                identifier = 'Replacable_Func'
+
+        assert isinstance(func.Replacable_Func().type, DateTime)
+        assert isinstance(func.RePlAcaBlE_fUnC().type, NullType)
+        assert isinstance(func.replacable_func().type, Integer)
+
+        class replacable_func_override(GenericFunction):
             __return_type__ = DateTime
-            case_sensitive = False
+            identifier = 'replacable_func'
 
-        assert isinstance(func.MyFunc().type, DateTime)
-        assert isinstance(func.mYfUnC().type, DateTime)
-        assert isinstance(func.myfunc().type, DateTime)
+        class Replacable_Func_override(GenericFunction):
+            __return_type__ = Integer
+            identifier = 'Replacable_Func'
+
+        assert isinstance(func.Replacable_Func().type, Integer)
+        assert isinstance(func.RePlAcaBlE_fUnC().type, NullType)
+        assert isinstance(func.replacable_func().type, DateTime)
+
+    def test_custom_legacy_case_sensitive(self):
+        # in case someone was using this system
+        with testing.expect_deprecated():
+            class MyFunc(GenericFunction):
+                __return_type__ = Integer
+
+        assert isinstance(func.MyFunc().type, Integer)
+        with pytest.raises(AssertionError):
+            assert isinstance(func.mYfUnC().type, Integer)
+        with pytest.raises(AssertionError):
+            assert isinstance(func.myfunc().type, Integer)
 
     def test_custom_w_custom_name(self):
         class myfunc(GenericFunction):
@@ -287,7 +317,6 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
             type = Integer
             name = "BufferFour"
             identifier = "Buf4"
-            case_sensitive = False
 
         self.assert_compile(func.geo.buf1(), "BufferOne()")
         self.assert_compile(func.buf2(), "BufferTwo()")
@@ -298,17 +327,21 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
         self.assert_compile(func.bUf4_(), "BufferFour()")
         self.assert_compile(func.buf4(), "BufferFour()")
 
-        class geobufferfour(GenericFunction):
-            type = Integer
-            name = "BufferFour"
-            identifier = "Buf4"
-            case_sensitive = True
+        with testing.expect_deprecated():
+            class geobufferfour_lowercase(GenericFunction):
+                type = Integer
+                name = "BufferFour_lowercase"
+                identifier = "buf4"
+
+        with testing.expect_deprecated():
+            class geobufferfour_random_case(GenericFunction):
+                type = Integer
+                name = "BuFferFouR"
+                identifier = "BuF4"
 
         self.assert_compile(func.Buf4(), "BufferFour()")
-        with pytest.raises(AssertionError):
-            self.assert_compile(func.BuF4(), "BufferFour()")
-        with pytest.raises(AssertionError):
-            self.assert_compile(func.buf4(), "BufferFour()")
+        self.assert_compile(func.BuF4(), "BuFferFouR()")
+        self.assert_compile(func.buf4(), "BufferFour_lowercase()")
 
     def test_custom_args(self):
         class myfunc(GenericFunction):