]> git.ipfire.org Git - thirdparty/fastapi/sqlmodel.git/commitdiff
✨ Update type annotations and upgrade mypy (#173)
authorSebastián Ramírez <tiangolo@gmail.com>
Tue, 30 Nov 2021 16:12:28 +0000 (17:12 +0100)
committerGitHub <noreply@github.com>
Tue, 30 Nov 2021 16:12:28 +0000 (17:12 +0100)
pyproject.toml
sqlmodel/engine/create.py
sqlmodel/engine/result.py
sqlmodel/ext/asyncio/session.py
sqlmodel/main.py
sqlmodel/orm/session.py
sqlmodel/sql/base.py
sqlmodel/sql/expression.py
sqlmodel/sql/expression.py.jinja2
sqlmodel/sql/sqltypes.py

index fc567909a8cced47f7b139d575439ee736210dc3..a8355cf1ad2e06d5b366d3aa8f7faf1a71b7a042 100644 (file)
@@ -37,7 +37,7 @@ sqlalchemy2-stubs = {version = "*", allow-prereleases = true}
 
 [tool.poetry.dev-dependencies]
 pytest = "^6.2.4"
-mypy = "^0.812"
+mypy = "^0.910"
 flake8 = "^3.9.2"
 black = {version = "^21.5-beta.1", python = "^3.7"}
 mkdocs = "^1.2.1"
@@ -98,3 +98,7 @@ warn_return_any = true
 implicit_reexport = false
 strict_equality = true
 # --strict end
+
+[[tool.mypy.overrides]]
+module = "sqlmodel.sql.expression"
+warn_unused_ignores = false
index 97481259e28695f577a7bd1f739acbb9220715e0..b2d567b1b1a4c2b813b7d01962442b697388651c 100644 (file)
@@ -136,4 +136,4 @@ def create_engine(
     if not isinstance(query_cache_size, _DefaultPlaceholder):
         current_kwargs["query_cache_size"] = query_cache_size
     current_kwargs.update(kwargs)
-    return _create_engine(url, **current_kwargs)
+    return _create_engine(url, **current_kwargs)  # type: ignore
index d521427581bbc9545090905454d07a4508e68d54..7a25422227a42ac3b4d98fbf69824df16f9ffba4 100644 (file)
@@ -23,7 +23,7 @@ class ScalarResult(_ScalarResult, Generic[_T]):
         return super().__iter__()
 
     def __next__(self) -> _T:
-        return super().__next__()
+        return super().__next__()  # type: ignore
 
     def first(self) -> Optional[_T]:
         return super().first()
@@ -32,7 +32,7 @@ class ScalarResult(_ScalarResult, Generic[_T]):
         return super().one_or_none()
 
     def one(self) -> _T:
-        return super().one()
+        return super().one()  # type: ignore
 
 
 class Result(_Result, Generic[_T]):
@@ -70,10 +70,10 @@ class Result(_Result, Generic[_T]):
         return super().scalar_one()  # type: ignore
 
     def scalar_one_or_none(self) -> Optional[_T]:
-        return super().scalar_one_or_none()  # type: ignore
+        return super().scalar_one_or_none()
 
     def one(self) -> _T:  # type: ignore
         return super().one()  # type: ignore
 
     def scalar(self) -> Optional[_T]:
-        return super().scalar()  # type: ignore
+        return super().scalar()
index 40e5b766e9145fd4a124c0c84f2c8b00709f100c..80267b25e5243fa4e0709981f35671bc4ffc1a0a 100644 (file)
@@ -21,7 +21,7 @@ class AsyncSession(_AsyncSession):
         self,
         bind: Optional[Union[AsyncConnection, AsyncEngine]] = None,
         binds: Optional[Mapping[object, Union[AsyncConnection, AsyncEngine]]] = None,
-        **kw,
+        **kw: Any,
     ):
         # All the same code of the original AsyncSession
         kw["future"] = True
@@ -52,7 +52,7 @@ class AsyncSession(_AsyncSession):
         # util.immutabledict has the union() method. Is this a bug in SQLAlchemy?
         execution_options = execution_options.union({"prebuffer_rows": True})  # type: ignore
 
-        return await greenlet_spawn(  # type: ignore
+        return await greenlet_spawn(
             self.sync_session.exec,
             statement,
             params=params,
index 661276b31d30af565dbbefbfb0e26065423db682..84e26c4532f48e3aa74d6a1a6f7a720e6344aff1 100644 (file)
@@ -101,7 +101,7 @@ class RelationshipInfo(Representation):
         *,
         back_populates: Optional[str] = None,
         link_model: Optional[Any] = None,
-        sa_relationship: Optional[RelationshipProperty] = None,
+        sa_relationship: Optional[RelationshipProperty] = None,  # type: ignore
         sa_relationship_args: Optional[Sequence[Any]] = None,
         sa_relationship_kwargs: Optional[Mapping[str, Any]] = None,
     ) -> None:
@@ -127,32 +127,32 @@ def Field(
     default: Any = Undefined,
     *,
     default_factory: Optional[NoArgAnyCallable] = None,
-    alias: str = None,
-    title: str = None,
-    description: str = None,
+    alias: Optional[str] = None,
+    title: Optional[str] = None,
+    description: Optional[str] = None,
     exclude: Union[
         AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
     ] = None,
     include: Union[
         AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
     ] = None,
-    const: bool = None,
-    gt: float = None,
-    ge: float = None,
-    lt: float = None,
-    le: float = None,
-    multiple_of: float = None,
-    min_items: int = None,
-    max_items: int = None,
-    min_length: int = None,
-    max_length: int = None,
+    const: Optional[bool] = None,
+    gt: Optional[float] = None,
+    ge: Optional[float] = None,
+    lt: Optional[float] = None,
+    le: Optional[float] = None,
+    multiple_of: Optional[float] = None,
+    min_items: Optional[int] = None,
+    max_items: Optional[int] = None,
+    min_length: Optional[int] = None,
+    max_length: Optional[int] = None,
     allow_mutation: bool = True,
-    regex: str = None,
+    regex: Optional[str] = None,
     primary_key: bool = False,
     foreign_key: Optional[Any] = None,
     nullable: Union[bool, UndefinedType] = Undefined,
     index: Union[bool, UndefinedType] = Undefined,
-    sa_column: Union[Column, UndefinedType] = Undefined,
+    sa_column: Union[Column, UndefinedType] = Undefined,  # type: ignore
     sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined,
     sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined,
     schema_extra: Optional[Dict[str, Any]] = None,
@@ -195,7 +195,7 @@ def Relationship(
     *,
     back_populates: Optional[str] = None,
     link_model: Optional[Any] = None,
-    sa_relationship: Optional[RelationshipProperty] = None,
+    sa_relationship: Optional[RelationshipProperty] = None,  # type: ignore
     sa_relationship_args: Optional[Sequence[Any]] = None,
     sa_relationship_kwargs: Optional[Mapping[str, Any]] = None,
 ) -> Any:
@@ -217,19 +217,25 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
 
     # Replicate SQLAlchemy
     def __setattr__(cls, name: str, value: Any) -> None:
-        if getattr(cls.__config__, "table", False):  # type: ignore
+        if getattr(cls.__config__, "table", False):
             DeclarativeMeta.__setattr__(cls, name, value)
         else:
             super().__setattr__(name, value)
 
     def __delattr__(cls, name: str) -> None:
-        if getattr(cls.__config__, "table", False):  # type: ignore
+        if getattr(cls.__config__, "table", False):
             DeclarativeMeta.__delattr__(cls, name)
         else:
             super().__delattr__(name)
 
     # From Pydantic
-    def __new__(cls, name, bases, class_dict: dict, **kwargs) -> Any:
+    def __new__(
+        cls,
+        name: str,
+        bases: Tuple[Type[Any], ...],
+        class_dict: Dict[str, Any],
+        **kwargs: Any,
+    ) -> Any:
         relationships: Dict[str, RelationshipInfo] = {}
         dict_for_pydantic = {}
         original_annotations = resolve_annotations(
@@ -342,7 +348,7 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
                 )
                 relationship_to = temp_field.type_
                 if isinstance(temp_field.type_, ForwardRef):
-                    relationship_to = temp_field.type_.__forward_arg__  # type: ignore
+                    relationship_to = temp_field.type_.__forward_arg__
                 rel_kwargs: Dict[str, Any] = {}
                 if rel_info.back_populates:
                     rel_kwargs["back_populates"] = rel_info.back_populates
@@ -360,7 +366,7 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
                     rel_args.extend(rel_info.sa_relationship_args)
                 if rel_info.sa_relationship_kwargs:
                     rel_kwargs.update(rel_info.sa_relationship_kwargs)
-                rel_value: RelationshipProperty = relationship(
+                rel_value: RelationshipProperty = relationship(  # type: ignore
                     relationship_to, *rel_args, **rel_kwargs
                 )
                 dict_used[rel_name] = rel_value
@@ -408,7 +414,7 @@ def get_sqlachemy_type(field: ModelField) -> Any:
         return GUID
 
 
-def get_column_from_field(field: ModelField) -> Column:
+def get_column_from_field(field: ModelField) -> Column:  # type: ignore
     sa_column = getattr(field.field_info, "sa_column", Undefined)
     if isinstance(sa_column, Column):
         return sa_column
@@ -440,10 +446,10 @@ def get_column_from_field(field: ModelField) -> Column:
         kwargs["default"] = sa_default
     sa_column_args = getattr(field.field_info, "sa_column_args", Undefined)
     if sa_column_args is not Undefined:
-        args.extend(list(cast(Sequence, sa_column_args)))
+        args.extend(list(cast(Sequence[Any], sa_column_args)))
     sa_column_kwargs = getattr(field.field_info, "sa_column_kwargs", Undefined)
     if sa_column_kwargs is not Undefined:
-        kwargs.update(cast(dict, sa_column_kwargs))
+        kwargs.update(cast(Dict[Any, Any], sa_column_kwargs))
     return Column(sa_type, *args, **kwargs)
 
 
@@ -452,24 +458,27 @@ class_registry = weakref.WeakValueDictionary()  # type: ignore
 default_registry = registry()
 
 
-def _value_items_is_true(v) -> bool:
+def _value_items_is_true(v: Any) -> bool:
     # Re-implement Pydantic's ValueItems.is_true() as it hasn't been released as of
     # the current latest, Pydantic 1.8.2
     return v is True or v is ...
 
 
+_TSQLModel = TypeVar("_TSQLModel", bound="SQLModel")
+
+
 class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry):
     # SQLAlchemy needs to set weakref(s), Pydantic will set the other slots values
     __slots__ = ("__weakref__",)
     __tablename__: ClassVar[Union[str, Callable[..., str]]]
-    __sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty]]
+    __sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty]]  # type: ignore
     __name__: ClassVar[str]
     metadata: ClassVar[MetaData]
 
     class Config:
         orm_mode = True
 
-    def __new__(cls, *args, **kwargs) -> Any:
+    def __new__(cls, *args: Any, **kwargs: Any) -> Any:
         new_object = super().__new__(cls)
         # SQLAlchemy doesn't call __init__ on the base class
         # Ref: https://docs.sqlalchemy.org/en/14/orm/constructors.html
@@ -520,7 +529,9 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry
                 super().__setattr__(name, value)
 
     @classmethod
-    def from_orm(cls: Type["SQLModel"], obj: Any, update: Dict[str, Any] = None):
+    def from_orm(
+        cls: Type[_TSQLModel], obj: Any, update: Optional[Dict[str, Any]] = None
+    ) -> _TSQLModel:
         # Duplicated from Pydantic
         if not cls.__config__.orm_mode:
             raise ConfigError(
@@ -533,7 +544,7 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry
         # End SQLModel support dict
         if not getattr(cls.__config__, "table", False):
             # If not table, normal Pydantic code
-            m = cls.__new__(cls)
+            m: _TSQLModel = cls.__new__(cls)
         else:
             # If table, create the new instance normally to make SQLAlchemy create
             # the _sa_instance_state attribute
@@ -554,7 +565,7 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry
 
     @classmethod
     def parse_obj(
-        cls: Type["SQLModel"], obj: Any, update: Dict[str, Any] = None
+        cls: Type["SQLModel"], obj: Any, update: Optional[Dict[str, Any]] = None
     ) -> "SQLModel":
         obj = cls._enforce_dict_if_root(obj)
         # SQLModel, support update dict
index a5a63e2c69b2e342967ea1f2dfa8627b0e6d48db..453e0eefafab3a1cc1e4931e381b2c9d74cd371a 100644 (file)
@@ -60,7 +60,7 @@ class Session(_Session):
         results = super().execute(
             statement,
             params=params,
-            execution_options=execution_options,  # type: ignore
+            execution_options=execution_options,
             bind_arguments=bind_arguments,
             _parent_execute_state=_parent_execute_state,
             _add_event=_add_event,
@@ -74,7 +74,7 @@ class Session(_Session):
         self,
         statement: _Executable,
         params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
-        execution_options: Mapping[str, Any] = util.EMPTY_DICT,
+        execution_options: Optional[Mapping[str, Any]] = util.EMPTY_DICT,
         bind_arguments: Optional[Mapping[str, Any]] = None,
         _parent_execute_state: Optional[Any] = None,
         _add_event: Optional[Any] = None,
@@ -101,7 +101,7 @@ class Session(_Session):
         return super().execute(  # type: ignore
             statement,
             params=params,
-            execution_options=execution_options,  # type: ignore
+            execution_options=execution_options,
             bind_arguments=bind_arguments,
             _parent_execute_state=_parent_execute_state,
             _add_event=_add_event,
index 129e4d43d77b0e0425f1268d0937f43ba64c0a04..3764a9721d81a122868f8be00ab3e788cfe94cf8 100644 (file)
@@ -6,6 +6,4 @@ _T = TypeVar("_T")
 
 
 class Executable(_Executable, Generic[_T]):
-    def __init__(self, *args, **kwargs):
-        self.__dict__["_exec_options"] = kwargs.pop("_exec_options", None)
-        super(_Executable, self).__init__(*args, **kwargs)
+    pass
index 66063bf2364a4cd655557e7a8fe46db0847698c4..bf6ea38ec68fd42b0836d6a740a51dc6b54518e6 100644 (file)
@@ -45,10 +45,10 @@ else:
     class GenericSelectMeta(GenericMeta, _Select.__class__):  # type: ignore
         pass
 
-    class _Py36Select(_Select, Generic[_TSelect], metaclass=GenericSelectMeta):  # type: ignore
+    class _Py36Select(_Select, Generic[_TSelect], metaclass=GenericSelectMeta):
         pass
 
-    class _Py36SelectOfScalar(_Select, Generic[_TSelect], metaclass=GenericSelectMeta):  # type: ignore
+    class _Py36SelectOfScalar(_Select, Generic[_TSelect], metaclass=GenericSelectMeta):
         pass
 
     # Cast them for editors to work correctly, from several tricks tried, this works
@@ -65,9 +65,9 @@ if TYPE_CHECKING:  # pragma: no cover
 
 _TScalar_0 = TypeVar(
     "_TScalar_0",
-    Column,
-    Sequence,
-    Mapping,
+    Column,  # type: ignore
+    Sequence,  # type: ignore
+    Mapping,  # type: ignore
     UUID,
     datetime,
     float,
@@ -83,9 +83,9 @@ _TModel_0 = TypeVar("_TModel_0", bound="SQLModel")
 
 _TScalar_1 = TypeVar(
     "_TScalar_1",
-    Column,
-    Sequence,
-    Mapping,
+    Column,  # type: ignore
+    Sequence,  # type: ignore
+    Mapping,  # type: ignore
     UUID,
     datetime,
     float,
@@ -101,9 +101,9 @@ _TModel_1 = TypeVar("_TModel_1", bound="SQLModel")
 
 _TScalar_2 = TypeVar(
     "_TScalar_2",
-    Column,
-    Sequence,
-    Mapping,
+    Column,  # type: ignore
+    Sequence,  # type: ignore
+    Mapping,  # type: ignore
     UUID,
     datetime,
     float,
@@ -119,9 +119,9 @@ _TModel_2 = TypeVar("_TModel_2", bound="SQLModel")
 
 _TScalar_3 = TypeVar(
     "_TScalar_3",
-    Column,
-    Sequence,
-    Mapping,
+    Column,  # type: ignore
+    Sequence,  # type: ignore
+    Mapping,  # type: ignore
     UUID,
     datetime,
     float,
@@ -446,14 +446,14 @@ def select(  # type: ignore
 # Generated overloads end
 
 
-def select(*entities: Any, **kw: Any) -> Union[Select, SelectOfScalar]:
+def select(*entities: Any, **kw: Any) -> Union[Select, SelectOfScalar]:  # type: ignore
     if len(entities) == 1:
         return SelectOfScalar._create(*entities, **kw)  # type: ignore
     return Select._create(*entities, **kw)  # type: ignore
 
 
 # TODO: add several @overload from Python types to SQLAlchemy equivalents
-def col(column_expression: Any) -> ColumnClause:
+def col(column_expression: Any) -> ColumnClause:  # type: ignore
     if not isinstance(column_expression, (ColumnClause, Column, InstrumentedAttribute)):
         raise RuntimeError(f"Not a SQLAlchemy column: {column_expression}")
     return column_expression
index b39d636ea22fda722f8857b7d3adfbbdb7761758..9cd5d3f33e4a13aae88d0af8b6aedfb713c8bb1b 100644 (file)
@@ -63,9 +63,9 @@ if TYPE_CHECKING:  # pragma: no cover
 {% for i in range(number_of_types) %}
 _TScalar_{{ i }} = TypeVar(
     "_TScalar_{{ i }}",
-    Column,
-    Sequence,
-    Mapping,
+    Column,  # type: ignore
+    Sequence,  # type: ignore
+    Mapping,  # type: ignore
     UUID,
     datetime,
     float,
@@ -106,14 +106,14 @@ def select(  # type: ignore
 
 # Generated overloads end
 
-def select(*entities: Any, **kw: Any) -> Union[Select, SelectOfScalar]:
+def select(*entities: Any, **kw: Any) -> Union[Select, SelectOfScalar]:  # type: ignore
     if len(entities) == 1:
         return SelectOfScalar._create(*entities, **kw)  # type: ignore
-    return Select._create(*entities, **kw)
+    return Select._create(*entities, **kw)  # type: ignore
 
 
 # TODO: add several @overload from Python types to SQLAlchemy equivalents
-def col(column_expression: Any) -> ColumnClause:
+def col(column_expression: Any) -> ColumnClause:  # type: ignore
     if not isinstance(column_expression, (ColumnClause, Column, InstrumentedAttribute)):
         raise RuntimeError(f"Not a SQLAlchemy column: {column_expression}")
     return column_expression
index e7b77b8c524c92f63d81f4da19f7c20f0ef1e45d..b3fda87739f79aa683cefca1950b473f0422139f 100644 (file)
@@ -1,13 +1,14 @@
 import uuid
-from typing import Any, cast
+from typing import Any, Optional, cast
 
 from sqlalchemy import types
 from sqlalchemy.dialects.postgresql import UUID
 from sqlalchemy.engine.interfaces import Dialect
+from sqlalchemy.sql.type_api import TypeEngine
 from sqlalchemy.types import CHAR, TypeDecorator
 
 
-class AutoString(types.TypeDecorator):
+class AutoString(types.TypeDecorator):  # type: ignore
 
     impl = types.String
     cache_ok = True
@@ -22,7 +23,7 @@ class AutoString(types.TypeDecorator):
 
 # Reference form SQLAlchemy docs: https://docs.sqlalchemy.org/en/14/core/custom_types.html#backend-agnostic-guid-type
 # with small modifications
-class GUID(TypeDecorator):
+class GUID(TypeDecorator):  # type: ignore
     """Platform-independent GUID type.
 
     Uses PostgreSQL's UUID type, otherwise uses
@@ -33,13 +34,13 @@ class GUID(TypeDecorator):
     impl = CHAR
     cache_ok = True
 
-    def load_dialect_impl(self, dialect):
+    def load_dialect_impl(self, dialect: Dialect) -> TypeEngine:  # type: ignore
         if dialect.name == "postgresql":
-            return dialect.type_descriptor(UUID())
+            return dialect.type_descriptor(UUID())  # type: ignore
         else:
-            return dialect.type_descriptor(CHAR(32))
+            return dialect.type_descriptor(CHAR(32))  # type: ignore
 
-    def process_bind_param(self, value, dialect):
+    def process_bind_param(self, value: Any, dialect: Dialect) -> Optional[str]:
         if value is None:
             return value
         elif dialect.name == "postgresql":
@@ -51,10 +52,10 @@ class GUID(TypeDecorator):
                 # hexstring
                 return f"{value.int:x}"
 
-    def process_result_value(self, value, dialect):
+    def process_result_value(self, value: Any, dialect: Dialect) -> Optional[uuid.UUID]:
         if value is None:
             return value
         else:
             if not isinstance(value, uuid.UUID):
                 value = uuid.UUID(value)
-            return value
+            return cast(uuid.UUID, value)