From f4aaba7b96d2344b60ad88669328e4fe5280f1d3 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Fri, 8 Jul 2022 14:36:25 -0400 Subject: [PATCH] add some typing tests for declared_attr, mixins, decl base to make typing easier, looking at using getattr on table.c rather than getitem for now Change-Id: I7946885071d0b0ddfc06be009f033495f9906de5 --- .../ext/mypy/plain_files/declared_attr_one.py | 69 +++++++++++++++++++ .../ext/mypy/plain_files/declared_attr_two.py | 38 ++++++++++ test/orm/declarative/test_mixin.py | 62 ++++++++++------- 3 files changed, 146 insertions(+), 23 deletions(-) create mode 100644 test/ext/mypy/plain_files/declared_attr_one.py create mode 100644 test/ext/mypy/plain_files/declared_attr_two.py diff --git a/test/ext/mypy/plain_files/declared_attr_one.py b/test/ext/mypy/plain_files/declared_attr_one.py new file mode 100644 index 0000000000..969c86c3cd --- /dev/null +++ b/test/ext/mypy/plain_files/declared_attr_one.py @@ -0,0 +1,69 @@ +from datetime import datetime +import typing + +from sqlalchemy import DateTime +from sqlalchemy import Integer +from sqlalchemy import String +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import declared_attr +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column + + +class Base(DeclarativeBase): + pass + + +class Employee(Base): + __tablename__ = "employee" + id = mapped_column(Integer, primary_key=True) + name = mapped_column(String(50)) + type = mapped_column(String(20)) + + __mapper_args__ = { + "polymorphic_on": type, + "polymorphic_identity": "employee", + } + + +class Engineer(Employee): + __mapper_args__ = { + "polymorphic_identity": "engineer", + } + + @declared_attr + def start_date(cls) -> Mapped[datetime]: + "Start date column, if not present already." + + assert Employee.__table__ is not None + return getattr( + Employee.__table__.c, + "start date", + mapped_column("start date", DateTime), + ) + + +class Manager(Employee): + __mapper_args__ = { + "polymorphic_identity": "manager", + } + + @declared_attr + def start_date(cls) -> Mapped[datetime]: + "Start date column, if not present already." + + assert Employee.__table__ is not None + return getattr( + Employee.__table__.c, + "start date", + mapped_column("start date", DateTime), + ) + + +if typing.TYPE_CHECKING: + + # EXPECTED_TYPE: InstrumentedAttribute[datetime] + reveal_type(Engineer.start_date) + + # EXPECTED_TYPE: InstrumentedAttribute[datetime] + reveal_type(Manager.start_date) diff --git a/test/ext/mypy/plain_files/declared_attr_two.py b/test/ext/mypy/plain_files/declared_attr_two.py new file mode 100644 index 0000000000..86e679e382 --- /dev/null +++ b/test/ext/mypy/plain_files/declared_attr_two.py @@ -0,0 +1,38 @@ +import typing + +from sqlalchemy import Integer +from sqlalchemy import Text +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import declared_attr +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column + + +class Base(DeclarativeBase): + pass + + +class HasRelatedDataMixin: + @declared_attr + def related_data(cls) -> Mapped[str]: + return mapped_column(Text(), deferred=True) + + +class User(HasRelatedDataMixin, Base): + __tablename__ = "user" + id = mapped_column(Integer, primary_key=True) + + +u1 = User() + + +u1.related_data + + +if typing.TYPE_CHECKING: + + # EXPECTED_TYPE: str + reveal_type(u1.related_data) + + # EXPECTED_TYPE: InstrumentedAttribute[str] + reveal_type(User.related_data) diff --git a/test/orm/declarative/test_mixin.py b/test/orm/declarative/test_mixin.py index e6f669c3cd..72e14ceebd 100644 --- a/test/orm/declarative/test_mixin.py +++ b/test/orm/declarative/test_mixin.py @@ -16,6 +16,7 @@ from sqlalchemy.orm import column_property from sqlalchemy.orm import configure_mappers from sqlalchemy.orm import declarative_base from sqlalchemy.orm import declarative_mixin +from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import declared_attr from sqlalchemy.orm import deferred from sqlalchemy.orm import events as orm_events @@ -52,7 +53,9 @@ class DeclarativeTestBase( global Base, mapper_registry mapper_registry = registry(metadata=MetaData()) - Base = mapper_registry.generate_base() + + class Base(DeclarativeBase): + registry = mapper_registry def teardown_test(self): close_all_sessions() @@ -62,12 +65,29 @@ class DeclarativeTestBase( class DeclarativeMixinTest(DeclarativeTestBase): - def test_init_subclass_works(self, registry): - class Base: - def __init_subclass__(cls): - cls.id = Column(Integer, primary_key=True) + @testing.combinations("generate_base", "subclass", argnames="base_type") + def test_init_subclass_works(self, registry, base_type): + reg = registry + if base_type == "generate_base": + + class Base: + def __init_subclass__(cls): + cls.id = Column(Integer, primary_key=True) + + Base = registry.generate_base(cls=Base) + elif base_type == "subclass": + + class Base(DeclarativeBase): + registry = reg + + def __init_subclass__(cls): + cls.id = Column(Integer, primary_key=True) + # hmmm what do we think of this. if DeclarativeBase + # used a full metaclass approach we wouldn't need this. + super().__init_subclass__() - Base = registry.generate_base(cls=Base) + else: + assert False class Foo(Base): __tablename__ = "foo" @@ -127,10 +147,6 @@ class DeclarativeMixinTest(DeclarativeTestBase): eq_(obj.foo(), "bar1") def test_declarative_mixin_decorator(self): - - # note we are also making sure an "old style class" in Python 2, - # as we are now illustrating in all the docs for mixins, doesn't cause - # a problem.... @declarative_mixin class MyMixin: @@ -141,10 +157,8 @@ class DeclarativeMixinTest(DeclarativeTestBase): def foo(self): return "bar" + str(self.id) - # ...as long as the mapped class itself is "new style", which will - # normally be the case for users using declarative_base @mapper_registry.mapped - class MyModel(MyMixin, object): + class MyModel(MyMixin): __tablename__ = "test" name = Column(String(100), nullable=False, index=True) @@ -645,7 +659,8 @@ class DeclarativeMixinTest(DeclarativeTestBase): lambda: testing.against("oracle"), "Test has an empty insert in it at the moment", ) - def test_columns_single_inheritance_conflict_resolution(self): + @testing.combinations(Column, mapped_column, argnames="_column") + def test_columns_single_inheritance_conflict_resolution(self, _column): """Test that a declared_attr can return the existing column and it will be ignored. this allows conditional columns to be added. @@ -655,13 +670,13 @@ class DeclarativeMixinTest(DeclarativeTestBase): class Person(Base): __tablename__ = "person" - id = Column(Integer, primary_key=True) + id = _column(Integer, primary_key=True) class Mixin: @declared_attr def target_id(cls): return cls.__table__.c.get( - "target_id", Column(Integer, ForeignKey("other.id")) + "target_id", _column(Integer, ForeignKey("other.id")) ) @declared_attr @@ -678,7 +693,7 @@ class DeclarativeMixinTest(DeclarativeTestBase): class Other(Base): __tablename__ = "other" - id = Column(Integer, primary_key=True) + id = _column(Integer, primary_key=True) is_( Engineer.target_id.property.columns[0], @@ -697,7 +712,8 @@ class DeclarativeMixinTest(DeclarativeTestBase): session.commit() eq_(session.query(Engineer).first().target, o1) - def test_columns_joined_table_inheritance(self): + @testing.combinations(Column, mapped_column, argnames="_column") + def test_columns_joined_table_inheritance(self, _column): """Test a column on a mixin with an alternate attribute name, mapped to a superclass and joined-table inheritance subclass. Both tables get the column, in the case of the subclass the two @@ -706,18 +722,18 @@ class DeclarativeMixinTest(DeclarativeTestBase): """ class MyMixin: - foo = Column("foo", Integer) - bar = Column("bar_newname", Integer) + foo = _column("foo", Integer) + bar = _column("bar_newname", Integer) class General(Base, MyMixin): __tablename__ = "test" - id = Column(Integer, primary_key=True) - type_ = Column(String(50)) + id = _column(Integer, primary_key=True) + type_ = _column(String(50)) __mapper__args = {"polymorphic_on": type_} class Specific(General): __tablename__ = "sub" - id = Column(Integer, ForeignKey("test.id"), primary_key=True) + id = _column(Integer, ForeignKey("test.id"), primary_key=True) __mapper_args__ = {"polymorphic_identity": "specific"} assert General.bar.prop.columns[0] is General.__table__.c.bar_newname -- 2.47.2