From 60f2da20902258573c5b50bd0ecefa256c0b24c8 Mon Sep 17 00:00:00 2001 From: Adrien Berchet Date: Fri, 29 Mar 2019 15:07:23 +0100 Subject: [PATCH] Clean code and tests. Use sqlalchemy warnings. --- lib/sqlalchemy/sql/functions.py | 25 ++++++------ test/sql/test_deprecations.py | 70 +++++++++++++++++++++++++++++++++ test/sql/test_functions.py | 58 +++++---------------------- 3 files changed, 93 insertions(+), 60 deletions(-) diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index d104925a33..3c268a5972 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -60,31 +60,32 @@ def register_function(identifier, fn, package="_default"): identifier = identifier.lower() # Check if a function with the same lowercase identifier is registered. - if identifier in _case_sensitive_reg[package]: + if identifier in _registry[package]: if ( raw_identifier not in _case_sensitive_reg[package][identifier] ): - warnings.warn( + util.warn_deprecated( "GenericFunction(s) '{}' are already registered with " "different letter cases and might interact with {}.".format( _case_sensitive_reg[package][identifier], - raw_identifier), - sa_exc.SADeprecationWarning) + raw_identifier)) + # If a function with the same lowercase identifier is registered, + # then these 2 functions are considered as case-sensitive. + if len(_case_sensitive_reg[package][identifier]) == 1: + old_fn = reg[identifier] + del reg[identifier] + reg[_case_sensitive_reg[package][identifier][0]] = old_fn else: warnings.warn( "The GenericFunction '{}' is already registered and " - "is going to be overriden.".format(identifier)) - - # If a function with the same lowercase identifier is registered, then - # these 2 functions are considered as case-sensitive. - if len(_case_sensitive_reg[package][identifier]) == 1: - old_fn = reg[identifier] - del reg[identifier] - reg[_case_sensitive_reg[package][identifier][0]] = old_fn + "is going to be overriden.".format(identifier), + sa_exc.SAWarning) + reg[raw_identifier] = fn else: reg[identifier] = fn + # Add the raw_identifier to the case-sensitive registry if raw_identifier not in _case_sensitive_reg[package][identifier]: _case_sensitive_reg[package][identifier].append(raw_identifier) diff --git a/test/sql/test_deprecations.py b/test/sql/test_deprecations.py index 7990cd56c6..f84859794a 100644 --- a/test/sql/test_deprecations.py +++ b/test/sql/test_deprecations.py @@ -1,11 +1,16 @@ #! coding: utf-8 +from copy import deepcopy +import pytest + from sqlalchemy import bindparam from sqlalchemy import Column from sqlalchemy import column from sqlalchemy import create_engine +from sqlalchemy import DateTime from sqlalchemy import exc from sqlalchemy import ForeignKey +from sqlalchemy import func from sqlalchemy import Integer from sqlalchemy import MetaData from sqlalchemy import select @@ -18,6 +23,9 @@ from sqlalchemy import util from sqlalchemy.engine import default from sqlalchemy.schema import DDL from sqlalchemy.sql import util as sql_util +from sqlalchemy.sql import functions +from sqlalchemy.sql.functions import GenericFunction +from sqlalchemy.sql.sqltypes import NullType from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import AssertsCompiledSQL @@ -25,11 +33,20 @@ from sqlalchemy.testing import engines from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures from sqlalchemy.testing import mock +from sqlalchemy.testing.assertions import expect_warnings class DeprecationWarningsTest(fixtures.TestBase): __backend__ = True + def setup_method(self): + self._registry = deepcopy(functions._registry) + self._case_sensitive_reg = deepcopy(functions._case_sensitive_reg) + + def teardown_method(self): + functions._registry = self._registry + functions._case_sensitive_reg = self._case_sensitive_reg + def test_ident_preparer_force(self): preparer = testing.db.dialect.identifier_preparer preparer.quote("hi") @@ -135,6 +152,59 @@ class DeprecationWarningsTest(fixtures.TestBase): autoload_with=testing.db, ) + def test_case_sensitive(self): + class MYFUNC(GenericFunction): + type = DateTime + + assert isinstance(func.MYFUNC().type, DateTime) + assert isinstance(func.MyFunc().type, DateTime) + assert isinstance(func.mYfUnC().type, DateTime) + assert isinstance(func.myfunc().type, DateTime) + + with testing.expect_deprecated(): + class MyFunc(GenericFunction): + type = Integer + + assert isinstance(func.MYFUNC().type, DateTime) + assert isinstance(func.MyFunc().type, Integer) + with pytest.raises(AssertionError): + assert isinstance(func.mYfUnC().type, Integer) + with pytest.raises(AssertionError): + assert isinstance(func.myfunc().type, Integer) + + def test_replace_function_case_sensitive(self): + + class replacable_func(GenericFunction): + __return_type__ = Integer + identifier = 'replacable_func' + + assert isinstance(func.Replacable_Func().type, Integer) + assert isinstance(func.RePlAcaBlE_fUnC().type, Integer) + assert isinstance(func.replacable_func().type, Integer) + + with testing.expect_deprecated(): + class Replacable_Func(GenericFunction): + __return_type__ = DateTime + identifier = 'Replacable_Func' + + assert isinstance(func.Replacable_Func().type, DateTime) + assert isinstance(func.RePlAcaBlE_fUnC().type, NullType) + assert isinstance(func.replacable_func().type, Integer) + + with expect_warnings(): + class replacable_func_override(GenericFunction): + __return_type__ = DateTime + identifier = 'replacable_func' + + with expect_warnings(): + class Replacable_Func_override(GenericFunction): + __return_type__ = Integer + identifier = 'Replacable_Func' + + assert isinstance(func.Replacable_Func().type, Integer) + assert isinstance(func.RePlAcaBlE_fUnC().type, NullType) + assert isinstance(func.replacable_func().type, DateTime) + class DDLListenerDeprecationsTest(fixtures.TestBase): def setup(self): diff --git a/test/sql/test_functions.py b/test/sql/test_functions.py index 311b476a80..62107ee5bd 100644 --- a/test/sql/test_functions.py +++ b/test/sql/test_functions.py @@ -2,8 +2,6 @@ from copy import deepcopy import datetime import decimal -import pytest - from sqlalchemy import ARRAY from sqlalchemy import bindparam from sqlalchemy import Boolean @@ -35,7 +33,6 @@ from sqlalchemy.sql import table from sqlalchemy.sql.compiler import BIND_TEMPLATES from sqlalchemy.sql.functions import FunctionElement from sqlalchemy.sql.functions import GenericFunction -from sqlalchemy.sql.sqltypes import NullType from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import AssertsCompiledSQL @@ -43,6 +40,7 @@ from sqlalchemy.testing import engines from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ +from sqlalchemy.testing.assertions import expect_warnings from sqlalchemy.testing.engines import all_dialects @@ -96,6 +94,9 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): dialect=dialect, ) + functions._registry['_default'].pop('fake_func') + functions._case_sensitive_reg['_default'].pop('fake_func') + def test_use_labels(self): self.assert_compile( select([func.foo()], use_labels=True), "SELECT foo() AS foo_1" @@ -237,46 +238,23 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): assert isinstance(func.mYfUnC().type, DateTime) assert isinstance(func.myfunc().type, DateTime) - with testing.expect_deprecated(): - class MyFunc(GenericFunction): - type = Integer - - assert isinstance(func.MYFUNC().type, DateTime) - assert isinstance(func.MyFunc().type, Integer) - with pytest.raises(AssertionError): - assert isinstance(func.mYfUnC().type, Integer) - with pytest.raises(AssertionError): - assert isinstance(func.myfunc().type, Integer) - def test_replace_function(self): class replacable_func(GenericFunction): - __return_type__ = Integer + type = Integer identifier = 'replacable_func' assert isinstance(func.Replacable_Func().type, Integer) assert isinstance(func.RePlAcaBlE_fUnC().type, Integer) assert isinstance(func.replacable_func().type, Integer) - with testing.expect_deprecated(): - class Replacable_Func(GenericFunction): - __return_type__ = DateTime - identifier = 'Replacable_Func' + with expect_warnings(): + class replacable_func_override(GenericFunction): + type = DateTime + identifier = 'replacable_func' assert isinstance(func.Replacable_Func().type, DateTime) - assert isinstance(func.RePlAcaBlE_fUnC().type, NullType) - assert isinstance(func.replacable_func().type, Integer) - - class replacable_func_override(GenericFunction): - __return_type__ = DateTime - identifier = 'replacable_func' - - class Replacable_Func_override(GenericFunction): - __return_type__ = Integer - identifier = 'Replacable_Func' - - assert isinstance(func.Replacable_Func().type, Integer) - assert isinstance(func.RePlAcaBlE_fUnC().type, NullType) + assert isinstance(func.RePlAcaBlE_fUnC().type, DateTime) assert isinstance(func.replacable_func().type, DateTime) def test_custom_w_custom_name(self): @@ -341,22 +319,6 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): self.assert_compile(func.bUf4_(), "BufferFour()") self.assert_compile(func.buf4(), "BufferFour()") - with testing.expect_deprecated(): - class geobufferfour_lowercase(GenericFunction): - type = Integer - name = "BufferFour_lowercase" - identifier = "buf4" - - with testing.expect_deprecated(): - class geobufferfour_random_case(GenericFunction): - type = Integer - name = "BuFferFouR" - identifier = "BuF4" - - self.assert_compile(func.Buf4(), "BufferFour()") - self.assert_compile(func.BuF4(), "BuFferFouR()") - self.assert_compile(func.buf4(), "BufferFour_lowercase()") - def test_custom_args(self): class myfunc(GenericFunction): pass -- 2.47.3