From: Andrew Hannigan Date: Thu, 19 Nov 2020 01:12:10 +0000 (-0600) Subject: Modify TypeEngine._generic_type_affinity() with updated mro search X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=ae4266d8707112d39fa5cccdab596d178f3f58f0;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Modify TypeEngine._generic_type_affinity() with updated mro search --- diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py index e4d66c54f7..6f40f95eca 100644 --- a/lib/sqlalchemy/sql/type_api.py +++ b/lib/sqlalchemy/sql/type_api.py @@ -456,12 +456,12 @@ class TypeEngine(Traversible): else: return self.__class__ - @classmethod - def _is_generic_type(cls): - n = cls.__name__ - return n.upper() != n - def _generic_type_affinity(self): + best_camelcase = None + best_uppercase = None + + if not isinstance(self, (TypeEngine, UserDefinedType)): + return self.__class__ for t in self.__class__.__mro__: if ( @@ -470,13 +470,15 @@ class TypeEngine(Traversible): "sqlalchemy.sql.sqltypes", "sqlalchemy.sql.type_api", ) - and hasattr(t, '_is_generic_type') and t._is_generic_type() + and issubclass(t, TypeEngine) + and t is not TypeEngine ): - if t in (TypeEngine, UserDefinedType): - return NULLTYPE.__class__ - return t - else: - return self.__class__ + if t.__name__.isupper() and not best_uppercase: + best_uppercase = t + elif not t.__name__.isupper() and not best_camelcase: + best_camelcase = t + + return best_camelcase or best_uppercase or NULLTYPE.__class__ def as_generic(self): """Return an instance of the generic type corresponding to this type""" diff --git a/test/sql/test_types.py b/test/sql/test_types.py index 570383dfed..a49307e940 100644 --- a/test/sql/test_types.py +++ b/test/sql/test_types.py @@ -6,6 +6,7 @@ import operator import os import sqlalchemy as sa +import sqlalchemy.dialects.postgresql as pg from sqlalchemy import and_ from sqlalchemy import ARRAY from sqlalchemy import BigInteger @@ -389,6 +390,8 @@ class AsGenericTest(fixtures.TestBase): (VARCHAR(length=100), String(length=100)), (NVARCHAR(length=100), Unicode(length=100)), (DATE(), Date()), + (pg.JSON(), sa.JSON()), + (pg.ARRAY(sa.String), sa.ARRAY(sa.String)), ) def test_as_generic(self, t1, t2): assert repr(t1.as_generic()) == repr(t2)