From e30c7ef4e95aea4febbbb51241a03036872d7920 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Tue, 30 Nov 2021 17:12:28 +0100 Subject: [PATCH] =?utf8?q?=E2=9C=A8=20Update=20type=20annotations=20and=20?= =?utf8?q?upgrade=20mypy=20(#173)?= MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit --- pyproject.toml | 6 ++- sqlmodel/engine/create.py | 2 +- sqlmodel/engine/result.py | 8 ++-- sqlmodel/ext/asyncio/session.py | 4 +- sqlmodel/main.py | 73 ++++++++++++++++++------------- sqlmodel/orm/session.py | 6 +-- sqlmodel/sql/base.py | 4 +- sqlmodel/sql/expression.py | 32 +++++++------- sqlmodel/sql/expression.py.jinja2 | 12 ++--- sqlmodel/sql/sqltypes.py | 19 ++++---- 10 files changed, 90 insertions(+), 76 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index fc567909..a8355cf1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ sqlalchemy2-stubs = {version = "*", allow-prereleases = true} [tool.poetry.dev-dependencies] pytest = "^6.2.4" -mypy = "^0.812" +mypy = "^0.910" flake8 = "^3.9.2" black = {version = "^21.5-beta.1", python = "^3.7"} mkdocs = "^1.2.1" @@ -98,3 +98,7 @@ warn_return_any = true implicit_reexport = false strict_equality = true # --strict end + +[[tool.mypy.overrides]] +module = "sqlmodel.sql.expression" +warn_unused_ignores = false diff --git a/sqlmodel/engine/create.py b/sqlmodel/engine/create.py index 97481259..b2d567b1 100644 --- a/sqlmodel/engine/create.py +++ b/sqlmodel/engine/create.py @@ -136,4 +136,4 @@ def create_engine( if not isinstance(query_cache_size, _DefaultPlaceholder): current_kwargs["query_cache_size"] = query_cache_size current_kwargs.update(kwargs) - return _create_engine(url, **current_kwargs) + return _create_engine(url, **current_kwargs) # type: ignore diff --git a/sqlmodel/engine/result.py b/sqlmodel/engine/result.py index d5214275..7a254222 100644 --- a/sqlmodel/engine/result.py +++ b/sqlmodel/engine/result.py @@ -23,7 +23,7 @@ class ScalarResult(_ScalarResult, Generic[_T]): return super().__iter__() def __next__(self) -> _T: - return super().__next__() + return super().__next__() # type: ignore def first(self) -> Optional[_T]: return super().first() @@ -32,7 +32,7 @@ class ScalarResult(_ScalarResult, Generic[_T]): return super().one_or_none() def one(self) -> _T: - return super().one() + return super().one() # type: ignore class Result(_Result, Generic[_T]): @@ -70,10 +70,10 @@ class Result(_Result, Generic[_T]): return super().scalar_one() # type: ignore def scalar_one_or_none(self) -> Optional[_T]: - return super().scalar_one_or_none() # type: ignore + return super().scalar_one_or_none() def one(self) -> _T: # type: ignore return super().one() # type: ignore def scalar(self) -> Optional[_T]: - return super().scalar() # type: ignore + return super().scalar() diff --git a/sqlmodel/ext/asyncio/session.py b/sqlmodel/ext/asyncio/session.py index 40e5b766..80267b25 100644 --- a/sqlmodel/ext/asyncio/session.py +++ b/sqlmodel/ext/asyncio/session.py @@ -21,7 +21,7 @@ class AsyncSession(_AsyncSession): self, bind: Optional[Union[AsyncConnection, AsyncEngine]] = None, binds: Optional[Mapping[object, Union[AsyncConnection, AsyncEngine]]] = None, - **kw, + **kw: Any, ): # All the same code of the original AsyncSession kw["future"] = True @@ -52,7 +52,7 @@ class AsyncSession(_AsyncSession): # util.immutabledict has the union() method. Is this a bug in SQLAlchemy? execution_options = execution_options.union({"prebuffer_rows": True}) # type: ignore - return await greenlet_spawn( # type: ignore + return await greenlet_spawn( self.sync_session.exec, statement, params=params, diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 661276b3..84e26c45 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -101,7 +101,7 @@ class RelationshipInfo(Representation): *, back_populates: Optional[str] = None, link_model: Optional[Any] = None, - sa_relationship: Optional[RelationshipProperty] = None, + sa_relationship: Optional[RelationshipProperty] = None, # type: ignore sa_relationship_args: Optional[Sequence[Any]] = None, sa_relationship_kwargs: Optional[Mapping[str, Any]] = None, ) -> None: @@ -127,32 +127,32 @@ def Field( default: Any = Undefined, *, default_factory: Optional[NoArgAnyCallable] = None, - alias: str = None, - title: str = None, - description: str = None, + alias: Optional[str] = None, + title: Optional[str] = None, + description: Optional[str] = None, exclude: Union[ AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any ] = None, include: Union[ AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any ] = None, - const: bool = None, - gt: float = None, - ge: float = None, - lt: float = None, - le: float = None, - multiple_of: float = None, - min_items: int = None, - max_items: int = None, - min_length: int = None, - max_length: int = None, + const: Optional[bool] = None, + gt: Optional[float] = None, + ge: Optional[float] = None, + lt: Optional[float] = None, + le: Optional[float] = None, + multiple_of: Optional[float] = None, + min_items: Optional[int] = None, + max_items: Optional[int] = None, + min_length: Optional[int] = None, + max_length: Optional[int] = None, allow_mutation: bool = True, - regex: str = None, + regex: Optional[str] = None, primary_key: bool = False, foreign_key: Optional[Any] = None, nullable: Union[bool, UndefinedType] = Undefined, index: Union[bool, UndefinedType] = Undefined, - sa_column: Union[Column, UndefinedType] = Undefined, + sa_column: Union[Column, UndefinedType] = Undefined, # type: ignore sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined, sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined, schema_extra: Optional[Dict[str, Any]] = None, @@ -195,7 +195,7 @@ def Relationship( *, back_populates: Optional[str] = None, link_model: Optional[Any] = None, - sa_relationship: Optional[RelationshipProperty] = None, + sa_relationship: Optional[RelationshipProperty] = None, # type: ignore sa_relationship_args: Optional[Sequence[Any]] = None, sa_relationship_kwargs: Optional[Mapping[str, Any]] = None, ) -> Any: @@ -217,19 +217,25 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta): # Replicate SQLAlchemy def __setattr__(cls, name: str, value: Any) -> None: - if getattr(cls.__config__, "table", False): # type: ignore + if getattr(cls.__config__, "table", False): DeclarativeMeta.__setattr__(cls, name, value) else: super().__setattr__(name, value) def __delattr__(cls, name: str) -> None: - if getattr(cls.__config__, "table", False): # type: ignore + if getattr(cls.__config__, "table", False): DeclarativeMeta.__delattr__(cls, name) else: super().__delattr__(name) # From Pydantic - def __new__(cls, name, bases, class_dict: dict, **kwargs) -> Any: + def __new__( + cls, + name: str, + bases: Tuple[Type[Any], ...], + class_dict: Dict[str, Any], + **kwargs: Any, + ) -> Any: relationships: Dict[str, RelationshipInfo] = {} dict_for_pydantic = {} original_annotations = resolve_annotations( @@ -342,7 +348,7 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta): ) relationship_to = temp_field.type_ if isinstance(temp_field.type_, ForwardRef): - relationship_to = temp_field.type_.__forward_arg__ # type: ignore + relationship_to = temp_field.type_.__forward_arg__ rel_kwargs: Dict[str, Any] = {} if rel_info.back_populates: rel_kwargs["back_populates"] = rel_info.back_populates @@ -360,7 +366,7 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta): rel_args.extend(rel_info.sa_relationship_args) if rel_info.sa_relationship_kwargs: rel_kwargs.update(rel_info.sa_relationship_kwargs) - rel_value: RelationshipProperty = relationship( + rel_value: RelationshipProperty = relationship( # type: ignore relationship_to, *rel_args, **rel_kwargs ) dict_used[rel_name] = rel_value @@ -408,7 +414,7 @@ def get_sqlachemy_type(field: ModelField) -> Any: return GUID -def get_column_from_field(field: ModelField) -> Column: +def get_column_from_field(field: ModelField) -> Column: # type: ignore sa_column = getattr(field.field_info, "sa_column", Undefined) if isinstance(sa_column, Column): return sa_column @@ -440,10 +446,10 @@ def get_column_from_field(field: ModelField) -> Column: kwargs["default"] = sa_default sa_column_args = getattr(field.field_info, "sa_column_args", Undefined) if sa_column_args is not Undefined: - args.extend(list(cast(Sequence, sa_column_args))) + args.extend(list(cast(Sequence[Any], sa_column_args))) sa_column_kwargs = getattr(field.field_info, "sa_column_kwargs", Undefined) if sa_column_kwargs is not Undefined: - kwargs.update(cast(dict, sa_column_kwargs)) + kwargs.update(cast(Dict[Any, Any], sa_column_kwargs)) return Column(sa_type, *args, **kwargs) @@ -452,24 +458,27 @@ class_registry = weakref.WeakValueDictionary() # type: ignore default_registry = registry() -def _value_items_is_true(v) -> bool: +def _value_items_is_true(v: Any) -> bool: # Re-implement Pydantic's ValueItems.is_true() as it hasn't been released as of # the current latest, Pydantic 1.8.2 return v is True or v is ... +_TSQLModel = TypeVar("_TSQLModel", bound="SQLModel") + + class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry): # SQLAlchemy needs to set weakref(s), Pydantic will set the other slots values __slots__ = ("__weakref__",) __tablename__: ClassVar[Union[str, Callable[..., str]]] - __sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty]] + __sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty]] # type: ignore __name__: ClassVar[str] metadata: ClassVar[MetaData] class Config: orm_mode = True - def __new__(cls, *args, **kwargs) -> Any: + def __new__(cls, *args: Any, **kwargs: Any) -> Any: new_object = super().__new__(cls) # SQLAlchemy doesn't call __init__ on the base class # Ref: https://docs.sqlalchemy.org/en/14/orm/constructors.html @@ -520,7 +529,9 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry super().__setattr__(name, value) @classmethod - def from_orm(cls: Type["SQLModel"], obj: Any, update: Dict[str, Any] = None): + def from_orm( + cls: Type[_TSQLModel], obj: Any, update: Optional[Dict[str, Any]] = None + ) -> _TSQLModel: # Duplicated from Pydantic if not cls.__config__.orm_mode: raise ConfigError( @@ -533,7 +544,7 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry # End SQLModel support dict if not getattr(cls.__config__, "table", False): # If not table, normal Pydantic code - m = cls.__new__(cls) + m: _TSQLModel = cls.__new__(cls) else: # If table, create the new instance normally to make SQLAlchemy create # the _sa_instance_state attribute @@ -554,7 +565,7 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry @classmethod def parse_obj( - cls: Type["SQLModel"], obj: Any, update: Dict[str, Any] = None + cls: Type["SQLModel"], obj: Any, update: Optional[Dict[str, Any]] = None ) -> "SQLModel": obj = cls._enforce_dict_if_root(obj) # SQLModel, support update dict diff --git a/sqlmodel/orm/session.py b/sqlmodel/orm/session.py index a5a63e2c..453e0eef 100644 --- a/sqlmodel/orm/session.py +++ b/sqlmodel/orm/session.py @@ -60,7 +60,7 @@ class Session(_Session): results = super().execute( statement, params=params, - execution_options=execution_options, # type: ignore + execution_options=execution_options, bind_arguments=bind_arguments, _parent_execute_state=_parent_execute_state, _add_event=_add_event, @@ -74,7 +74,7 @@ class Session(_Session): self, statement: _Executable, params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None, - execution_options: Mapping[str, Any] = util.EMPTY_DICT, + execution_options: Optional[Mapping[str, Any]] = util.EMPTY_DICT, bind_arguments: Optional[Mapping[str, Any]] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, @@ -101,7 +101,7 @@ class Session(_Session): return super().execute( # type: ignore statement, params=params, - execution_options=execution_options, # type: ignore + execution_options=execution_options, bind_arguments=bind_arguments, _parent_execute_state=_parent_execute_state, _add_event=_add_event, diff --git a/sqlmodel/sql/base.py b/sqlmodel/sql/base.py index 129e4d43..3764a972 100644 --- a/sqlmodel/sql/base.py +++ b/sqlmodel/sql/base.py @@ -6,6 +6,4 @@ _T = TypeVar("_T") class Executable(_Executable, Generic[_T]): - def __init__(self, *args, **kwargs): - self.__dict__["_exec_options"] = kwargs.pop("_exec_options", None) - super(_Executable, self).__init__(*args, **kwargs) + pass diff --git a/sqlmodel/sql/expression.py b/sqlmodel/sql/expression.py index 66063bf2..bf6ea38e 100644 --- a/sqlmodel/sql/expression.py +++ b/sqlmodel/sql/expression.py @@ -45,10 +45,10 @@ else: class GenericSelectMeta(GenericMeta, _Select.__class__): # type: ignore pass - class _Py36Select(_Select, Generic[_TSelect], metaclass=GenericSelectMeta): # type: ignore + class _Py36Select(_Select, Generic[_TSelect], metaclass=GenericSelectMeta): pass - class _Py36SelectOfScalar(_Select, Generic[_TSelect], metaclass=GenericSelectMeta): # type: ignore + class _Py36SelectOfScalar(_Select, Generic[_TSelect], metaclass=GenericSelectMeta): pass # Cast them for editors to work correctly, from several tricks tried, this works @@ -65,9 +65,9 @@ if TYPE_CHECKING: # pragma: no cover _TScalar_0 = TypeVar( "_TScalar_0", - Column, - Sequence, - Mapping, + Column, # type: ignore + Sequence, # type: ignore + Mapping, # type: ignore UUID, datetime, float, @@ -83,9 +83,9 @@ _TModel_0 = TypeVar("_TModel_0", bound="SQLModel") _TScalar_1 = TypeVar( "_TScalar_1", - Column, - Sequence, - Mapping, + Column, # type: ignore + Sequence, # type: ignore + Mapping, # type: ignore UUID, datetime, float, @@ -101,9 +101,9 @@ _TModel_1 = TypeVar("_TModel_1", bound="SQLModel") _TScalar_2 = TypeVar( "_TScalar_2", - Column, - Sequence, - Mapping, + Column, # type: ignore + Sequence, # type: ignore + Mapping, # type: ignore UUID, datetime, float, @@ -119,9 +119,9 @@ _TModel_2 = TypeVar("_TModel_2", bound="SQLModel") _TScalar_3 = TypeVar( "_TScalar_3", - Column, - Sequence, - Mapping, + Column, # type: ignore + Sequence, # type: ignore + Mapping, # type: ignore UUID, datetime, float, @@ -446,14 +446,14 @@ def select( # type: ignore # Generated overloads end -def select(*entities: Any, **kw: Any) -> Union[Select, SelectOfScalar]: +def select(*entities: Any, **kw: Any) -> Union[Select, SelectOfScalar]: # type: ignore if len(entities) == 1: return SelectOfScalar._create(*entities, **kw) # type: ignore return Select._create(*entities, **kw) # type: ignore # TODO: add several @overload from Python types to SQLAlchemy equivalents -def col(column_expression: Any) -> ColumnClause: +def col(column_expression: Any) -> ColumnClause: # type: ignore if not isinstance(column_expression, (ColumnClause, Column, InstrumentedAttribute)): raise RuntimeError(f"Not a SQLAlchemy column: {column_expression}") return column_expression diff --git a/sqlmodel/sql/expression.py.jinja2 b/sqlmodel/sql/expression.py.jinja2 index b39d636e..9cd5d3f3 100644 --- a/sqlmodel/sql/expression.py.jinja2 +++ b/sqlmodel/sql/expression.py.jinja2 @@ -63,9 +63,9 @@ if TYPE_CHECKING: # pragma: no cover {% for i in range(number_of_types) %} _TScalar_{{ i }} = TypeVar( "_TScalar_{{ i }}", - Column, - Sequence, - Mapping, + Column, # type: ignore + Sequence, # type: ignore + Mapping, # type: ignore UUID, datetime, float, @@ -106,14 +106,14 @@ def select( # type: ignore # Generated overloads end -def select(*entities: Any, **kw: Any) -> Union[Select, SelectOfScalar]: +def select(*entities: Any, **kw: Any) -> Union[Select, SelectOfScalar]: # type: ignore if len(entities) == 1: return SelectOfScalar._create(*entities, **kw) # type: ignore - return Select._create(*entities, **kw) + return Select._create(*entities, **kw) # type: ignore # TODO: add several @overload from Python types to SQLAlchemy equivalents -def col(column_expression: Any) -> ColumnClause: +def col(column_expression: Any) -> ColumnClause: # type: ignore if not isinstance(column_expression, (ColumnClause, Column, InstrumentedAttribute)): raise RuntimeError(f"Not a SQLAlchemy column: {column_expression}") return column_expression diff --git a/sqlmodel/sql/sqltypes.py b/sqlmodel/sql/sqltypes.py index e7b77b8c..b3fda877 100644 --- a/sqlmodel/sql/sqltypes.py +++ b/sqlmodel/sql/sqltypes.py @@ -1,13 +1,14 @@ import uuid -from typing import Any, cast +from typing import Any, Optional, cast from sqlalchemy import types from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.engine.interfaces import Dialect +from sqlalchemy.sql.type_api import TypeEngine from sqlalchemy.types import CHAR, TypeDecorator -class AutoString(types.TypeDecorator): +class AutoString(types.TypeDecorator): # type: ignore impl = types.String cache_ok = True @@ -22,7 +23,7 @@ class AutoString(types.TypeDecorator): # Reference form SQLAlchemy docs: https://docs.sqlalchemy.org/en/14/core/custom_types.html#backend-agnostic-guid-type # with small modifications -class GUID(TypeDecorator): +class GUID(TypeDecorator): # type: ignore """Platform-independent GUID type. Uses PostgreSQL's UUID type, otherwise uses @@ -33,13 +34,13 @@ class GUID(TypeDecorator): impl = CHAR cache_ok = True - def load_dialect_impl(self, dialect): + def load_dialect_impl(self, dialect: Dialect) -> TypeEngine: # type: ignore if dialect.name == "postgresql": - return dialect.type_descriptor(UUID()) + return dialect.type_descriptor(UUID()) # type: ignore else: - return dialect.type_descriptor(CHAR(32)) + return dialect.type_descriptor(CHAR(32)) # type: ignore - def process_bind_param(self, value, dialect): + def process_bind_param(self, value: Any, dialect: Dialect) -> Optional[str]: if value is None: return value elif dialect.name == "postgresql": @@ -51,10 +52,10 @@ class GUID(TypeDecorator): # hexstring return f"{value.int:x}" - def process_result_value(self, value, dialect): + def process_result_value(self, value: Any, dialect: Dialect) -> Optional[uuid.UUID]: if value is None: return value else: if not isinstance(value, uuid.UUID): value = uuid.UUID(value) - return value + return cast(uuid.UUID, value) -- 2.47.2