From 1526cf68af500141480cc51ec4de18c705fe0b0a Mon Sep 17 00:00:00 2001 From: Federico Caselli Date: Mon, 23 Jan 2023 22:51:51 +0100 Subject: [PATCH] Add public protocol for mapped class Fixes: #8624 Change-Id: Ia7a66ae9ba534ed7152f95dfd0f7d05b9d00165a --- doc/build/orm/mapping_api.rst | 3 + doc/build/orm/mapping_styles.rst | 2 + lib/sqlalchemy/orm/__init__.py | 1 + lib/sqlalchemy/orm/decl_base.py | 79 ++++++++++++------- .../ext/mypy/plain_files/declared_attr_one.py | 19 +++++ test/orm/declarative/test_basic.py | 18 +++++ 6 files changed, 94 insertions(+), 28 deletions(-) diff --git a/doc/build/orm/mapping_api.rst b/doc/build/orm/mapping_api.rst index 8eebe7c775..1a33f95668 100644 --- a/doc/build/orm/mapping_api.rst +++ b/doc/build/orm/mapping_api.rst @@ -141,3 +141,6 @@ Class Mapping API .. autoclass:: MappedAsDataclass :members: + +.. autoclass:: MappedClassProtocol + :no-members: diff --git a/doc/build/orm/mapping_styles.rst b/doc/build/orm/mapping_styles.rst index b263993934..b4c21a353d 100644 --- a/doc/build/orm/mapping_styles.rst +++ b/doc/build/orm/mapping_styles.rst @@ -36,6 +36,8 @@ the class itself has been :term:`instrumented` to include behaviors linked to relational operations both at the level of the class as well as on instances of that class. As the process is basically the same in all cases, classes mapped from different styles are always fully interoperable with each other. +The protocol :class:`_orm.MappedClassProtocol` can be used to indicate a mapped +class when using type checkers such as mypy. The original mapping API is commonly referred to as "classical" style, whereas the more automated style of mapping is known as "declarative" style. diff --git a/lib/sqlalchemy/orm/__init__.py b/lib/sqlalchemy/orm/__init__.py index 6980db2e24..d54e1ccb9c 100644 --- a/lib/sqlalchemy/orm/__init__.py +++ b/lib/sqlalchemy/orm/__init__.py @@ -65,6 +65,7 @@ from .decl_api import has_inherited_table as has_inherited_table from .decl_api import MappedAsDataclass as MappedAsDataclass from .decl_api import registry as registry from .decl_api import synonym_for as synonym_for +from .decl_base import MappedClassProtocol as MappedClassProtocol from .descriptor_props import Composite as Composite from .descriptor_props import CompositeProperty as CompositeProperty from .descriptor_props import Synonym as Synonym diff --git a/lib/sqlalchemy/orm/decl_base.py b/lib/sqlalchemy/orm/decl_base.py index a379af2ddd..9e8b023597 100644 --- a/lib/sqlalchemy/orm/decl_base.py +++ b/lib/sqlalchemy/orm/decl_base.py @@ -49,7 +49,6 @@ from .interfaces import _IntrospectsAnnotations from .interfaces import _MappedAttribute from .interfaces import _MapsColumns from .interfaces import MapperProperty -from .mapper import Mapper as mapper from .mapper import Mapper from .properties import ColumnProperty from .properties import MappedColumn @@ -84,25 +83,38 @@ if TYPE_CHECKING: _T = TypeVar("_T", bound=Any) _MapperKwArgs = Mapping[str, Any] - _TableArgsType = Union[Tuple[Any, ...], Dict[str, Any]] -class _DeclMappedClassProtocol(Protocol[_O]): - metadata: MetaData +class MappedClassProtocol(Protocol[_O]): + """A protocol representing a SQLAlchemy mapped class. + + The protocol is generic on the type of class, use + ``MappedClassProtocol[Any]`` to allow any mapped class. + """ + + __name__: str __mapper__: Mapper[_O] - __table__: Table + __table__: FromClause + + def __call__(self, **kw: Any) -> _O: + ... + + +class _DeclMappedClassProtocol(MappedClassProtocol[_O], Protocol): + "Internal more detailed version of ``MappedClassProtocol``." + metadata: MetaData __tablename__: str - __mapper_args__: Mapping[str, Any] + __mapper_args__: _MapperKwArgs __table_args__: Optional[_TableArgsType] _sa_apply_dc_transforms: Optional[_DataclassArguments] def __declare_first__(self) -> None: - pass + ... def __declare_last__(self) -> None: - pass + ... class _DataclassArguments(TypedDict): @@ -241,7 +253,7 @@ def _mapper( mapper_kw: _MapperKwArgs, ) -> Mapper[_O]: _ImperativeMapperConfig(registry, cls, table, mapper_kw) - return cast("_DeclMappedClassProtocol[_O]", cls).__mapper__ + return cast("MappedClassProtocol[_O]", cls).__mapper__ @util.preload_module("sqlalchemy.orm.decl_api") @@ -297,7 +309,7 @@ class _MapperConfig: manager = attributes.opt_manager_of_class(cls) if manager and manager.class_ is cls_: raise exc.InvalidRequestError( - "Class %r already has been " "instrumented declaratively" % cls + f"Class {cls!r} already has been instrumented declaratively" ) if cls_.__dict__.get("__abstract__", False): @@ -382,7 +394,7 @@ class _ImperativeMapperConfig(_MapperConfig): self._early_mapping(mapper_kw) def map(self, mapper_kw: _MapperKwArgs = util.EMPTY_DICT) -> Mapper[Any]: - mapper_cls = mapper + mapper_cls = Mapper return self.set_cls_attribute( "__mapper__", @@ -413,7 +425,7 @@ class _ImperativeMapperConfig(_MapperConfig): % (cls, inherits_search) ) inherits = inherits_search[0] - elif isinstance(inherits, mapper): + elif isinstance(inherits, Mapper): inherits = inherits.class_ self.inherits = inherits @@ -567,7 +579,7 @@ class _ClassScanMapperConfig(_MapperConfig): def _setup_declared_events(self) -> None: if _get_immediate_cls_attr(self.cls, "__declare_last__"): - @event.listens_for(mapper, "after_configured") + @event.listens_for(Mapper, "after_configured") def after_configured() -> None: cast( "_DeclMappedClassProtocol[Any]", self.cls @@ -575,7 +587,7 @@ class _ClassScanMapperConfig(_MapperConfig): if _get_immediate_cls_attr(self.cls, "__declare_first__"): - @event.listens_for(mapper, "before_configured") + @event.listens_for(Mapper, "before_configured") def before_configured() -> None: cast( "_DeclMappedClassProtocol[Any]", self.cls @@ -1507,7 +1519,7 @@ class _ClassScanMapperConfig(_MapperConfig): def _setup_table(self, table: Optional[FromClause] = None) -> None: cls = self.cls - cls_as_Decl = cast("_DeclMappedClassProtocol[Any]", cls) + cls_as_Decl = cast("MappedClassProtocol[Any]", cls) tablename = self.tablename table_args = self.table_args @@ -1570,8 +1582,9 @@ class _ClassScanMapperConfig(_MapperConfig): self.local_table = table def _metadata_for_cls(self, manager: ClassManager[Any]) -> MetaData: - if hasattr(self.cls, "metadata"): - return cast("_DeclMappedClassProtocol[Any]", self.cls).metadata + meta: Optional[MetaData] = getattr(self.cls, "metadata", None) + if meta is not None: + return meta else: return manager.registry.metadata @@ -1599,7 +1612,7 @@ class _ClassScanMapperConfig(_MapperConfig): % (cls, inherits_search) ) inherits = inherits_search[0] - elif isinstance(inherits, mapper): + elif isinstance(inherits, Mapper): inherits = inherits.class_ self.inherits = inherits @@ -1701,7 +1714,7 @@ class _ClassScanMapperConfig(_MapperConfig): if "inherits" in mapper_args: inherits_arg = mapper_args["inherits"] - if isinstance(inherits_arg, mapper): + if isinstance(inherits_arg, Mapper): inherits_arg = inherits_arg.class_ if inherits_arg is not self.inherits: @@ -1762,7 +1775,7 @@ class _ClassScanMapperConfig(_MapperConfig): ), ) else: - mapper_cls = mapper + mapper_cls = Mapper return self.set_cls_attribute( "__mapper__", @@ -1873,18 +1886,29 @@ def _add_attribute( """ if "__mapper__" in cls.__dict__: - mapped_cls = cast("_DeclMappedClassProtocol[Any]", cls) + mapped_cls = cast("MappedClassProtocol[Any]", cls) + + def _table_or_raise(mc: MappedClassProtocol[Any]) -> Table: + if isinstance(mc.__table__, Table): + return mc.__table__ + raise exc.InvalidRequestError( + f"Cannot add a new attribute to mapped class {mc.__name__!r} " + "because it's not mapped against a table." + ) + if isinstance(value, Column): _undefer_column_name(key, value) - # TODO: raise for this is not a Table - mapped_cls.__table__.append_column(value, replace_existing=True) + _table_or_raise(mapped_cls).append_column( + value, replace_existing=True + ) mapped_cls.__mapper__.add_property(key, value) elif isinstance(value, _MapsColumns): mp = value.mapper_property_to_assign for col in value.columns_to_assign: _undefer_column_name(key, col) - # TODO: raise for this is not a Table - mapped_cls.__table__.append_column(col, replace_existing=True) + _table_or_raise(mapped_cls).append_column( + col, replace_existing=True + ) if not mp: mapped_cls.__mapper__.add_property(key, col) if mp: @@ -1904,12 +1928,11 @@ def _add_attribute( def _del_attribute(cls: Type[Any], key: str) -> None: - if ( "__mapper__" in cls.__dict__ and key in cls.__dict__ and not cast( - "_DeclMappedClassProtocol[Any]", cls + "MappedClassProtocol[Any]", cls ).__mapper__._dispose_called ): value = cls.__dict__[key] @@ -1922,7 +1945,7 @@ def _del_attribute(cls: Type[Any], key: str) -> None: else: type.__delattr__(cls, key) cast( - "_DeclMappedClassProtocol[Any]", cls + "MappedClassProtocol[Any]", cls ).__mapper__._expire_memoizations() else: type.__delattr__(cls, key) diff --git a/test/ext/mypy/plain_files/declared_attr_one.py b/test/ext/mypy/plain_files/declared_attr_one.py index a6d96f39ee..d4f3c826e6 100644 --- a/test/ext/mypy/plain_files/declared_attr_one.py +++ b/test/ext/mypy/plain_files/declared_attr_one.py @@ -10,6 +10,7 @@ from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import declared_attr from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column +from sqlalchemy.orm import MappedClassProtocol from sqlalchemy.sql.schema import PrimaryKeyConstraint @@ -70,6 +71,24 @@ class Manager(Employee): ) +def do_something_with_mapped_class( + cls_: MappedClassProtocol[Employee], +) -> None: + + # EXPECTED_TYPE: Select[Any] + reveal_type(cls_.__table__.select()) + + # EXPECTED_TYPE: Mapper[Employee] + reveal_type(cls_.__mapper__) + + # EXPECTED_TYPE: Employee + reveal_type(cls_()) + + +do_something_with_mapped_class(Manager) +do_something_with_mapped_class(Engineer) + + if typing.TYPE_CHECKING: # EXPECTED_TYPE: InstrumentedAttribute[datetime] diff --git a/test/orm/declarative/test_basic.py b/test/orm/declarative/test_basic.py index 83d103864f..28fdc97f23 100644 --- a/test/orm/declarative/test_basic.py +++ b/test/orm/declarative/test_basic.py @@ -611,6 +611,24 @@ class DeclarativeBaseSetupsTest(fixtures.TestBase): sa.Column("id", Integer, primary_key=True), ) + def test_cannot_add_to_selectable(self): + class Base(DeclarativeBase): + pass + + class Foo(Base): + __table__ = ( + select(sa.Column("x", sa.Integer, primary_key=True)) + .select_from(sa.table("foo")) + .subquery("foo") + ) + + with assertions.expect_raises_message( + exc.InvalidRequestError, + "Cannot add a new attribute to mapped class 'Foo' " + "because it's not mapped against a table", + ): + Foo.y = mapped_column(sa.Text) + @testing.combinations( ("declarative_base_nometa_superclass",), -- 2.47.2