From: Adrien Berchet Date: Wed, 27 Mar 2019 19:03:14 +0000 (+0100) Subject: Use lower case by default. X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=0a8b385270d4009996c56cecb5a956bfe2f7357e;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Use lower case by default. --- diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index aadbb73d81..c914b3dd63 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -8,6 +8,9 @@ """SQL function API, factories, and built-in functions. """ +from functools import partial +import warnings + from . import annotation from . import operators from . import schema @@ -33,15 +36,17 @@ from .selectable import Alias from .selectable import FromClause from .selectable import Select from .visitors import VisitableType +from .. import exc as sa_exc from .. import util +_RegistryList = partial(util.defaultdict, list) + _registry = util.defaultdict(dict) -_case_insensitive_functions = util.defaultdict(dict) +_case_sensitive_reg = util.defaultdict(_RegistryList) -def register_function(identifier, fn, package="_default", - case_sensitive=True): +def register_function(identifier, fn, package="_default"): """Associate a callable with a particular func. name. This is normally called by _GenericMeta, but is also @@ -51,11 +56,37 @@ def register_function(identifier, fn, package="_default", """ reg = _registry[package] - reg[identifier] = fn - if not case_sensitive: - _case_insensitive_functions[package][identifier.lower()] = identifier - elif identifier.lower() in _case_insensitive_functions[package]: - del _case_insensitive_functions[package][identifier.lower()] + raw_identifier = identifier + identifier = identifier.lower() + + # Check if a function with the same lowercase identifier is registered. + if identifier in _case_sensitive_reg[package]: + if ( + raw_identifier not in _case_sensitive_reg[package][identifier] + ): + warnings.warn( + "GenericFunction(s) '{}' are already registered with " + "different letter cases and might interact with {}.".format( + _case_sensitive_reg[package][identifier], + raw_identifier), + sa_exc.SADeprecationWarning) + else: + warnings.warn( + "The GenericFunction '{}' is already registered and " + "is going to be overriden.".format(identifier)) + + # If a function with the same lowercase identifier is registered, then + # these 2 functions are considered as case-sensitive. + if len(_case_sensitive_reg[package][identifier]) == 1: + old_fn = reg[identifier] + del reg[identifier] + reg[_case_sensitive_reg[package][identifier][0]] = old_fn + reg[raw_identifier] = fn + else: + reg[identifier] = fn + + if raw_identifier not in _case_sensitive_reg[package][identifier]: + _case_sensitive_reg[package][identifier].append(raw_identifier) class FunctionElement(Executable, ColumnElement, FromClause): @@ -446,8 +477,12 @@ class _FunctionGenerator(object): package = None if package is not None: - func = _registry[package].get(fname) or _registry[package].get( - _case_insensitive_functions[package].get(fname.lower())) + if ( + len(_case_sensitive_reg[package].get(fname.lower(), [])) > 1 + ): + func = _registry[package].get(fname) + else: + func = _registry[package].get(fname.lower()) if func is not None: return func(*c, **o) @@ -577,13 +612,11 @@ class _GenericMeta(VisitableType): if annotation.Annotated not in cls.__mro__: cls.name = name = clsdict.get("name", clsname) cls.identifier = identifier = clsdict.get("identifier", name) - cls.case_sensitive = case_sensitive = clsdict.get( - "case_sensitive", True) package = clsdict.pop("package", "_default") # legacy if "__return_type__" in clsdict: cls.type = clsdict["__return_type__"] - register_function(identifier, cls, package, case_sensitive) + register_function(identifier, cls, package) super(_GenericMeta, cls).__init__(clsname, bases, clsdict) diff --git a/test/sql/test_functions.py b/test/sql/test_functions.py index a2ed9f0486..0518ee4e49 100644 --- a/test/sql/test_functions.py +++ b/test/sql/test_functions.py @@ -33,6 +33,7 @@ from sqlalchemy.sql import table from sqlalchemy.sql.compiler import BIND_TEMPLATES from sqlalchemy.sql.functions import FunctionElement from sqlalchemy.sql.functions import GenericFunction +from sqlalchemy.sql.sqltypes import NullType from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import AssertsCompiledSQL @@ -56,7 +57,7 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): def tear_down(self): functions._registry.clear() - functions._case_insensitive_functions.clear() + functions._case_sensitive_functions.clear() def test_compile(self): for dialect in all_dialects(exclude=("sybase",)): @@ -102,7 +103,7 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): def test_uppercase(self): # for now, we need to keep case insensitivity - self.assert_compile(func.NOW(), "NOW()") + self.assert_compile(func.UNREGISTERED_FN(), "UNREGISTERED_FN()") def test_uppercase_packages(self): # for now, we need to keep case insensitivity @@ -219,21 +220,50 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): class myfunc(GenericFunction): __return_type__ = DateTime - with pytest.raises(AssertionError): - assert isinstance(func.MyFunc().type, DateTime) - with pytest.raises(AssertionError): - assert isinstance(func.mYfUnC().type, DateTime) assert isinstance(func.myfunc().type, DateTime) - def test_custom_legacy_type_case_insensitive(self): - # in case someone was using this system - class MyFunc(GenericFunction): + def test_replace_function(self): + + class replacable_func(GenericFunction): + __return_type__ = Integer + identifier = 'replacable_func' + + assert isinstance(func.Replacable_Func().type, Integer) + assert isinstance(func.RePlAcaBlE_fUnC().type, Integer) + assert isinstance(func.replacable_func().type, Integer) + + with testing.expect_deprecated(): + class Replacable_Func(GenericFunction): + __return_type__ = DateTime + identifier = 'Replacable_Func' + + assert isinstance(func.Replacable_Func().type, DateTime) + assert isinstance(func.RePlAcaBlE_fUnC().type, NullType) + assert isinstance(func.replacable_func().type, Integer) + + class replacable_func_override(GenericFunction): __return_type__ = DateTime - case_sensitive = False + identifier = 'replacable_func' - assert isinstance(func.MyFunc().type, DateTime) - assert isinstance(func.mYfUnC().type, DateTime) - assert isinstance(func.myfunc().type, DateTime) + class Replacable_Func_override(GenericFunction): + __return_type__ = Integer + identifier = 'Replacable_Func' + + assert isinstance(func.Replacable_Func().type, Integer) + assert isinstance(func.RePlAcaBlE_fUnC().type, NullType) + assert isinstance(func.replacable_func().type, DateTime) + + def test_custom_legacy_case_sensitive(self): + # in case someone was using this system + with testing.expect_deprecated(): + class MyFunc(GenericFunction): + __return_type__ = Integer + + assert isinstance(func.MyFunc().type, Integer) + with pytest.raises(AssertionError): + assert isinstance(func.mYfUnC().type, Integer) + with pytest.raises(AssertionError): + assert isinstance(func.myfunc().type, Integer) def test_custom_w_custom_name(self): class myfunc(GenericFunction): @@ -287,7 +317,6 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): type = Integer name = "BufferFour" identifier = "Buf4" - case_sensitive = False self.assert_compile(func.geo.buf1(), "BufferOne()") self.assert_compile(func.buf2(), "BufferTwo()") @@ -298,17 +327,21 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): self.assert_compile(func.bUf4_(), "BufferFour()") self.assert_compile(func.buf4(), "BufferFour()") - class geobufferfour(GenericFunction): - type = Integer - name = "BufferFour" - identifier = "Buf4" - case_sensitive = True + with testing.expect_deprecated(): + class geobufferfour_lowercase(GenericFunction): + type = Integer + name = "BufferFour_lowercase" + identifier = "buf4" + + with testing.expect_deprecated(): + class geobufferfour_random_case(GenericFunction): + type = Integer + name = "BuFferFouR" + identifier = "BuF4" self.assert_compile(func.Buf4(), "BufferFour()") - with pytest.raises(AssertionError): - self.assert_compile(func.BuF4(), "BufferFour()") - with pytest.raises(AssertionError): - self.assert_compile(func.buf4(), "BufferFour()") + self.assert_compile(func.BuF4(), "BuFferFouR()") + self.assert_compile(func.buf4(), "BufferFour_lowercase()") def test_custom_args(self): class myfunc(GenericFunction):