]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Add public protocol for mapped class
authorFederico Caselli <cfederico87@gmail.com>
Mon, 23 Jan 2023 21:51:51 +0000 (22:51 +0100)
committerFederico Caselli <cfederico87@gmail.com>
Wed, 25 Jan 2023 20:04:56 +0000 (21:04 +0100)
Fixes: #8624
Change-Id: Ia7a66ae9ba534ed7152f95dfd0f7d05b9d00165a

doc/build/orm/mapping_api.rst
doc/build/orm/mapping_styles.rst
lib/sqlalchemy/orm/__init__.py
lib/sqlalchemy/orm/decl_base.py
test/ext/mypy/plain_files/declared_attr_one.py
test/orm/declarative/test_basic.py

index 8eebe7c775adefa255cee7dd510ec5a2e8ee7510..1a33f956689b2e4bd7a547eaafc5eccf124e558a 100644 (file)
@@ -141,3 +141,6 @@ Class Mapping API
 
 .. autoclass:: MappedAsDataclass
     :members:
+
+.. autoclass:: MappedClassProtocol
+    :no-members:
index b2639939344eef41d3d5a40a08c2fb6ee9a1612d..b4c21a353de0d31ab6042dcb26ad93ef81e45260 100644 (file)
@@ -36,6 +36,8 @@ the class itself has been :term:`instrumented` to include behaviors linked to
 relational operations both at the level of the class as well as on instances of
 that class. As the process is basically the same in all cases, classes mapped
 from different styles are always fully interoperable with each other.
+The protocol :class:`_orm.MappedClassProtocol` can be used to indicate a mapped
+class when using type checkers such as mypy.
 
 The original mapping API is commonly referred to as "classical" style,
 whereas the more automated style of mapping is known as "declarative" style.
index 6980db2e244e08d8c04a0e26b078398bdceb14e9..d54e1ccb9c7183e5a1db1fa9ca05b16133337535 100644 (file)
@@ -65,6 +65,7 @@ from .decl_api import has_inherited_table as has_inherited_table
 from .decl_api import MappedAsDataclass as MappedAsDataclass
 from .decl_api import registry as registry
 from .decl_api import synonym_for as synonym_for
+from .decl_base import MappedClassProtocol as MappedClassProtocol
 from .descriptor_props import Composite as Composite
 from .descriptor_props import CompositeProperty as CompositeProperty
 from .descriptor_props import Synonym as Synonym
index a379af2ddd67e992017cc8053a7b5fd063308627..9e8b02359735ff0f9524c2cdfcc6ee825d6e3ef4 100644 (file)
@@ -49,7 +49,6 @@ from .interfaces import _IntrospectsAnnotations
 from .interfaces import _MappedAttribute
 from .interfaces import _MapsColumns
 from .interfaces import MapperProperty
-from .mapper import Mapper as mapper
 from .mapper import Mapper
 from .properties import ColumnProperty
 from .properties import MappedColumn
@@ -84,25 +83,38 @@ if TYPE_CHECKING:
 _T = TypeVar("_T", bound=Any)
 
 _MapperKwArgs = Mapping[str, Any]
-
 _TableArgsType = Union[Tuple[Any, ...], Dict[str, Any]]
 
 
-class _DeclMappedClassProtocol(Protocol[_O]):
-    metadata: MetaData
+class MappedClassProtocol(Protocol[_O]):
+    """A protocol representing a SQLAlchemy mapped class.
+
+    The protocol is generic on the type of class, use
+    ``MappedClassProtocol[Any]`` to allow any mapped class.
+    """
+
+    __name__: str
     __mapper__: Mapper[_O]
-    __table__: Table
+    __table__: FromClause
+
+    def __call__(self, **kw: Any) -> _O:
+        ...
+
+
+class _DeclMappedClassProtocol(MappedClassProtocol[_O], Protocol):
+    "Internal more detailed version of ``MappedClassProtocol``."
+    metadata: MetaData
     __tablename__: str
-    __mapper_args__: Mapping[str, Any]
+    __mapper_args__: _MapperKwArgs
     __table_args__: Optional[_TableArgsType]
 
     _sa_apply_dc_transforms: Optional[_DataclassArguments]
 
     def __declare_first__(self) -> None:
-        pass
+        ...
 
     def __declare_last__(self) -> None:
-        pass
+        ...
 
 
 class _DataclassArguments(TypedDict):
@@ -241,7 +253,7 @@ def _mapper(
     mapper_kw: _MapperKwArgs,
 ) -> Mapper[_O]:
     _ImperativeMapperConfig(registry, cls, table, mapper_kw)
-    return cast("_DeclMappedClassProtocol[_O]", cls).__mapper__
+    return cast("MappedClassProtocol[_O]", cls).__mapper__
 
 
 @util.preload_module("sqlalchemy.orm.decl_api")
@@ -297,7 +309,7 @@ class _MapperConfig:
         manager = attributes.opt_manager_of_class(cls)
         if manager and manager.class_ is cls_:
             raise exc.InvalidRequestError(
-                "Class %r already has been " "instrumented declaratively" % cls
+                f"Class {cls!r} already has been instrumented declaratively"
             )
 
         if cls_.__dict__.get("__abstract__", False):
@@ -382,7 +394,7 @@ class _ImperativeMapperConfig(_MapperConfig):
             self._early_mapping(mapper_kw)
 
     def map(self, mapper_kw: _MapperKwArgs = util.EMPTY_DICT) -> Mapper[Any]:
-        mapper_cls = mapper
+        mapper_cls = Mapper
 
         return self.set_cls_attribute(
             "__mapper__",
@@ -413,7 +425,7 @@ class _ImperativeMapperConfig(_MapperConfig):
                         % (cls, inherits_search)
                     )
                 inherits = inherits_search[0]
-        elif isinstance(inherits, mapper):
+        elif isinstance(inherits, Mapper):
             inherits = inherits.class_
 
         self.inherits = inherits
@@ -567,7 +579,7 @@ class _ClassScanMapperConfig(_MapperConfig):
     def _setup_declared_events(self) -> None:
         if _get_immediate_cls_attr(self.cls, "__declare_last__"):
 
-            @event.listens_for(mapper, "after_configured")
+            @event.listens_for(Mapper, "after_configured")
             def after_configured() -> None:
                 cast(
                     "_DeclMappedClassProtocol[Any]", self.cls
@@ -575,7 +587,7 @@ class _ClassScanMapperConfig(_MapperConfig):
 
         if _get_immediate_cls_attr(self.cls, "__declare_first__"):
 
-            @event.listens_for(mapper, "before_configured")
+            @event.listens_for(Mapper, "before_configured")
             def before_configured() -> None:
                 cast(
                     "_DeclMappedClassProtocol[Any]", self.cls
@@ -1507,7 +1519,7 @@ class _ClassScanMapperConfig(_MapperConfig):
 
     def _setup_table(self, table: Optional[FromClause] = None) -> None:
         cls = self.cls
-        cls_as_Decl = cast("_DeclMappedClassProtocol[Any]", cls)
+        cls_as_Decl = cast("MappedClassProtocol[Any]", cls)
 
         tablename = self.tablename
         table_args = self.table_args
@@ -1570,8 +1582,9 @@ class _ClassScanMapperConfig(_MapperConfig):
         self.local_table = table
 
     def _metadata_for_cls(self, manager: ClassManager[Any]) -> MetaData:
-        if hasattr(self.cls, "metadata"):
-            return cast("_DeclMappedClassProtocol[Any]", self.cls).metadata
+        meta: Optional[MetaData] = getattr(self.cls, "metadata", None)
+        if meta is not None:
+            return meta
         else:
             return manager.registry.metadata
 
@@ -1599,7 +1612,7 @@ class _ClassScanMapperConfig(_MapperConfig):
                         % (cls, inherits_search)
                     )
                 inherits = inherits_search[0]
-        elif isinstance(inherits, mapper):
+        elif isinstance(inherits, Mapper):
             inherits = inherits.class_
 
         self.inherits = inherits
@@ -1701,7 +1714,7 @@ class _ClassScanMapperConfig(_MapperConfig):
 
         if "inherits" in mapper_args:
             inherits_arg = mapper_args["inherits"]
-            if isinstance(inherits_arg, mapper):
+            if isinstance(inherits_arg, Mapper):
                 inherits_arg = inherits_arg.class_
 
             if inherits_arg is not self.inherits:
@@ -1762,7 +1775,7 @@ class _ClassScanMapperConfig(_MapperConfig):
                 ),
             )
         else:
-            mapper_cls = mapper
+            mapper_cls = Mapper
 
         return self.set_cls_attribute(
             "__mapper__",
@@ -1873,18 +1886,29 @@ def _add_attribute(
     """
 
     if "__mapper__" in cls.__dict__:
-        mapped_cls = cast("_DeclMappedClassProtocol[Any]", cls)
+        mapped_cls = cast("MappedClassProtocol[Any]", cls)
+
+        def _table_or_raise(mc: MappedClassProtocol[Any]) -> Table:
+            if isinstance(mc.__table__, Table):
+                return mc.__table__
+            raise exc.InvalidRequestError(
+                f"Cannot add a new attribute to mapped class {mc.__name__!r} "
+                "because it's not mapped against a table."
+            )
+
         if isinstance(value, Column):
             _undefer_column_name(key, value)
-            # TODO: raise for this is not a Table
-            mapped_cls.__table__.append_column(value, replace_existing=True)
+            _table_or_raise(mapped_cls).append_column(
+                value, replace_existing=True
+            )
             mapped_cls.__mapper__.add_property(key, value)
         elif isinstance(value, _MapsColumns):
             mp = value.mapper_property_to_assign
             for col in value.columns_to_assign:
                 _undefer_column_name(key, col)
-                # TODO: raise for this is not a Table
-                mapped_cls.__table__.append_column(col, replace_existing=True)
+                _table_or_raise(mapped_cls).append_column(
+                    col, replace_existing=True
+                )
                 if not mp:
                     mapped_cls.__mapper__.add_property(key, col)
             if mp:
@@ -1904,12 +1928,11 @@ def _add_attribute(
 
 
 def _del_attribute(cls: Type[Any], key: str) -> None:
-
     if (
         "__mapper__" in cls.__dict__
         and key in cls.__dict__
         and not cast(
-            "_DeclMappedClassProtocol[Any]", cls
+            "MappedClassProtocol[Any]", cls
         ).__mapper__._dispose_called
     ):
         value = cls.__dict__[key]
@@ -1922,7 +1945,7 @@ def _del_attribute(cls: Type[Any], key: str) -> None:
         else:
             type.__delattr__(cls, key)
             cast(
-                "_DeclMappedClassProtocol[Any]", cls
+                "MappedClassProtocol[Any]", cls
             ).__mapper__._expire_memoizations()
     else:
         type.__delattr__(cls, key)
index a6d96f39ee53bfb8b03db080b040cf6d3fb66f31..d4f3c826e6ab69435f51be16c541918246ed2e25 100644 (file)
@@ -10,6 +10,7 @@ from sqlalchemy.orm import DeclarativeBase
 from sqlalchemy.orm import declared_attr
 from sqlalchemy.orm import Mapped
 from sqlalchemy.orm import mapped_column
+from sqlalchemy.orm import MappedClassProtocol
 from sqlalchemy.sql.schema import PrimaryKeyConstraint
 
 
@@ -70,6 +71,24 @@ class Manager(Employee):
         )
 
 
+def do_something_with_mapped_class(
+    cls_: MappedClassProtocol[Employee],
+) -> None:
+
+    # EXPECTED_TYPE: Select[Any]
+    reveal_type(cls_.__table__.select())
+
+    # EXPECTED_TYPE: Mapper[Employee]
+    reveal_type(cls_.__mapper__)
+
+    # EXPECTED_TYPE: Employee
+    reveal_type(cls_())
+
+
+do_something_with_mapped_class(Manager)
+do_something_with_mapped_class(Engineer)
+
+
 if typing.TYPE_CHECKING:
 
     # EXPECTED_TYPE: InstrumentedAttribute[datetime]
index 83d103864f849ef25445cff18f31fd170a67b0cf..28fdc97f23fce33cba283c0a59ad6a877f0305b3 100644 (file)
@@ -611,6 +611,24 @@ class DeclarativeBaseSetupsTest(fixtures.TestBase):
                     sa.Column("id", Integer, primary_key=True),
                 )
 
+    def test_cannot_add_to_selectable(self):
+        class Base(DeclarativeBase):
+            pass
+
+        class Foo(Base):
+            __table__ = (
+                select(sa.Column("x", sa.Integer, primary_key=True))
+                .select_from(sa.table("foo"))
+                .subquery("foo")
+            )
+
+        with assertions.expect_raises_message(
+            exc.InvalidRequestError,
+            "Cannot add a new attribute to mapped class 'Foo' "
+            "because it's not mapped against a table",
+        ):
+            Foo.y = mapped_column(sa.Text)
+
 
 @testing.combinations(
     ("declarative_base_nometa_superclass",),