]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
allow TypeEngine classes to export their repr() parameters
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 22 Feb 2026 21:04:48 +0000 (16:04 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 24 Feb 2026 16:10:52 +0000 (11:10 -0500)
Added a new :class:`.GenericRepr` class and :meth:`.TypeEngine.repr_struct`
method to provide better control over type representation. The
:class:`.TypeDecorator` class now properly displays its decorated type's
parameters in its ``__repr__()``, allowing introspection tools and
libraries like Alembic to better understand the structure of decorated
types, in particular for complex "schema" types such as :class:`.Enum` and
:class:`.Boolean`. Type classes can override
:meth:`.TypeEngine.repr_struct` to customize their representation
structure, and the returned :class:`.GenericRepr` object allows for
modifications such as changing the displayed class name.

Fixes: #13140
Change-Id: Ie41d249cfea56686b16c895b74ae03721207170b

doc/build/changelog/unreleased_21/13140.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/mysql/enumerated.py
lib/sqlalchemy/dialects/mysql/types.py
lib/sqlalchemy/sql/sqltypes.py
lib/sqlalchemy/sql/type_api.py
lib/sqlalchemy/util/__init__.py
lib/sqlalchemy/util/langhelpers.py
test/sql/test_types.py

diff --git a/doc/build/changelog/unreleased_21/13140.rst b/doc/build/changelog/unreleased_21/13140.rst
new file mode 100644 (file)
index 0000000..62a485d
--- /dev/null
@@ -0,0 +1,10 @@
+.. change::
+    :tags: bug, sql
+    :tickets: 13140
+
+    Improved the ability for :class:`.TypeDecorator` to produce a correct
+    ``repr()`` for "schema" types such as :class:`.Enum` and :class:`.Boolean`.
+    This is mostly to support the Alembic autogenerate use case so that custom
+    types render with relevant arguments present.  Improved the architecture
+    used by :class:`.TypeEngine` to produce ``repr()`` strings to be more
+    modular for compound types like :class:`.TypeDecorator`.
index 599cd781064d0a81a7dcb9b43fe172231b82e48d..0caffc1edfd90be71462f44122efdd2f14c4be2a 100644 (file)
@@ -104,8 +104,8 @@ class ENUM(type_api.NativeForEmulated, sqltypes.Enum, _StringType):
         else:
             return super()._object_value_for_elem(elem)
 
-    def __repr__(self) -> str:
-        return util.generic_repr(
+    def repr_struct(self) -> util.GenericRepr:
+        return util.GenericRepr(
             self, to_inspect=[ENUM, _StringType, sqltypes.Enum]
         )
 
@@ -267,8 +267,8 @@ class SET(_StringType):
         kw["retrieve_as_bitwise"] = self.retrieve_as_bitwise
         return util.constructor_copy(self, cls, *self.values, **kw)
 
-    def __repr__(self) -> str:
-        return util.generic_repr(
+    def repr_struct(self) -> util.GenericRepr:
+        return util.GenericRepr(
             self,
             to_inspect=[SET, _StringType],
             additional_kw=[
index d841485547b9f3d5ba17e15e4a3064898a69902a..76376081c8d5ec19d94cc2d8304c8dff2f3681af 100644 (file)
@@ -46,8 +46,8 @@ class _NumericType(
     _NumericCommonType, sqltypes.Numeric[Union[decimal.Decimal, float]]
 ):
 
-    def __repr__(self) -> str:
-        return util.generic_repr(
+    def repr_struct(self) -> util.GenericRepr:
+        return util.GenericRepr(
             self,
             to_inspect=[_NumericType, _NumericCommonType, sqltypes.Numeric],
         )
@@ -75,8 +75,8 @@ class _FloatType(
         super().__init__(precision=precision, asdecimal=asdecimal, **kw)
         self.scale = scale
 
-    def __repr__(self) -> str:
-        return util.generic_repr(
+    def repr_struct(self) -> util.GenericRepr:
+        return util.GenericRepr(
             self, to_inspect=[_FloatType, _NumericCommonType, sqltypes.Float]
         )
 
@@ -86,8 +86,8 @@ class _IntegerType(_NumericCommonType, sqltypes.Integer):
         self.display_width = display_width
         super().__init__(**kw)
 
-    def __repr__(self) -> str:
-        return util.generic_repr(
+    def repr_struct(self) -> util.GenericRepr:
+        return util.GenericRepr(
             self,
             to_inspect=[_IntegerType, _NumericCommonType, sqltypes.Integer],
         )
@@ -117,8 +117,8 @@ class _StringType(sqltypes.String):
         self.national = national
         super().__init__(**kw)
 
-    def __repr__(self) -> str:
-        return util.generic_repr(
+    def repr_struct(self) -> util.GenericRepr:
+        return util.GenericRepr(
             self, to_inspect=[_StringType, sqltypes.String]
         )
 
index 6308f3014d0ee1cd0d244d836e5f77de3d76fce4..98e5783a3d0276a065c4e0181706b4c892df2a19 100644 (file)
@@ -1841,8 +1841,8 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]):
                 )
             ) from err
 
-    def __repr__(self):
-        return util.generic_repr(
+    def repr_struct(self):
+        return util.GenericRepr(
             self,
             additional_kw=[
                 ("native_enum", True),
index f1b6a28e8d90ab9d801bf7590126b156d99ab815..cd04ca0e11a31cb3f811314d2d8a55cc6b2a84b6 100644 (file)
@@ -1158,8 +1158,19 @@ class TypeEngine(Visitable, Generic[_T]):
     def __str__(self) -> str:
         return str(self.compile())
 
+    def repr_struct(self) -> util.GenericRepr:
+        """Return a :class:`.GenericRepr` object representing this type.
+
+        This method is used to generate the repr string for the type.
+        Subclasses can override this to customize the repr structure.
+
+        .. versionadded:: 2.1
+
+        """
+        return util.GenericRepr(self)
+
     def __repr__(self) -> str:
-        return util.generic_repr(self)
+        return str(self.repr_struct())
 
 
 class TypeEngineMixin:
@@ -2362,8 +2373,18 @@ class TypeDecorator(SchemaEventTarget, ExternalType, TypeEngine[_T]):
     def sort_key_function(self) -> Optional[Callable[[Any], Any]]:  # type: ignore # noqa: E501
         return self.impl_instance.sort_key_function
 
-    def __repr__(self) -> str:
-        return util.generic_repr(self, to_inspect=self.impl_instance)
+    def repr_struct(self) -> util.GenericRepr:
+        """Return a :class:`.GenericRepr` object representing this type.
+
+        For TypeDecorator, this returns a repr structure based on the
+        impl instance but with the TypeDecorator's class name.
+
+        .. versionadded:: 2.1
+
+        """
+        return self.impl_instance.repr_struct().set_class_name(
+            self.__class__.__name__
+        )
 
 
 class Variant(TypeDecorator[_T]):
index 5d55a691d5768752cd879f1fef3e32bd461bd6b5..183a6a1f20bc1bb82ee0b21fb9ecfcfa58e26cc9 100644 (file)
@@ -106,6 +106,7 @@ from .langhelpers import format_argspec_init as format_argspec_init
 from .langhelpers import format_argspec_plus as format_argspec_plus
 from .langhelpers import generic_fn_descriptor as generic_fn_descriptor
 from .langhelpers import generic_repr as generic_repr
+from .langhelpers import GenericRepr as GenericRepr
 from .langhelpers import get_callable_argspec as get_callable_argspec
 from .langhelpers import get_cls_kwargs as get_cls_kwargs
 from .langhelpers import get_func_kwargs as get_func_kwargs
index 6767a97121375ad041b2e5b62568e40652aa6627..affa6c4fa0bb8677e72e9a37fcde6f75127d6506 100644 (file)
@@ -767,74 +767,119 @@ def unbound_method_to_callable(func_or_cls):
         return func_or_cls
 
 
-def generic_repr(
-    obj: Any,
-    additional_kw: Sequence[Tuple[str, Any]] = (),
-    to_inspect: Optional[Union[object, List[object]]] = None,
-    omit_kwarg: Sequence[str] = (),
-) -> str:
-    """Produce a __repr__() based on direct association of the __init__()
-    specification vs. same-named attributes present.
+class GenericRepr:
+    """Encapsulates the logic for creating a generic __repr__() string.
 
+    This class allows for the repr structure to be created, then modified
+    (e.g., changing the class name), before being rendered as a string.
+
+    .. versionadded:: 2.1
     """
-    if to_inspect is None:
-        to_inspect = [obj]
-    else:
-        to_inspect = _collections.to_list(to_inspect)
 
-    missing = object()
+    __slots__ = (
+        "_obj",
+        "_additional_kw",
+        "_to_inspect",
+        "_omit_kwarg",
+        "_class_name",
+    )
 
-    pos_args = []
-    kw_args: _collections.OrderedDict[str, Any] = _collections.OrderedDict()
-    vargs = None
-    for i, insp in enumerate(to_inspect):
-        try:
-            spec = compat.inspect_getfullargspec(insp.__init__)
-        except TypeError:
-            continue
-        else:
-            default_len = len(spec.defaults) if spec.defaults else 0
-            if i == 0:
-                if spec.varargs:
-                    vargs = spec.varargs
-                if default_len:
-                    pos_args.extend(spec.args[1:-default_len])
-                else:
-                    pos_args.extend(spec.args[1:])
-            else:
-                kw_args.update(
-                    [(arg, missing) for arg in spec.args[1:-default_len]]
-                )
+    _obj: Any
+    _additional_kw: Sequence[Tuple[str, Any]]
+    _to_inspect: List[object]
+    _omit_kwarg: Sequence[str]
+    _class_name: Optional[str]
 
-            if default_len:
-                assert spec.defaults
-                kw_args.update(
-                    [
-                        (arg, default)
-                        for arg, default in zip(
-                            spec.args[-default_len:], spec.defaults
-                        )
-                    ]
-                )
-    output: List[str] = []
+    def __init__(
+        self,
+        obj: Any,
+        additional_kw: Sequence[Tuple[str, Any]] = (),
+        to_inspect: Optional[Union[object, List[object]]] = None,
+        omit_kwarg: Sequence[str] = (),
+    ):
+        """Create a GenericRepr object.
+
+        :param obj: The object being repr'd
+        :param additional_kw: Additional keyword arguments to check for in
+            the repr, as a sequence of 2-tuples of (name, default_value)
+        :param to_inspect: One or more objects whose __init__ signature
+            should be inspected. If not provided, defaults to [obj].
+        :param omit_kwarg: Sequence of keyword argument names to omit from
+            the repr output
+        """
+        self._obj = obj
+        self._additional_kw = additional_kw
+        self._to_inspect = (
+            [obj] if to_inspect is None else _collections.to_list(to_inspect)
+        )
+        self._omit_kwarg = omit_kwarg
+        self._class_name = None
 
-    output.extend(repr(getattr(obj, arg, None)) for arg in pos_args)
+    def set_class_name(self, class_name: str) -> GenericRepr:
+        """Set the class name to be used in the repr.
 
-    if vargs is not None and hasattr(obj, vargs):
-        output.extend([repr(val) for val in getattr(obj, vargs)])
+        By default, the class name is taken from obj.__class__.__name__.
+        This method allows it to be overridden.
 
-    for arg, defval in kw_args.items():
-        if arg in omit_kwarg:
-            continue
-        try:
-            val = getattr(obj, arg, missing)
-            if val is not missing and val != defval:
-                output.append("%s=%r" % (arg, val))
-        except Exception:
-            pass
+        :param class_name: The class name to use
+        :return: self, for method chaining
+        """
+        self._class_name = class_name
+        return self
 
-    if additional_kw:
-        for arg, defval in additional_kw:
+    def __str__(self) -> str:
+        """Produce the __repr__() string based on the configured parameters."""
+        obj = self._obj
+        to_inspect = self._to_inspect
+        additional_kw = self._additional_kw
+        omit_kwarg = self._omit_kwarg
+
+        missing = object()
+
+        pos_args = []
+        kw_args: _collections.OrderedDict[str, Any] = (
+            _collections.OrderedDict()
+        )
+        vargs = None
+        for i, insp in enumerate(to_inspect):
+            try:
+                spec = compat.inspect_getfullargspec(insp.__init__)  # type: ignore[misc]  # noqa: E501
+            except TypeError:
+                continue
+            else:
+                default_len = len(spec.defaults) if spec.defaults else 0
+                if i == 0:
+                    if spec.varargs:
+                        vargs = spec.varargs
+                    if default_len:
+                        pos_args.extend(spec.args[1:-default_len])
+                    else:
+                        pos_args.extend(spec.args[1:])
+                else:
+                    kw_args.update(
+                        [(arg, missing) for arg in spec.args[1:-default_len]]
+                    )
+
+                if default_len:
+                    assert spec.defaults
+                    kw_args.update(
+                        [
+                            (arg, default)
+                            for arg, default in zip(
+                                spec.args[-default_len:], spec.defaults
+                            )
+                        ]
+                    )
+        output: List[str] = []
+
+        output.extend(repr(getattr(obj, arg, None)) for arg in pos_args)
+
+        if vargs is not None and hasattr(obj, vargs):
+            output.extend([repr(val) for val in getattr(obj, vargs)])
+
+        for arg, defval in kw_args.items():
+            if arg in omit_kwarg:
+                continue
             try:
                 val = getattr(obj, arg, missing)
                 if val is not missing and val != defval:
@@ -842,7 +887,41 @@ def generic_repr(
             except Exception:
                 pass
 
-    return "%s(%s)" % (obj.__class__.__name__, ", ".join(output))
+        if additional_kw:
+            for arg, defval in additional_kw:
+                try:
+                    val = getattr(obj, arg, missing)
+                    if val is not missing and val != defval:
+                        output.append("%s=%r" % (arg, val))
+                except Exception:
+                    pass
+
+        class_name = (
+            self._class_name
+            if self._class_name is not None
+            else obj.__class__.__name__
+        )
+        return "%s(%s)" % (class_name, ", ".join(output))
+
+
+def generic_repr(
+    obj: Any,
+    additional_kw: Sequence[Tuple[str, Any]] = (),
+    to_inspect: Optional[Union[object, List[object]]] = None,
+    omit_kwarg: Sequence[str] = (),
+) -> str:
+    """Produce a __repr__() based on direct association of the __init__()
+    specification vs. same-named attributes present.
+
+    """
+    return str(
+        GenericRepr(
+            obj,
+            additional_kw=additional_kw,
+            to_inspect=to_inspect,
+            omit_kwarg=omit_kwarg,
+        )
+    )
 
 
 def class_hierarchy(cls):
index 826a6f2ec4b9919f0445c1ab220f883764244116..17b947f45af1fbc6a69196b834d6ce2a89a58f7c 100644 (file)
@@ -103,6 +103,7 @@ from sqlalchemy.testing.schema import pep435_enum
 from sqlalchemy.testing.schema import Table
 from sqlalchemy.testing.util import picklers
 from sqlalchemy.types import UserDefinedType
+from sqlalchemy.util import GenericRepr
 
 
 def _all_dialect_modules():
@@ -4648,3 +4649,102 @@ class ResolveForLiteralTest(fixtures.TestBase):
     )
     def test_resolve(self, value, expected):
         is_(literal(value).type, expected)
+
+
+class ReprTest(fixtures.TestBase):
+    """test suite for TypeEngine repr_struct() and GenericRepr"""
+
+    def test_generic_repr_basic(self):
+        """Test GenericRepr basic functionality."""
+        t = String(50)
+        gr = GenericRepr(t)
+        eq_(str(gr), "String(length=50)")
+
+    def test_generic_repr_set_class_name(self):
+        """Test GenericRepr.set_class_name() method."""
+        t = String(50)
+        gr = GenericRepr(t)
+        gr.set_class_name("CustomString")
+        eq_(str(gr), "CustomString(length=50)")
+
+    def test_type_engine_repr_struct(self):
+        """Test TypeEngine.repr_struct() returns GenericRepr."""
+        t = String(50)
+        gr = t.repr_struct()
+        assert isinstance(gr, GenericRepr)
+        eq_(str(gr), "String(length=50)")
+
+    @testing.combinations(
+        (Integer(), "Integer()"),
+        (String(50), "String(length=50)"),
+        (VARCHAR(100), "VARCHAR(length=100)"),
+        (NUMERIC(10, 2), "NUMERIC(precision=10, scale=2)"),
+        (
+            Enum("a", "b", "c", name="myenum"),
+            "Enum('a', 'b', 'c', name='myenum')",
+        ),
+        (
+            mysql.NUMERIC(10, 2, unsigned=True),
+            "NUMERIC(unsigned=True, precision=10, scale=2)",
+        ),
+        (
+            mysql.VARCHAR(50, charset="utf8"),
+            "VARCHAR(charset='utf8', length=50)",
+        ),
+        (mysql.ENUM("a", "b", "c"), "ENUM('a', 'b', 'c')"),
+        (mysql.SET("a", "b", "c"), "SET('a', 'b', 'c')"),
+        argnames="type_,expected",
+    )
+    def test_type_repr(self, type_, expected):
+        """Test repr for various type objects."""
+        eq_(repr(type_), expected)
+
+    @testing.variation("impl_type", ["enum", "boolean", "string"])
+    @testing.variation("has_name", [True, False])
+    def test_type_decorator_repr(self, impl_type, has_name):
+        """Test TypeDecorator wrapping various SchemaType objects."""
+
+        if impl_type.enum:
+
+            class MyType(TypeDecorator):
+                impl = Enum
+                cache_ok = True
+
+            if has_name:
+                t = MyType("a", "b", "c", name="myenum")
+                eq_(repr(t), "MyType('a', 'b', 'c', name='myenum')")
+            else:
+                t = MyType("x", "y", "z")
+                eq_(repr(t), "MyType('x', 'y', 'z')")
+
+        elif impl_type.boolean:
+
+            class MyType(TypeDecorator):
+                impl = Boolean
+                cache_ok = True
+
+            if has_name:
+                t = MyType(create_constraint=True, name="mybool")
+                eq_(
+                    repr(t),
+                    "MyType(create_constraint=True, name='mybool')",
+                )
+            else:
+                t = MyType()
+                eq_(repr(t), "MyType()")
+
+        elif impl_type.string:
+
+            class MyType(TypeDecorator):
+                impl = String
+                cache_ok = True
+
+            if has_name:
+                # String doesn't have a name parameter, use length
+                t = MyType(100)
+                eq_(repr(t), "MyType(length=100)")
+            else:
+                t = MyType()
+                eq_(repr(t), "MyType()")
+        else:
+            impl_type.fail()