]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Modify TypeEngine._generic_type_affinity() with updated mro search
authorAndrew Hannigan <andrewhannigan@Andrews-MacBook-Pro.local>
Thu, 19 Nov 2020 01:12:10 +0000 (19:12 -0600)
committerAndrew Hannigan <andrewhannigan@Andrews-MacBook-Pro.local>
Thu, 19 Nov 2020 01:12:10 +0000 (19:12 -0600)
lib/sqlalchemy/sql/type_api.py
test/sql/test_types.py

index e4d66c54f7254809206eed6b578ff3334e03d071..6f40f95ecaa3dc0220b1a318674b06ea0e52f9b5 100644 (file)
@@ -456,12 +456,12 @@ class TypeEngine(Traversible):
         else:
             return self.__class__
 
-    @classmethod
-    def _is_generic_type(cls):
-        n = cls.__name__
-        return n.upper() != n
-
     def _generic_type_affinity(self):
+        best_camelcase = None
+        best_uppercase = None
+
+        if not isinstance(self, (TypeEngine, UserDefinedType)):
+            return self.__class__
 
         for t in self.__class__.__mro__:
             if (
@@ -470,13 +470,15 @@ class TypeEngine(Traversible):
                     "sqlalchemy.sql.sqltypes",
                     "sqlalchemy.sql.type_api",
                 )
-                and hasattr(t, '_is_generic_type') and t._is_generic_type()
+                and issubclass(t, TypeEngine)
+                and t is not TypeEngine
             ):
-                if t in (TypeEngine, UserDefinedType):
-                    return NULLTYPE.__class__
-                return t
-        else:
-            return self.__class__
+                if t.__name__.isupper() and not best_uppercase:
+                    best_uppercase = t
+                elif not t.__name__.isupper() and not best_camelcase:
+                    best_camelcase = t
+
+        return best_camelcase or best_uppercase or NULLTYPE.__class__
 
     def as_generic(self):
         """Return an instance of the generic type corresponding to this type"""
index 570383dfedb838d2e428bea48b2f0e0ca1253049..a49307e940f6ffefd23411d9d80691a3c5669c55 100644 (file)
@@ -6,6 +6,7 @@ import operator
 import os
 
 import sqlalchemy as sa
+import sqlalchemy.dialects.postgresql as pg
 from sqlalchemy import and_
 from sqlalchemy import ARRAY
 from sqlalchemy import BigInteger
@@ -389,6 +390,8 @@ class AsGenericTest(fixtures.TestBase):
         (VARCHAR(length=100), String(length=100)),
         (NVARCHAR(length=100), Unicode(length=100)),
         (DATE(), Date()),
+        (pg.JSON(), sa.JSON()),
+        (pg.ARRAY(sa.String), sa.ARRAY(sa.String)),
     )
     def test_as_generic(self, t1, t2):
         assert repr(t1.as_generic()) == repr(t2)