]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Implement `TypeEngine.as_generic`
authorGord Thompson <gord@gordthompson.com>
Mon, 7 Dec 2020 23:37:29 +0000 (18:37 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 9 Dec 2020 00:54:05 +0000 (19:54 -0500)
Added :meth:`_types.TypeEngine.as_generic` to map dialect-specific types,
such as :class:`sqlalchemy.dialects.mysql.INTEGER`, with the "best match"
generic SQLAlchemy type, in this case :class:`_types.Integer`.  Pull
request courtesy Andrew Hannigan.

Abstract away how we check for "overridden methods" so it is more
clear what the intent is and that the methodology can be
independently tested.

Fixes: #5659
Closes: #5714
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/5714
Pull-request-sha: 91afb9a0ba3bfa81a1ded80c025989213cf6e4eb

Change-Id: Ic54d6690ecc10dc69e6e72856d5620036cea472a

doc/build/changelog/unreleased_14/5659.rst [new file with mode: 0644]
doc/build/core/reflection.rst
lib/sqlalchemy/dialects/oracle/base.py
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/sql/events.py
lib/sqlalchemy/sql/sqltypes.py
lib/sqlalchemy/sql/type_api.py
lib/sqlalchemy/util/__init__.py
lib/sqlalchemy/util/langhelpers.py
test/base/test_utils.py
test/sql/test_types.py

diff --git a/doc/build/changelog/unreleased_14/5659.rst b/doc/build/changelog/unreleased_14/5659.rst
new file mode 100644 (file)
index 0000000..e57ae0b
--- /dev/null
@@ -0,0 +1,12 @@
+.. change::
+    :tags: schema, feature
+    :tickets: 5659
+
+    Added :meth:`_types.TypeEngine.as_generic` to map dialect-specific types,
+    such as :class:`sqlalchemy.dialects.mysql.INTEGER`, with the "best match"
+    generic SQLAlchemy type, in this case :class:`_types.Integer`.  Pull
+    request courtesy Andrew Hannigan.
+
+    .. seealso::
+
+      :ref:`metadata_reflection_dbagnostic_types` - example usage
index b95d3e44a5e58dbcc9f9021f9ab707302656776f..9cd75b814f9442f2e45cb12447710ca7472fcbc2 100644 (file)
@@ -140,6 +140,121 @@ database is also available. This is known as the "Inspector"::
     :members:
     :undoc-members:
 
+.. _metadata_reflection_dbagnostic_types:
+
+Reflecting with Database-Agnostic Types
+---------------------------------------
+
+When the columns of a table are reflected, using either the
+:paramref:`_schema.Table.autoload_with` parameter of :class:`_schema.Table` or
+the :meth:`_reflection.Inspector.get_columns` method of
+:class:`_reflection.Inspector`, the datatypes will be as specific as possible
+to the target database.   This means that if an "integer" datatype is reflected
+from a MySQL database, the type will be represented by the
+:class:`sqlalchemy.dialects.mysql.INTEGER` class, which includes MySQL-specific
+attributes such as "display_width".   Or on PostgreSQL, a PostgreSQL-specific
+datatype such as :class:`sqlalchemy.dialects.postgresql.INTERVAL` or
+:class:`sqlalchemy.dialects.postgresql.ENUM` may be returned.
+
+There is a use case for reflection which is that a given :class:`_schema.Table`
+is to be transferred to a different vendor database.   To suit this use case,
+there is a technique by which these vendor-specific datatypes can be converted
+on the fly to be instance of SQLAlchemy backend-agnostic datatypes, for
+the examples above types such as :class:`_types.Integer`, :class:`_types.Interval`
+and :class:`_types.Enum`.   This may be achieved by intercepting the
+column reflection using the :meth:`_events.DDLEvents.column_reflect` event
+in conjunction with the :meth:`_types.TypeEngine.as_generic` method.
+
+Given a table in MySQL (chosen because MySQL has a lot of vendor-specific
+datatypes and options)::
+
+    CREATE TABLE IF NOT EXISTS my_table (
+        id INTEGER PRIMARY KEY AUTO_INCREMENT,
+        data1 VARCHAR(50) CHARACTER SET latin1,
+        data2 MEDIUMINT(4),
+        data3 TINYINT(2)
+    )
+
+The above table includes MySQL-only integer types ``MEDIUMINT`` and
+``TINYINT`` as well as a ``VARCHAR`` that includes the MySQL-only ``CHARACTER
+SET`` option.   If we reflect this table normally, it produces a
+:class:`_schema.Table` object that will contain those MySQL-specific datatypes
+and options:
+
+.. sourcecode:: pycon+sql
+
+    >>> from sqlalchemy import MetaData, Table, create_engine
+    >>> mysql_engine = create_engine("mysql://scott:tiger@localhost/test")
+    >>> metadata = MetaData()
+    >>> my_mysql_table = Table("my_table", metadata, autoload_with=mysql_engine)
+
+The above example reflects the above table schema into a new :class:`_schema.Table`
+object.  We can then, for demonstration purposes, print out the MySQL-specific
+"CREATE TABLE" statement using the :class:`_schema.CreateTable` construct:
+
+.. sourcecode:: pycon+sql
+
+    >>> from sqlalchemy.schema import CreateTable
+    >>> print(CreateTable(my_mysql_table).compile(mysql_engine))
+    {opensql}CREATE TABLE my_table (
+    id INTEGER(11) NOT NULL AUTO_INCREMENT,
+    data1 VARCHAR(50) CHARACTER SET latin1,
+    data2 MEDIUMINT(4),
+    data3 TINYINT(2),
+    PRIMARY KEY (id)
+    )ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
+
+
+Above, the MySQL-specific datatypes and options were maintained.   If we wanted
+a :class:`_schema.Table` that we could instead transfer cleanly to another
+database vendor, replacing the special datatypes
+:class:`sqlalchemy.dialects.mysql.MEDIUMINT` and
+:class:`sqlalchemy.dialects.mysql.TINYINT` with :class:`_types.Integer`, we can
+choose instead to "genericize" the datatypes on this table, or otherwise change
+them in any way we'd like, by establishing a handler using the
+:meth:`_events.DDLEvents.column_reflect` event.  The custom handler will make use
+of the :meth:`_types.TypeEngine.as_generic` method to convert the above
+MySQL-specific type objects into generic ones, by replacing the ``"type"``
+entry within the column dictionary entry that is passed to the event handler.
+The format of this dictionary is described at :meth:`_reflection.Inspector.get_columns`:
+
+.. sourcecode:: pycon+sql
+
+    >>> from sqlalchemy import event
+    >>> metadata = MetaData()
+
+    >>> @event.listens_for(metadata, "column_reflect")
+    >>> def genericize_datatypes(inspector, tablename, column_dict):
+    ...     column_dict["type"] = column_dict["type"].as_generic()
+
+    >>> my_generic_table = Table("my_table", metadata, autoload_with=mysql_engine)
+
+We now get a new :class:`_schema.Table` that is generic and uses
+:class:`_types.Integer` for those datatypes.  We can now emit a
+"CREATE TABLE" statement for example on a PostgreSQL database:
+
+.. sourcecode:: pycon+sql
+
+    >>> pg_engine = create_engine("postgresql://scott:tiger@localhost/test", echo=True)
+    >>> my_generic_table.create(pg_engine)
+    {opensql}CREATE TABLE my_table (
+        id SERIAL NOT NULL,
+        data1 VARCHAR(50),
+        data2 INTEGER,
+        data3 INTEGER,
+        PRIMARY KEY (id)
+    )
+
+Noting above also that SQLAlchemy will usually make a decent guess for other
+behaviors, such as that the MySQL ``AUTO_INCREMENT`` directive is represented
+in PostgreSQL most closely using the ``SERIAL`` auto-incrementing datatype.
+
+.. versionadded:: 1.4 Added the :meth:`_types.TypeEngine.as_generic` method
+   and additionally improved the use of the :meth:`_events.DDLEvents.column_reflect`
+   event such that it may be applied to a :class:`_schema.MetaData` object
+   for convenience.
+
+
 Limitations of Reflection
 -------------------------
 
index 223c1db98d183de5bc4c28921490ad7f17a89da5..371a6702e7a4a33d22f7ef3db85d64b672120a7c 100644 (file)
@@ -666,6 +666,13 @@ class INTERVAL(sqltypes.NativeForEmulated, sqltypes._AbstractInterval):
     def _type_affinity(self):
         return sqltypes.Interval
 
+    def as_generic(self, allow_nulltype=False):
+        return sqltypes.Interval(
+            native=True,
+            second_precision=self.second_precision,
+            day_precision=self.day_precision,
+        )
+
 
 class ROWID(sqltypes.TypeEngine):
     """Oracle ROWID type.
index 612bc92230f8d4a926ab39e8522c6c72bb1a5106..e41e489c0df796758bdcb4670e5bf243b44ae918 100644 (file)
@@ -1474,6 +1474,9 @@ class INTERVAL(sqltypes.NativeForEmulated, sqltypes._AbstractInterval):
     def _type_affinity(self):
         return sqltypes.Interval
 
+    def as_generic(self, allow_nulltype=False):
+        return sqltypes.Interval(native=True, second_precision=self.precision)
+
     @property
     def python_type(self):
         return dt.timedelta
index 58d04f7aa380a3408f4adf1111469fe0a533970b..797ca697fcb7bbc94b059021d5f279ba732d66e2 100644 (file)
@@ -314,4 +314,7 @@ class DDLEvents(event.Events):
             :ref:`automap_intercepting_columns` -
             in the :ref:`automap_toplevel` documentation
 
+            :ref:`metadata_reflection_dbagnostic_types` - in
+            the :ref:`metadata_reflection_toplevel` documentation
+
         """
index 45d4f0b7f1c24c75ef81e1dc52814dc31d9b8b0d..581573d17e282f3f4ed58684c7a119956d52e06a 100644 (file)
@@ -1607,6 +1607,18 @@ class Enum(Emulated, String, SchemaType):
             to_inspect=[Enum, SchemaType],
         )
 
+    def as_generic(self, allow_nulltype=False):
+        if hasattr(self, "enums"):
+            args = self.enums
+        else:
+            raise NotImplementedError(
+                "TypeEngine.as_generic() heuristic "
+                "is undefined for types that inherit Enum but do not have "
+                "an `enums` attribute."
+            )
+
+        return util.constructor_copy(self, self._generic_type_affinity, *args)
+
     def adapt_to_emulated(self, impltype, **kw):
         kw.setdefault("_expect_unicode", self._expect_unicode)
         kw.setdefault("validate_strings", self.validate_strings)
index bca6e9020e44b816db0ad4c221d1e076e26351f4..b48886cca70dc9b100c02da20aebacc8246c41a6 100644 (file)
@@ -17,7 +17,6 @@ from .visitors import TraversibleType
 from .. import exc
 from .. import util
 
-
 # these are back-assigned by sqltypes.
 BOOLEANTYPE = None
 INTEGERTYPE = None
@@ -372,10 +371,7 @@ class TypeEngine(Traversible):
 
         """
 
-        return (
-            self.__class__.bind_expression.__code__
-            is not TypeEngine.bind_expression.__code__
-        )
+        return util.method_is_overridden(self, TypeEngine.bind_expression)
 
     @staticmethod
     def _to_instance(cls_or_self):
@@ -456,12 +452,13 @@ class TypeEngine(Traversible):
         else:
             return self.__class__
 
-    @classmethod
-    def _is_generic_type(cls):
-        n = cls.__name__
-        return n.upper() != n
-
+    @util.memoized_property
     def _generic_type_affinity(self):
+        best_camelcase = None
+        best_uppercase = None
+
+        if not isinstance(self, (TypeEngine, UserDefinedType)):
+            return self.__class__
 
         for t in self.__class__.__mro__:
             if (
@@ -470,13 +467,56 @@ class TypeEngine(Traversible):
                     "sqlalchemy.sql.sqltypes",
                     "sqlalchemy.sql.type_api",
                 )
-                and t._is_generic_type()
+                and issubclass(t, TypeEngine)
+                and t is not TypeEngine
+                and t.__name__[0] != "_"
             ):
-                if t in (TypeEngine, UserDefinedType):
-                    return NULLTYPE.__class__
-                return t
-        else:
-            return self.__class__
+                if t.__name__.isupper() and not best_uppercase:
+                    best_uppercase = t
+                elif not t.__name__.isupper() and not best_camelcase:
+                    best_camelcase = t
+
+        return best_camelcase or best_uppercase or NULLTYPE.__class__
+
+    def as_generic(self, allow_nulltype=False):
+        """
+        Return an instance of the generic type corresponding to this type
+        using heuristic rule. The method may be overridden if this
+        heuristic rule is not sufficient.
+
+        >>> from sqlalchemy.dialects.mysql import INTEGER
+        >>> INTEGER(display_width=4).as_generic()
+        Integer()
+
+        >>> from sqlalchemy.dialects.mysql import NVARCHAR
+        >>> NVARCHAR(length=100).as_generic()
+        Unicode(length=100)
+
+        .. versionadded:: 1.4.0b2
+
+
+        .. seealso::
+
+            :ref:`metadata_reflection_dbagnostic_types` - describes the
+            use of :meth:`_types.TypeEngine.as_generic` in conjunction with
+            the :meth:`_sql.DDLEvents.column_reflect` event, which is its
+            intended use.
+
+        """
+        if (
+            not allow_nulltype
+            and self._generic_type_affinity == NULLTYPE.__class__
+        ):
+            raise NotImplementedError(
+                "Default TypeEngine.as_generic() "
+                "heuristic method was unsuccessful for {}. A custom "
+                "as_generic() method must be implemented for this "
+                "type class.".format(
+                    self.__class__.__module__ + "." + self.__class__.__name__
+                )
+            )
+
+        return util.constructor_copy(self, self._generic_type_affinity)
 
     def dialect_impl(self, dialect):
         """Return a dialect-specific implementation for this
@@ -1171,18 +1211,16 @@ class TypeDecorator(SchemaEventTarget, TypeEngine):
 
         """
 
-        return (
-            self.__class__.process_bind_param.__code__
-            is not TypeDecorator.process_bind_param.__code__
+        return util.method_is_overridden(
+            self, TypeDecorator.process_bind_param
         )
 
     @util.memoized_property
     def _has_literal_processor(self):
         """memoized boolean, check if process_literal_param is implemented."""
 
-        return (
-            self.__class__.process_literal_param.__code__
-            is not TypeDecorator.process_literal_param.__code__
+        return util.method_is_overridden(
+            self, TypeDecorator.process_literal_param
         )
 
     def literal_processor(self, dialect):
@@ -1278,9 +1316,9 @@ class TypeDecorator(SchemaEventTarget, TypeEngine):
         exception throw.
 
         """
-        return (
-            self.__class__.process_result_value.__code__
-            is not TypeDecorator.process_result_value.__code__
+
+        return util.method_is_overridden(
+            self, TypeDecorator.process_result_value
         )
 
     def result_processor(self, dialect, coltype):
@@ -1322,10 +1360,11 @@ class TypeDecorator(SchemaEventTarget, TypeEngine):
 
     @util.memoized_property
     def _has_bind_expression(self):
+
         return (
-            self.__class__.bind_expression.__code__
-            is not TypeDecorator.bind_expression.__code__
-        ) or self.impl._has_bind_expression
+            util.method_is_overridden(self, TypeDecorator.bind_expression)
+            or self.impl._has_bind_expression
+        )
 
     def bind_expression(self, bindparam):
         return self.impl.bind_expression(bindparam)
@@ -1340,9 +1379,9 @@ class TypeDecorator(SchemaEventTarget, TypeEngine):
         """
 
         return (
-            self.__class__.column_expression.__code__
-            is not TypeDecorator.column_expression.__code__
-        ) or self.impl._has_column_expression
+            util.method_is_overridden(self, TypeDecorator.column_expression)
+            or self.impl._has_column_expression
+        )
 
     def column_expression(self, column):
         return self.impl.column_expression(column)
index 2db1adb8d2374982eea5200b0f81b854f9e4678e..f4363d03cef7d3c55a4fb10459481b3da623664f 100644 (file)
@@ -147,6 +147,7 @@ from .langhelpers import md5_hex  # noqa
 from .langhelpers import memoized_instancemethod  # noqa
 from .langhelpers import memoized_property  # noqa
 from .langhelpers import MemoizedSlots  # noqa
+from .langhelpers import method_is_overridden  # noqa
 from .langhelpers import methods_equivalent  # noqa
 from .langhelpers import monkeypatch_proxied_specials  # noqa
 from .langhelpers import NoneType  # noqa
index 8d6c2d8eeeff2e05b8d84889955db8966b0b12cc..b0963ce431ff57080e132c4d098325a195e21f1f 100644 (file)
@@ -114,6 +114,21 @@ def clsname_as_plain_name(cls):
     )
 
 
+def method_is_overridden(instance_or_cls, against_method):
+    """Return True if the two class methods don't match."""
+
+    if not isinstance(instance_or_cls, type):
+        current_cls = instance_or_cls.__class__
+    else:
+        current_cls = instance_or_cls
+
+    method_name = against_method.__name__
+
+    current_method = getattr(current_cls, method_name)
+
+    return current_method != against_method
+
+
 def decode_slice(slc):
     """decode a slice object as sent to __getitem__.
 
index 6d2bb60c958cbe0f5aaa4e6a0c391390f770399c..a00fbd0186297272606b9a17223368c2d82eb919 100644 (file)
@@ -3311,3 +3311,92 @@ class TestModuleRegistry(fixtures.TestBase):
             for name, mod in to_restore:
                 if mod is not None:
                     sys.modules[name] = mod
+
+
+class MethodOveriddenTest(fixtures.TestBase):
+    def test_subclass_overrides_cls_given(self):
+        class Foo(object):
+            def bar(self):
+                pass
+
+        class Bar(Foo):
+            def bar(self):
+                pass
+
+        is_true(util.method_is_overridden(Bar, Foo.bar))
+
+    def test_subclass_overrides(self):
+        class Foo(object):
+            def bar(self):
+                pass
+
+        class Bar(Foo):
+            def bar(self):
+                pass
+
+        is_true(util.method_is_overridden(Bar(), Foo.bar))
+
+    def test_subclass_overrides_skiplevel(self):
+        class Foo(object):
+            def bar(self):
+                pass
+
+        class Bar(Foo):
+            pass
+
+        class Bat(Bar):
+            def bar(self):
+                pass
+
+        is_true(util.method_is_overridden(Bat(), Foo.bar))
+
+    def test_subclass_overrides_twolevels(self):
+        class Foo(object):
+            def bar(self):
+                pass
+
+        class Bar(Foo):
+            def bar(self):
+                pass
+
+        class Bat(Bar):
+            pass
+
+        is_true(util.method_is_overridden(Bat(), Foo.bar))
+
+    def test_subclass_doesnt_override_cls_given(self):
+        class Foo(object):
+            def bar(self):
+                pass
+
+        class Bar(Foo):
+            pass
+
+        is_false(util.method_is_overridden(Bar, Foo.bar))
+
+    def test_subclass_doesnt_override(self):
+        class Foo(object):
+            def bar(self):
+                pass
+
+        class Bar(Foo):
+            pass
+
+        is_false(util.method_is_overridden(Bar(), Foo.bar))
+
+    def test_subclass_overrides_multi_mro(self):
+        class Base(object):
+            pass
+
+        class Foo(object):
+            pass
+
+        class Bat(Base):
+            def bar(self):
+                pass
+
+        class HoHo(Foo, Bat):
+            def bar(self):
+                pass
+
+        is_true(util.method_is_overridden(HoHo(), Bat.bar))
index fd1783e09806177acb953da3225526f493e20ec4..3178eb09ab09f4b91f20291e73caae99cab4527f 100644 (file)
@@ -58,6 +58,9 @@ from sqlalchemy import types
 from sqlalchemy import Unicode
 from sqlalchemy import util
 from sqlalchemy import VARCHAR
+import sqlalchemy.dialects.mysql as mysql
+import sqlalchemy.dialects.oracle as oracle
+import sqlalchemy.dialects.postgresql as pg
 from sqlalchemy.engine import default
 from sqlalchemy.schema import AddConstraint
 from sqlalchemy.schema import CheckConstraint
@@ -69,6 +72,7 @@ from sqlalchemy.sql import operators
 from sqlalchemy.sql import sqltypes
 from sqlalchemy.sql import table
 from sqlalchemy.sql import visitors
+from sqlalchemy.sql.sqltypes import TypeEngine
 from sqlalchemy.testing import assert_raises
 from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.testing import AssertsCompiledSQL
@@ -383,6 +387,69 @@ class TypeAffinityTest(fixtures.TestBase):
         assert t1.dialect_impl(d)._type_affinity is postgresql.UUID
 
 
+class AsGenericTest(fixtures.TestBase):
+    @testing.combinations(
+        (String(), String()),
+        (VARCHAR(length=100), String(length=100)),
+        (NVARCHAR(length=100), Unicode(length=100)),
+        (DATE(), Date()),
+        (pg.JSON(), sa.JSON()),
+        (pg.ARRAY(sa.String), sa.ARRAY(sa.String)),
+        (Enum("a", "b", "c"), Enum("a", "b", "c")),
+        (pg.ENUM("a", "b", "c"), Enum("a", "b", "c")),
+        (mysql.ENUM("a", "b", "c"), Enum("a", "b", "c")),
+        (pg.INTERVAL(precision=5), Interval(native=True, second_precision=5)),
+        (
+            oracle.INTERVAL(second_precision=5, day_precision=5),
+            Interval(native=True, day_precision=5, second_precision=5),
+        ),
+    )
+    def test_as_generic(self, t1, t2):
+        assert repr(t1.as_generic(allow_nulltype=False)) == repr(t2)
+
+    @testing.combinations(
+        *[
+            (t,)
+            for t in _all_types(omit_special_types=True)
+            if not util.method_is_overridden(t, TypeEngine.as_generic)
+        ]
+    )
+    def test_as_generic_all_types_heuristic(self, type_):
+        if issubclass(type_, ARRAY):
+            t1 = type_(String)
+        else:
+            t1 = type_()
+
+        try:
+            gentype = t1.as_generic()
+        except NotImplementedError:
+            pass
+        else:
+            assert isinstance(t1, gentype.__class__)
+            assert isinstance(gentype, TypeEngine)
+
+        gentype = t1.as_generic(allow_nulltype=True)
+        if not isinstance(gentype, types.NULLTYPE.__class__):
+            assert isinstance(t1, gentype.__class__)
+            assert isinstance(gentype, TypeEngine)
+
+    @testing.combinations(
+        *[
+            (t,)
+            for t in _all_types(omit_special_types=True)
+            if util.method_is_overridden(t, TypeEngine.as_generic)
+        ]
+    )
+    def test_as_generic_all_types_custom(self, type_):
+        if issubclass(type_, ARRAY):
+            t1 = type_(String)
+        else:
+            t1 = type_()
+
+        gentype = t1.as_generic(allow_nulltype=False)
+        assert isinstance(gentype, TypeEngine)
+
+
 class PickleTypesTest(fixtures.TestBase):
     @testing.combinations(
         ("Boo", Boolean()),