From c29ef4833c4ff111925138d9ee4f7184c67b2f9a Mon Sep 17 00:00:00 2001 From: Andrew Hannigan Date: Sun, 6 Dec 2020 12:52:45 -0600 Subject: [PATCH] Incorporate PR notes --- lib/sqlalchemy/dialects/oracle/base.py | 7 ++++ lib/sqlalchemy/dialects/postgresql/base.py | 3 ++ lib/sqlalchemy/sql/type_api.py | 41 +++++++++++++++++++--- test/sql/test_types.py | 38 +++++++++++++++++++- 4 files changed, 84 insertions(+), 5 deletions(-) diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py index 7bdaafd82d..e3d6e1ed9b 100644 --- a/lib/sqlalchemy/dialects/oracle/base.py +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -666,6 +666,13 @@ class INTERVAL(sqltypes.NativeForEmulated, sqltypes._AbstractInterval): def _type_affinity(self): return sqltypes.Interval + def as_generic(self): + return sqltypes.Interval( + native=True, + second_precision=self.second_precision, + day_precision=self.day_precision, + ) + class ROWID(sqltypes.TypeEngine): """Oracle ROWID type. diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 3c33d9ee8e..2571d10de5 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -1434,6 +1434,9 @@ class INTERVAL(sqltypes.NativeForEmulated, sqltypes._AbstractInterval): def _type_affinity(self): return sqltypes.Interval + def as_generic(self): + return sqltypes.Interval(native=True, second_precision=self.precision) + @property def python_type(self): return dt.timedelta diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py index 9de93aaed5..f4372626c4 100644 --- a/lib/sqlalchemy/sql/type_api.py +++ b/lib/sqlalchemy/sql/type_api.py @@ -17,7 +17,6 @@ from .visitors import TraversibleType from .. import exc from .. import util - # these are back-assigned by sqltypes. BOOLEANTYPE = None INTEGERTYPE = None @@ -456,6 +455,7 @@ class TypeEngine(Traversible): else: return self.__class__ + @util.memoized_property def _generic_type_affinity(self): best_camelcase = None best_uppercase = None @@ -472,6 +472,7 @@ class TypeEngine(Traversible): ) and issubclass(t, TypeEngine) and t is not TypeEngine + and t.__name__[0] != "_" ): if t.__name__.isupper() and not best_uppercase: best_uppercase = t @@ -480,9 +481,11 @@ class TypeEngine(Traversible): return best_camelcase or best_uppercase or NULLTYPE.__class__ - def as_generic(self): + def as_generic(self, allow_nulltype=False): """ - Return an instance of the generic type corresponding to this type. + Return an instance of the generic type corresponding to this type using + heuristic rule. The method may be overridden if this heuristic rule is not + sufficient. >>> from sqlalchemy.dialects.mysql import INTEGER >>> INTEGER(display_width=4).as_generic() @@ -494,8 +497,38 @@ class TypeEngine(Traversible): .. versionadded:: 1.4.0b2 """ + from sqlalchemy import Enum - return util.constructor_copy(self, self._generic_type_affinity()) + if isinstance(self, Enum): + if hasattr(self, "enums"): + args = self.enums + else: + raise NotImplementedError( + "TypeEngine.as_generic() heuristic " + "is undefined for types that inherit Enum but do not have " + "an `enums` attribute." + ) + else: + args = () + + if ( + not allow_nulltype + and self._generic_type_affinity == NULLTYPE.__class__ + ): + raise NotImplementedError( + "Default TypeEngine.as_generic() " + "heuristic method was unsuccessful for {}. A custom " + "as_generic() method must be implemented for this " + "type class.".format( + self.__class__.__module__ + "." + self.__class__.__name__ + ) + ) + + return util.constructor_copy(self, self._generic_type_affinity, *args) + + @classmethod + def _uses_as_generic_heuristic(cls): + return cls.as_generic == TypeEngine.as_generic def dialect_impl(self, dialect): """Return a dialect-specific implementation for this diff --git a/test/sql/test_types.py b/test/sql/test_types.py index 3c9e08fff9..6fe413ab61 100644 --- a/test/sql/test_types.py +++ b/test/sql/test_types.py @@ -58,6 +58,7 @@ from sqlalchemy import types from sqlalchemy import Unicode from sqlalchemy import util from sqlalchemy import VARCHAR +import sqlalchemy.dialects.mysql as mysql import sqlalchemy.dialects.postgresql as pg from sqlalchemy.engine import default from sqlalchemy.schema import AddConstraint @@ -70,6 +71,7 @@ from sqlalchemy.sql import operators from sqlalchemy.sql import sqltypes from sqlalchemy.sql import table from sqlalchemy.sql import visitors +from sqlalchemy.sql.sqltypes import TypeEngine from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import AssertsCompiledSQL @@ -392,6 +394,9 @@ class AsGenericTest(fixtures.TestBase): (DATE(), Date()), (pg.JSON(), sa.JSON()), (pg.ARRAY(sa.String), sa.ARRAY(sa.String)), + (Enum("a", "b", "c"), Enum("a", "b", "c")), + (pg.ENUM("a", "b", "c"), Enum("a", "b", "c")), + (mysql.ENUM("a", "b", "c"), Enum("a", "b", "c")), ) def test_as_generic(self, t1, t2): assert repr(t1.as_generic()) == repr(t2) @@ -403,7 +408,38 @@ class AsGenericTest(fixtures.TestBase): else: t1 = type_() - t1.as_generic() + try: + gentype = t1.as_generic() + except NotImplementedError as e: + pass + else: + if t1.__class__._uses_as_generic_heuristic(): + assert isinstance(t1, gentype.__class__) + + assert isinstance(gentype, TypeEngine) + + @testing.combinations(*[(t,) for t in _all_types(omit_special_types=True)]) + def test_as_generic_all_types_allow_nulltype(self, type_): + if issubclass(type_, ARRAY): + t1 = type_(String) + else: + t1 = type_() + + # The `allow_nulltype` argument may not be available in custom + # implementations of as_generic() which override the + # TypeEngine.as_generic heuristic. + if t1.__class__._uses_as_generic_heuristic(): + gentype = t1.as_generic(allow_nulltype=True) + else: + gentype = t1.as_generic() + + if isinstance(gentype, types.NULLTYPE.__class__): + return + + if t1.__class__._uses_as_generic_heuristic(): + assert isinstance(t1, gentype.__class__) + + assert isinstance(gentype, TypeEngine) class PickleTypesTest(fixtures.TestBase): -- 2.47.3