]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Incorporate PR notes
authorAndrew Hannigan <andrewhannigan@Andrews-MacBook-Pro.local>
Sun, 6 Dec 2020 18:52:45 +0000 (12:52 -0600)
committerAndrew Hannigan <andrewhannigan@Andrews-MacBook-Pro.local>
Sun, 6 Dec 2020 18:52:45 +0000 (12:52 -0600)
lib/sqlalchemy/dialects/oracle/base.py
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/sql/type_api.py
test/sql/test_types.py

index 7bdaafd82d310f9be01eb57c86582a2d02a83c4c..e3d6e1ed9b146b64a22b32a23efd30634997d9ae 100644 (file)
@@ -666,6 +666,13 @@ class INTERVAL(sqltypes.NativeForEmulated, sqltypes._AbstractInterval):
     def _type_affinity(self):
         return sqltypes.Interval
 
+    def as_generic(self):
+        return sqltypes.Interval(
+            native=True,
+            second_precision=self.second_precision,
+            day_precision=self.day_precision,
+        )
+
 
 class ROWID(sqltypes.TypeEngine):
     """Oracle ROWID type.
index 3c33d9ee8ec5964765eb6ec7cd4bc6e97273afa7..2571d10de51afa2ceb3c18e37de3374a3368cd3d 100644 (file)
@@ -1434,6 +1434,9 @@ class INTERVAL(sqltypes.NativeForEmulated, sqltypes._AbstractInterval):
     def _type_affinity(self):
         return sqltypes.Interval
 
+    def as_generic(self):
+        return sqltypes.Interval(native=True, second_precision=self.precision)
+
     @property
     def python_type(self):
         return dt.timedelta
index 9de93aaed54afc80ca9cd071addb719fafa0b5fa..f4372626c409365fa3489ca0ba08453e9bf1f74e 100644 (file)
@@ -17,7 +17,6 @@ from .visitors import TraversibleType
 from .. import exc
 from .. import util
 
-
 # these are back-assigned by sqltypes.
 BOOLEANTYPE = None
 INTEGERTYPE = None
@@ -456,6 +455,7 @@ class TypeEngine(Traversible):
         else:
             return self.__class__
 
+    @util.memoized_property
     def _generic_type_affinity(self):
         best_camelcase = None
         best_uppercase = None
@@ -472,6 +472,7 @@ class TypeEngine(Traversible):
                 )
                 and issubclass(t, TypeEngine)
                 and t is not TypeEngine
+                and t.__name__[0] != "_"
             ):
                 if t.__name__.isupper() and not best_uppercase:
                     best_uppercase = t
@@ -480,9 +481,11 @@ class TypeEngine(Traversible):
 
         return best_camelcase or best_uppercase or NULLTYPE.__class__
 
-    def as_generic(self):
+    def as_generic(self, allow_nulltype=False):
         """
-        Return an instance of the generic type corresponding to this type.
+        Return an instance of the generic type corresponding to this type using
+        heuristic rule. The method may be overridden if this heuristic rule is not
+        sufficient.
 
         >>> from sqlalchemy.dialects.mysql import INTEGER
         >>> INTEGER(display_width=4).as_generic()
@@ -494,8 +497,38 @@ class TypeEngine(Traversible):
 
         .. versionadded:: 1.4.0b2
         """
+        from sqlalchemy import Enum
 
-        return util.constructor_copy(self, self._generic_type_affinity())
+        if isinstance(self, Enum):
+            if hasattr(self, "enums"):
+                args = self.enums
+            else:
+                raise NotImplementedError(
+                    "TypeEngine.as_generic() heuristic "
+                    "is undefined for types that inherit Enum but do not have "
+                    "an `enums` attribute."
+                )
+        else:
+            args = ()
+
+        if (
+            not allow_nulltype
+            and self._generic_type_affinity == NULLTYPE.__class__
+        ):
+            raise NotImplementedError(
+                "Default TypeEngine.as_generic() "
+                "heuristic method was unsuccessful for {}. A custom "
+                "as_generic() method must be implemented for this "
+                "type class.".format(
+                    self.__class__.__module__ + "." + self.__class__.__name__
+                )
+            )
+
+        return util.constructor_copy(self, self._generic_type_affinity, *args)
+
+    @classmethod
+    def _uses_as_generic_heuristic(cls):
+        return cls.as_generic == TypeEngine.as_generic
 
     def dialect_impl(self, dialect):
         """Return a dialect-specific implementation for this
index 3c9e08fff90d2706c191a491b96da8b5794026fc..6fe413ab6176066da5696833cb0e730df096a49c 100644 (file)
@@ -58,6 +58,7 @@ from sqlalchemy import types
 from sqlalchemy import Unicode
 from sqlalchemy import util
 from sqlalchemy import VARCHAR
+import sqlalchemy.dialects.mysql as mysql
 import sqlalchemy.dialects.postgresql as pg
 from sqlalchemy.engine import default
 from sqlalchemy.schema import AddConstraint
@@ -70,6 +71,7 @@ from sqlalchemy.sql import operators
 from sqlalchemy.sql import sqltypes
 from sqlalchemy.sql import table
 from sqlalchemy.sql import visitors
+from sqlalchemy.sql.sqltypes import TypeEngine
 from sqlalchemy.testing import assert_raises
 from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.testing import AssertsCompiledSQL
@@ -392,6 +394,9 @@ class AsGenericTest(fixtures.TestBase):
         (DATE(), Date()),
         (pg.JSON(), sa.JSON()),
         (pg.ARRAY(sa.String), sa.ARRAY(sa.String)),
+        (Enum("a", "b", "c"), Enum("a", "b", "c")),
+        (pg.ENUM("a", "b", "c"), Enum("a", "b", "c")),
+        (mysql.ENUM("a", "b", "c"), Enum("a", "b", "c")),
     )
     def test_as_generic(self, t1, t2):
         assert repr(t1.as_generic()) == repr(t2)
@@ -403,7 +408,38 @@ class AsGenericTest(fixtures.TestBase):
         else:
             t1 = type_()
 
-        t1.as_generic()
+        try:
+            gentype = t1.as_generic()
+        except NotImplementedError as e:
+            pass
+        else:
+            if t1.__class__._uses_as_generic_heuristic():
+                assert isinstance(t1, gentype.__class__)
+
+            assert isinstance(gentype, TypeEngine)
+
+    @testing.combinations(*[(t,) for t in _all_types(omit_special_types=True)])
+    def test_as_generic_all_types_allow_nulltype(self, type_):
+        if issubclass(type_, ARRAY):
+            t1 = type_(String)
+        else:
+            t1 = type_()
+
+        # The `allow_nulltype` argument may not be available in custom
+        # implementations of as_generic() which override the
+        # TypeEngine.as_generic heuristic.
+        if t1.__class__._uses_as_generic_heuristic():
+            gentype = t1.as_generic(allow_nulltype=True)
+        else:
+            gentype = t1.as_generic()
+
+        if isinstance(gentype, types.NULLTYPE.__class__):
+            return
+
+        if t1.__class__._uses_as_generic_heuristic():
+            assert isinstance(t1, gentype.__class__)
+
+        assert isinstance(gentype, TypeEngine)
 
 
 class PickleTypesTest(fixtures.TestBase):