]> git.ipfire.org Git - thirdparty/fastapi/sqlmodel.git/commitdiff
✨ Upgrade SQLAlchemy to 2.0, including initial work by farahats9 (#700)
authorSebastián Ramírez <tiangolo@gmail.com>
Sat, 18 Nov 2023 11:30:37 +0000 (12:30 +0100)
committerGitHub <noreply@github.com>
Sat, 18 Nov 2023 11:30:37 +0000 (12:30 +0100)
Co-authored-by: Mohamed Farahat <farahats9@yahoo.com>
Co-authored-by: Stefan Borer <stefan.borer@gmail.com>
Co-authored-by: Peter Landry <peter.landry@gmail.com>
24 files changed:
.github/workflows/test.yml
pyproject.toml
scripts/generate_select.py
sqlmodel/__init__.py
sqlmodel/engine/__init__.py [deleted file]
sqlmodel/engine/create.py [deleted file]
sqlmodel/engine/result.py [deleted file]
sqlmodel/ext/asyncio/session.py
sqlmodel/main.py
sqlmodel/orm/session.py
sqlmodel/sql/expression.py
sqlmodel/sql/expression.py.jinja2
sqlmodel/sql/sqltypes.py
tests/test_tutorial/test_fastapi/test_delete/test_tutorial001.py
tests/test_tutorial/test_fastapi/test_limit_and_offset/test_tutorial001.py
tests/test_tutorial/test_fastapi/test_multiple_models/test_tutorial001.py
tests/test_tutorial/test_fastapi/test_multiple_models/test_tutorial002.py
tests/test_tutorial/test_fastapi/test_read_one/test_tutorial001.py
tests/test_tutorial/test_fastapi/test_relationships/test_tutorial001.py
tests/test_tutorial/test_fastapi/test_response_model/test_tutorial001.py
tests/test_tutorial/test_fastapi/test_session_with_dependency/test_tutorial001.py
tests/test_tutorial/test_fastapi/test_simple_hero_api/test_tutorial001.py
tests/test_tutorial/test_fastapi/test_teams/test_tutorial001.py
tests/test_tutorial/test_fastapi/test_update/test_tutorial001.py

index 201abc7c22f8dc3c1d1bb786e5ee742f7f3c6185..c3b07f484eff0f71b961e23f3995c5dfca8784db 100644 (file)
@@ -56,6 +56,8 @@ jobs:
         if: steps.cache.outputs.cache-hit != 'true'
         run: python -m poetry install
       - name: Lint
+        # Do not run on Python 3.7 as mypy behaves differently
+        if: matrix.python-version != '3.7'
         run: python -m poetry run bash scripts/lint.sh
       - run: mkdir coverage
       - name: Test
index 23fa79bf31ddb63799939315cf7a0b1d334da3fa..515bbaf66ce99bab7ab0c251fad9aa5e902cc003 100644 (file)
@@ -31,9 +31,8 @@ classifiers = [
 
 [tool.poetry.dependencies]
 python = "^3.7"
-SQLAlchemy = ">=1.4.36,<2.0.0"
+SQLAlchemy = ">=2.0.0,<2.1.0"
 pydantic = "^1.9.0"
-sqlalchemy2-stubs = {version = "*", allow-prereleases = true}
 
 [tool.poetry.group.dev.dependencies]
 pytest = "^7.0.1"
@@ -45,9 +44,10 @@ pillow = "^9.3.0"
 cairosvg = "^2.5.2"
 mdx-include = "^1.4.1"
 coverage = {extras = ["toml"], version = ">=6.2,<8.0"}
-fastapi = "^0.68.1"
-requests = "^2.26.0"
+fastapi = "^0.103.2"
 ruff = "^0.1.2"
+# For FastAPI tests
+httpx = "0.24.1"
 
 [build-system]
 requires = ["poetry-core"]
@@ -80,6 +80,12 @@ strict = true
 module = "sqlmodel.sql.expression"
 warn_unused_ignores = false
 
+[[tool.mypy.overrides]]
+module = "docs_src.*"
+disallow_incomplete_defs = false
+disallow_untyped_defs = false
+disallow_untyped_calls = false
+
 [tool.ruff]
 select = [
     "E",  # pycodestyle errors
index f8aa30023fad1d3d251b8e7b920fcd27aa0aa16c..88e0e0a997cbe32d84600a59e5219826e77708a7 100644 (file)
@@ -34,9 +34,9 @@ for total_args in range(2, number_of_types + 1):
                 arg = Arg(name=f"entity_{i}", annotation=t_var)
                 ret_type = t_var
             else:
-                t_type = f"_TModel_{i}"
-                t_var = f"Type[{t_type}]"
-                arg = Arg(name=f"entity_{i}", annotation=t_var)
+                t_type = f"_T{i}"
+                t_var = f"_TCCA[{t_type}]"
+                arg = Arg(name=f"__ent{i}", annotation=t_var)
                 ret_type = t_type
             args.append(arg)
             return_types.append(ret_type)
index 495ac9c8a810cb223b51693b04b0a7ebab99f8e0..e9432571651addf66d00c12312a9a21e1c5f5401 100644 (file)
@@ -1,9 +1,12 @@
 __version__ = "0.0.11"
 
 # Re-export from SQLAlchemy
+from sqlalchemy.engine import create_engine as create_engine
 from sqlalchemy.engine import create_mock_engine as create_mock_engine
 from sqlalchemy.engine import engine_from_config as engine_from_config
 from sqlalchemy.inspection import inspect as inspect
+from sqlalchemy.pool import QueuePool as QueuePool
+from sqlalchemy.pool import StaticPool as StaticPool
 from sqlalchemy.schema import BLANK_SCHEMA as BLANK_SCHEMA
 from sqlalchemy.schema import DDL as DDL
 from sqlalchemy.schema import CheckConstraint as CheckConstraint
@@ -21,7 +24,6 @@ from sqlalchemy.schema import MetaData as MetaData
 from sqlalchemy.schema import PrimaryKeyConstraint as PrimaryKeyConstraint
 from sqlalchemy.schema import Sequence as Sequence
 from sqlalchemy.schema import Table as Table
-from sqlalchemy.schema import ThreadLocalMetaData as ThreadLocalMetaData
 from sqlalchemy.schema import UniqueConstraint as UniqueConstraint
 from sqlalchemy.sql import LABEL_STYLE_DEFAULT as LABEL_STYLE_DEFAULT
 from sqlalchemy.sql import (
@@ -32,26 +34,14 @@ from sqlalchemy.sql import (
     LABEL_STYLE_TABLENAME_PLUS_COL as LABEL_STYLE_TABLENAME_PLUS_COL,
 )
 from sqlalchemy.sql import alias as alias
-from sqlalchemy.sql import all_ as all_
-from sqlalchemy.sql import and_ as and_
-from sqlalchemy.sql import any_ as any_
-from sqlalchemy.sql import asc as asc
-from sqlalchemy.sql import between as between
 from sqlalchemy.sql import bindparam as bindparam
-from sqlalchemy.sql import case as case
-from sqlalchemy.sql import cast as cast
-from sqlalchemy.sql import collate as collate
 from sqlalchemy.sql import column as column
 from sqlalchemy.sql import delete as delete
-from sqlalchemy.sql import desc as desc
-from sqlalchemy.sql import distinct as distinct
 from sqlalchemy.sql import except_ as except_
 from sqlalchemy.sql import except_all as except_all
 from sqlalchemy.sql import exists as exists
-from sqlalchemy.sql import extract as extract
 from sqlalchemy.sql import false as false
 from sqlalchemy.sql import func as func
-from sqlalchemy.sql import funcfilter as funcfilter
 from sqlalchemy.sql import insert as insert
 from sqlalchemy.sql import intersect as intersect
 from sqlalchemy.sql import intersect_all as intersect_all
@@ -61,28 +51,19 @@ from sqlalchemy.sql import lateral as lateral
 from sqlalchemy.sql import literal as literal
 from sqlalchemy.sql import literal_column as literal_column
 from sqlalchemy.sql import modifier as modifier
-from sqlalchemy.sql import not_ as not_
 from sqlalchemy.sql import null as null
-from sqlalchemy.sql import nulls_first as nulls_first
-from sqlalchemy.sql import nulls_last as nulls_last
 from sqlalchemy.sql import nullsfirst as nullsfirst
 from sqlalchemy.sql import nullslast as nullslast
-from sqlalchemy.sql import or_ as or_
 from sqlalchemy.sql import outerjoin as outerjoin
 from sqlalchemy.sql import outparam as outparam
-from sqlalchemy.sql import over as over
-from sqlalchemy.sql import subquery as subquery
 from sqlalchemy.sql import table as table
 from sqlalchemy.sql import tablesample as tablesample
 from sqlalchemy.sql import text as text
 from sqlalchemy.sql import true as true
-from sqlalchemy.sql import tuple_ as tuple_
-from sqlalchemy.sql import type_coerce as type_coerce
 from sqlalchemy.sql import union as union
 from sqlalchemy.sql import union_all as union_all
 from sqlalchemy.sql import update as update
 from sqlalchemy.sql import values as values
-from sqlalchemy.sql import within_group as within_group
 from sqlalchemy.types import ARRAY as ARRAY
 from sqlalchemy.types import BIGINT as BIGINT
 from sqlalchemy.types import BINARY as BINARY
@@ -93,6 +74,8 @@ from sqlalchemy.types import CLOB as CLOB
 from sqlalchemy.types import DATE as DATE
 from sqlalchemy.types import DATETIME as DATETIME
 from sqlalchemy.types import DECIMAL as DECIMAL
+from sqlalchemy.types import DOUBLE as DOUBLE
+from sqlalchemy.types import DOUBLE_PRECISION as DOUBLE_PRECISION
 from sqlalchemy.types import FLOAT as FLOAT
 from sqlalchemy.types import INT as INT
 from sqlalchemy.types import INTEGER as INTEGER
@@ -105,12 +88,14 @@ from sqlalchemy.types import SMALLINT as SMALLINT
 from sqlalchemy.types import TEXT as TEXT
 from sqlalchemy.types import TIME as TIME
 from sqlalchemy.types import TIMESTAMP as TIMESTAMP
+from sqlalchemy.types import UUID as UUID
 from sqlalchemy.types import VARBINARY as VARBINARY
 from sqlalchemy.types import VARCHAR as VARCHAR
 from sqlalchemy.types import BigInteger as BigInteger
 from sqlalchemy.types import Boolean as Boolean
 from sqlalchemy.types import Date as Date
 from sqlalchemy.types import DateTime as DateTime
+from sqlalchemy.types import Double as Double
 from sqlalchemy.types import Enum as Enum
 from sqlalchemy.types import Float as Float
 from sqlalchemy.types import Integer as Integer
@@ -122,16 +107,38 @@ from sqlalchemy.types import SmallInteger as SmallInteger
 from sqlalchemy.types import String as String
 from sqlalchemy.types import Text as Text
 from sqlalchemy.types import Time as Time
+from sqlalchemy.types import TupleType as TupleType
 from sqlalchemy.types import TypeDecorator as TypeDecorator
 from sqlalchemy.types import Unicode as Unicode
 from sqlalchemy.types import UnicodeText as UnicodeText
+from sqlalchemy.types import Uuid as Uuid
 
 # From SQLModel, modifications of SQLAlchemy or equivalents of Pydantic
-from .engine.create import create_engine as create_engine
 from .main import Field as Field
 from .main import Relationship as Relationship
 from .main import SQLModel as SQLModel
 from .orm.session import Session as Session
+from .sql.expression import all_ as all_
+from .sql.expression import and_ as and_
+from .sql.expression import any_ as any_
+from .sql.expression import asc as asc
+from .sql.expression import between as between
+from .sql.expression import case as case
+from .sql.expression import cast as cast
 from .sql.expression import col as col
+from .sql.expression import collate as collate
+from .sql.expression import desc as desc
+from .sql.expression import distinct as distinct
+from .sql.expression import extract as extract
+from .sql.expression import funcfilter as funcfilter
+from .sql.expression import not_ as not_
+from .sql.expression import nulls_first as nulls_first
+from .sql.expression import nulls_last as nulls_last
+from .sql.expression import or_ as or_
+from .sql.expression import over as over
 from .sql.expression import select as select
+from .sql.expression import tuple_ as tuple_
+from .sql.expression import type_coerce as type_coerce
+from .sql.expression import within_group as within_group
+from .sql.sqltypes import GUID as GUID
 from .sql.sqltypes import AutoString as AutoString
diff --git a/sqlmodel/engine/__init__.py b/sqlmodel/engine/__init__.py
deleted file mode 100644 (file)
index e69de29..0000000
diff --git a/sqlmodel/engine/create.py b/sqlmodel/engine/create.py
deleted file mode 100644 (file)
index b2d567b..0000000
+++ /dev/null
@@ -1,139 +0,0 @@
-import json
-import sqlite3
-from typing import Any, Callable, Dict, List, Optional, Type, Union
-
-from sqlalchemy import create_engine as _create_engine
-from sqlalchemy.engine.url import URL
-from sqlalchemy.future import Engine as _FutureEngine
-from sqlalchemy.pool import Pool
-from typing_extensions import Literal, TypedDict
-
-from ..default import Default, _DefaultPlaceholder
-
-# Types defined in sqlalchemy2-stubs, but can't be imported, so re-define here
-
-_Debug = Literal["debug"]
-
-_IsolationLevel = Literal[
-    "SERIALIZABLE",
-    "REPEATABLE READ",
-    "READ COMMITTED",
-    "READ UNCOMMITTED",
-    "AUTOCOMMIT",
-]
-_ParamStyle = Literal["qmark", "numeric", "named", "format", "pyformat"]
-_ResetOnReturn = Literal["rollback", "commit"]
-
-
-class _SQLiteConnectArgs(TypedDict, total=False):
-    timeout: float
-    detect_types: Any
-    isolation_level: Optional[Literal["DEFERRED", "IMMEDIATE", "EXCLUSIVE"]]
-    check_same_thread: bool
-    factory: Type[sqlite3.Connection]
-    cached_statements: int
-    uri: bool
-
-
-_ConnectArgs = Union[_SQLiteConnectArgs, Dict[str, Any]]
-
-
-# Re-define create_engine to have by default future=True, and assume that's what is used
-# Also show the default values used for each parameter, but don't set them unless
-# explicitly passed as arguments by the user to prevent errors. E.g. SQLite doesn't
-# support pool connection arguments.
-def create_engine(
-    url: Union[str, URL],
-    *,
-    connect_args: _ConnectArgs = Default({}),  # type: ignore
-    echo: Union[bool, _Debug] = Default(False),
-    echo_pool: Union[bool, _Debug] = Default(False),
-    enable_from_linting: bool = Default(True),
-    encoding: str = Default("utf-8"),
-    execution_options: Dict[Any, Any] = Default({}),
-    future: bool = True,
-    hide_parameters: bool = Default(False),
-    implicit_returning: bool = Default(True),
-    isolation_level: Optional[_IsolationLevel] = Default(None),
-    json_deserializer: Callable[..., Any] = Default(json.loads),
-    json_serializer: Callable[..., Any] = Default(json.dumps),
-    label_length: Optional[int] = Default(None),
-    logging_name: Optional[str] = Default(None),
-    max_identifier_length: Optional[int] = Default(None),
-    max_overflow: int = Default(10),
-    module: Optional[Any] = Default(None),
-    paramstyle: Optional[_ParamStyle] = Default(None),
-    pool: Optional[Pool] = Default(None),
-    poolclass: Optional[Type[Pool]] = Default(None),
-    pool_logging_name: Optional[str] = Default(None),
-    pool_pre_ping: bool = Default(False),
-    pool_size: int = Default(5),
-    pool_recycle: int = Default(-1),
-    pool_reset_on_return: Optional[_ResetOnReturn] = Default("rollback"),
-    pool_timeout: float = Default(30),
-    pool_use_lifo: bool = Default(False),
-    plugins: Optional[List[str]] = Default(None),
-    query_cache_size: Optional[int] = Default(None),
-    **kwargs: Any,
-) -> _FutureEngine:
-    current_kwargs: Dict[str, Any] = {
-        "future": future,
-    }
-    if not isinstance(echo, _DefaultPlaceholder):
-        current_kwargs["echo"] = echo
-    if not isinstance(echo_pool, _DefaultPlaceholder):
-        current_kwargs["echo_pool"] = echo_pool
-    if not isinstance(enable_from_linting, _DefaultPlaceholder):
-        current_kwargs["enable_from_linting"] = enable_from_linting
-    if not isinstance(connect_args, _DefaultPlaceholder):
-        current_kwargs["connect_args"] = connect_args
-    if not isinstance(encoding, _DefaultPlaceholder):
-        current_kwargs["encoding"] = encoding
-    if not isinstance(execution_options, _DefaultPlaceholder):
-        current_kwargs["execution_options"] = execution_options
-    if not isinstance(hide_parameters, _DefaultPlaceholder):
-        current_kwargs["hide_parameters"] = hide_parameters
-    if not isinstance(implicit_returning, _DefaultPlaceholder):
-        current_kwargs["implicit_returning"] = implicit_returning
-    if not isinstance(isolation_level, _DefaultPlaceholder):
-        current_kwargs["isolation_level"] = isolation_level
-    if not isinstance(json_deserializer, _DefaultPlaceholder):
-        current_kwargs["json_deserializer"] = json_deserializer
-    if not isinstance(json_serializer, _DefaultPlaceholder):
-        current_kwargs["json_serializer"] = json_serializer
-    if not isinstance(label_length, _DefaultPlaceholder):
-        current_kwargs["label_length"] = label_length
-    if not isinstance(logging_name, _DefaultPlaceholder):
-        current_kwargs["logging_name"] = logging_name
-    if not isinstance(max_identifier_length, _DefaultPlaceholder):
-        current_kwargs["max_identifier_length"] = max_identifier_length
-    if not isinstance(max_overflow, _DefaultPlaceholder):
-        current_kwargs["max_overflow"] = max_overflow
-    if not isinstance(module, _DefaultPlaceholder):
-        current_kwargs["module"] = module
-    if not isinstance(paramstyle, _DefaultPlaceholder):
-        current_kwargs["paramstyle"] = paramstyle
-    if not isinstance(pool, _DefaultPlaceholder):
-        current_kwargs["pool"] = pool
-    if not isinstance(poolclass, _DefaultPlaceholder):
-        current_kwargs["poolclass"] = poolclass
-    if not isinstance(pool_logging_name, _DefaultPlaceholder):
-        current_kwargs["pool_logging_name"] = pool_logging_name
-    if not isinstance(pool_pre_ping, _DefaultPlaceholder):
-        current_kwargs["pool_pre_ping"] = pool_pre_ping
-    if not isinstance(pool_size, _DefaultPlaceholder):
-        current_kwargs["pool_size"] = pool_size
-    if not isinstance(pool_recycle, _DefaultPlaceholder):
-        current_kwargs["pool_recycle"] = pool_recycle
-    if not isinstance(pool_reset_on_return, _DefaultPlaceholder):
-        current_kwargs["pool_reset_on_return"] = pool_reset_on_return
-    if not isinstance(pool_timeout, _DefaultPlaceholder):
-        current_kwargs["pool_timeout"] = pool_timeout
-    if not isinstance(pool_use_lifo, _DefaultPlaceholder):
-        current_kwargs["pool_use_lifo"] = pool_use_lifo
-    if not isinstance(plugins, _DefaultPlaceholder):
-        current_kwargs["plugins"] = plugins
-    if not isinstance(query_cache_size, _DefaultPlaceholder):
-        current_kwargs["query_cache_size"] = query_cache_size
-    current_kwargs.update(kwargs)
-    return _create_engine(url, **current_kwargs)  # type: ignore
diff --git a/sqlmodel/engine/result.py b/sqlmodel/engine/result.py
deleted file mode 100644 (file)
index 7a25422..0000000
+++ /dev/null
@@ -1,79 +0,0 @@
-from typing import Generic, Iterator, List, Optional, TypeVar
-
-from sqlalchemy.engine.result import Result as _Result
-from sqlalchemy.engine.result import ScalarResult as _ScalarResult
-
-_T = TypeVar("_T")
-
-
-class ScalarResult(_ScalarResult, Generic[_T]):
-    def all(self) -> List[_T]:
-        return super().all()
-
-    def partitions(self, size: Optional[int] = None) -> Iterator[List[_T]]:
-        return super().partitions(size)
-
-    def fetchall(self) -> List[_T]:
-        return super().fetchall()
-
-    def fetchmany(self, size: Optional[int] = None) -> List[_T]:
-        return super().fetchmany(size)
-
-    def __iter__(self) -> Iterator[_T]:
-        return super().__iter__()
-
-    def __next__(self) -> _T:
-        return super().__next__()  # type: ignore
-
-    def first(self) -> Optional[_T]:
-        return super().first()
-
-    def one_or_none(self) -> Optional[_T]:
-        return super().one_or_none()
-
-    def one(self) -> _T:
-        return super().one()  # type: ignore
-
-
-class Result(_Result, Generic[_T]):
-    def scalars(self, index: int = 0) -> ScalarResult[_T]:
-        return super().scalars(index)  # type: ignore
-
-    def __iter__(self) -> Iterator[_T]:  # type: ignore
-        return super().__iter__()  # type: ignore
-
-    def __next__(self) -> _T:  # type: ignore
-        return super().__next__()  # type: ignore
-
-    def partitions(self, size: Optional[int] = None) -> Iterator[List[_T]]:  # type: ignore
-        return super().partitions(size)  # type: ignore
-
-    def fetchall(self) -> List[_T]:  # type: ignore
-        return super().fetchall()  # type: ignore
-
-    def fetchone(self) -> Optional[_T]:  # type: ignore
-        return super().fetchone()  # type: ignore
-
-    def fetchmany(self, size: Optional[int] = None) -> List[_T]:  # type: ignore
-        return super().fetchmany()  # type: ignore
-
-    def all(self) -> List[_T]:  # type: ignore
-        return super().all()  # type: ignore
-
-    def first(self) -> Optional[_T]:  # type: ignore
-        return super().first()  # type: ignore
-
-    def one_or_none(self) -> Optional[_T]:  # type: ignore
-        return super().one_or_none()  # type: ignore
-
-    def scalar_one(self) -> _T:
-        return super().scalar_one()  # type: ignore
-
-    def scalar_one_or_none(self) -> Optional[_T]:
-        return super().scalar_one_or_none()
-
-    def one(self) -> _T:  # type: ignore
-        return super().one()  # type: ignore
-
-    def scalar(self) -> Optional[_T]:
-        return super().scalar()
index f500c44dc2192701fb6a8b37f5ed9f195d4b1989..012d8ef5e494293ba8d7938309efa30f1e14c1f7 100644 (file)
@@ -1,45 +1,38 @@
-from typing import Any, Mapping, Optional, Sequence, TypeVar, Union, overload
+from typing import (
+    Any,
+    Dict,
+    Mapping,
+    Optional,
+    Sequence,
+    Type,
+    TypeVar,
+    Union,
+    cast,
+    overload,
+)
 
 from sqlalchemy import util
+from sqlalchemy.engine.interfaces import _CoreAnyExecuteParams
+from sqlalchemy.engine.result import Result, ScalarResult, TupleResult
 from sqlalchemy.ext.asyncio import AsyncSession as _AsyncSession
-from sqlalchemy.ext.asyncio import engine
-from sqlalchemy.ext.asyncio.engine import AsyncConnection, AsyncEngine
+from sqlalchemy.ext.asyncio.result import _ensure_sync_result
+from sqlalchemy.ext.asyncio.session import _EXECUTE_OPTIONS
+from sqlalchemy.orm._typing import OrmExecuteOptionsParameter
+from sqlalchemy.sql.base import Executable as _Executable
 from sqlalchemy.util.concurrency import greenlet_spawn
+from typing_extensions import deprecated
 
-from ...engine.result import Result, ScalarResult
 from ...orm.session import Session
 from ...sql.base import Executable
 from ...sql.expression import Select, SelectOfScalar
 
-_TSelectParam = TypeVar("_TSelectParam")
+_TSelectParam = TypeVar("_TSelectParam", bound=Any)
 
 
 class AsyncSession(_AsyncSession):
+    sync_session_class: Type[Session] = Session
     sync_session: Session
 
-    def __init__(
-        self,
-        bind: Optional[Union[AsyncConnection, AsyncEngine]] = None,
-        binds: Optional[Mapping[object, Union[AsyncConnection, AsyncEngine]]] = None,
-        **kw: Any,
-    ):
-        # All the same code of the original AsyncSession
-        kw["future"] = True
-        if bind:
-            self.bind = bind
-            bind = engine._get_sync_engine_or_connection(bind)  # type: ignore
-
-        if binds:
-            self.binds = binds
-            binds = {
-                key: engine._get_sync_engine_or_connection(b)  # type: ignore
-                for key, b in binds.items()
-            }
-
-        self.sync_session = self._proxied = self._assign_proxied(  # type: ignore
-            Session(bind=bind, binds=binds, **kw)  # type: ignore
-        )
-
     @overload
     async def exec(
         self,
@@ -47,11 +40,10 @@ class AsyncSession(_AsyncSession):
         *,
         params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
         execution_options: Mapping[str, Any] = util.EMPTY_DICT,
-        bind_arguments: Optional[Mapping[str, Any]] = None,
+        bind_arguments: Optional[Dict[str, Any]] = None,
         _parent_execute_state: Optional[Any] = None,
         _add_event: Optional[Any] = None,
-        **kw: Any,
-    ) -> Result[_TSelectParam]:
+    ) -> TupleResult[_TSelectParam]:
         ...
 
     @overload
@@ -61,10 +53,9 @@ class AsyncSession(_AsyncSession):
         *,
         params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
         execution_options: Mapping[str, Any] = util.EMPTY_DICT,
-        bind_arguments: Optional[Mapping[str, Any]] = None,
+        bind_arguments: Optional[Dict[str, Any]] = None,
         _parent_execute_state: Optional[Any] = None,
         _add_event: Optional[Any] = None,
-        **kw: Any,
     ) -> ScalarResult[_TSelectParam]:
         ...
 
@@ -75,20 +66,87 @@ class AsyncSession(_AsyncSession):
             SelectOfScalar[_TSelectParam],
             Executable[_TSelectParam],
         ],
+        *,
         params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
-        execution_options: Mapping[Any, Any] = util.EMPTY_DICT,
-        bind_arguments: Optional[Mapping[str, Any]] = None,
-        **kw: Any,
-    ) -> Union[Result[_TSelectParam], ScalarResult[_TSelectParam]]:
-        # TODO: the documentation says execution_options accepts a dict, but only
-        # util.immutabledict has the union() method. Is this a bug in SQLAlchemy?
-        execution_options = execution_options.union({"prebuffer_rows": True})  # type: ignore
-
-        return await greenlet_spawn(
+        execution_options: Mapping[str, Any] = util.EMPTY_DICT,
+        bind_arguments: Optional[Dict[str, Any]] = None,
+        _parent_execute_state: Optional[Any] = None,
+        _add_event: Optional[Any] = None,
+    ) -> Union[TupleResult[_TSelectParam], ScalarResult[_TSelectParam]]:
+        if execution_options:
+            execution_options = util.immutabledict(execution_options).union(
+                _EXECUTE_OPTIONS
+            )
+        else:
+            execution_options = _EXECUTE_OPTIONS
+
+        result = await greenlet_spawn(
             self.sync_session.exec,
             statement,
             params=params,
             execution_options=execution_options,
             bind_arguments=bind_arguments,
-            **kw,
+            _parent_execute_state=_parent_execute_state,
+            _add_event=_add_event,
+        )
+        result_value = await _ensure_sync_result(
+            cast(Result[_TSelectParam], result), self.exec
+        )
+        return result_value  # type: ignore
+
+    @deprecated(
+        """
+        🚨 You probably want to use `session.exec()` instead of `session.execute()`.
+
+        This is the original SQLAlchemy `session.execute()` method that returns objects
+        of type `Row`, and that you have to call `scalars()` to get the model objects.
+
+        For example:
+
+        ```Python
+        heroes = await session.execute(select(Hero)).scalars().all()
+        ```
+
+        instead you could use `exec()`:
+
+        ```Python
+        heroes = await session.exec(select(Hero)).all()
+        ```
+        """
+    )
+    async def execute(  # type: ignore
+        self,
+        statement: _Executable,
+        params: Optional[_CoreAnyExecuteParams] = None,
+        *,
+        execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
+        bind_arguments: Optional[Dict[str, Any]] = None,
+        _parent_execute_state: Optional[Any] = None,
+        _add_event: Optional[Any] = None,
+    ) -> Result[Any]:
+        """
+        🚨 You probably want to use `session.exec()` instead of `session.execute()`.
+
+        This is the original SQLAlchemy `session.execute()` method that returns objects
+        of type `Row`, and that you have to call `scalars()` to get the model objects.
+
+        For example:
+
+        ```Python
+        heroes = await session.execute(select(Hero)).scalars().all()
+        ```
+
+        instead you could use `exec()`:
+
+        ```Python
+        heroes = await session.exec(select(Hero)).all()
+        ```
+        """
+        return await super().execute(
+            statement,
+            params=params,
+            execution_options=execution_options,
+            bind_arguments=bind_arguments,
+            _parent_execute_state=_parent_execute_state,
+            _add_event=_add_event,
         )
index 2b69dd2a75929e3dd24bc0c866c654c837e257e1..c30af5779f3f07b5789e476e629963ab87a93c2b 100644 (file)
@@ -45,12 +45,19 @@ from sqlalchemy import (
     inspect,
 )
 from sqlalchemy import Enum as sa_Enum
-from sqlalchemy.orm import RelationshipProperty, declared_attr, registry, relationship
+from sqlalchemy.orm import (
+    Mapped,
+    RelationshipProperty,
+    declared_attr,
+    registry,
+    relationship,
+)
 from sqlalchemy.orm.attributes import set_attribute
 from sqlalchemy.orm.decl_api import DeclarativeMeta
 from sqlalchemy.orm.instrumentation import is_instrumented
 from sqlalchemy.sql.schema import MetaData
 from sqlalchemy.sql.sqltypes import LargeBinary, Time
+from typing_extensions import get_origin
 
 from .sql.sqltypes import GUID, AutoString
 
@@ -483,7 +490,16 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
                     # over anything else, use that and continue with the next attribute
                     setattr(cls, rel_name, rel_info.sa_relationship)  # Fix #315
                     continue
-                ann = cls.__annotations__[rel_name]
+                raw_ann = cls.__annotations__[rel_name]
+                origin = get_origin(raw_ann)
+                if origin is Mapped:
+                    ann = raw_ann.__args__[0]
+                else:
+                    ann = raw_ann
+                    # Plain forward references, for models not yet defined, are not
+                    # handled well by SQLAlchemy without Mapped, so, wrap the
+                    # annotations in Mapped here
+                    cls.__annotations__[rel_name] = Mapped[ann]  # type: ignore[valid-type]
                 temp_field = ModelField.infer(
                     name=rel_name,
                     value=rel_info,
@@ -511,9 +527,7 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
                     rel_args.extend(rel_info.sa_relationship_args)
                 if rel_info.sa_relationship_kwargs:
                     rel_kwargs.update(rel_info.sa_relationship_kwargs)
-                rel_value: RelationshipProperty = relationship(  # type: ignore
-                    relationship_to, *rel_args, **rel_kwargs
-                )
+                rel_value = relationship(relationship_to, *rel_args, **rel_kwargs)
                 setattr(cls, rel_name, rel_value)  # Fix #315
             # SQLAlchemy no longer uses dict_
             # Ref: https://github.com/sqlalchemy/sqlalchemy/commit/428ea01f00a9cc7f85e435018565eb6da7af1b77
@@ -642,6 +656,7 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry
     __sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty]]  # type: ignore
     __name__: ClassVar[str]
     metadata: ClassVar[MetaData]
+    __allow_unmapped__ = True  # https://docs.sqlalchemy.org/en/20/changelog/migration_20.html#migration-20-step-six
 
     class Config:
         orm_mode = True
@@ -685,7 +700,7 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry
             return
         else:
             # Set in SQLAlchemy, before Pydantic to trigger events and updates
-            if getattr(self.__config__, "table", False) and is_instrumented(self, name):
+            if getattr(self.__config__, "table", False) and is_instrumented(self, name):  # type: ignore
                 set_attribute(self, name, value)
             # Set in Pydantic model to trigger possible validation changes, only for
             # non relationship values
index 0c70c290ae9e5b3317bd2ddecf2665bcb83531a0..6050d5fbc111a0a71b9953be04a8b00294d491ab 100644 (file)
@@ -1,16 +1,27 @@
-from typing import Any, Mapping, Optional, Sequence, Type, TypeVar, Union, overload
+from typing import (
+    Any,
+    Dict,
+    Mapping,
+    Optional,
+    Sequence,
+    TypeVar,
+    Union,
+    overload,
+)
 
 from sqlalchemy import util
+from sqlalchemy.engine.interfaces import _CoreAnyExecuteParams
+from sqlalchemy.engine.result import Result, ScalarResult, TupleResult
 from sqlalchemy.orm import Query as _Query
 from sqlalchemy.orm import Session as _Session
+from sqlalchemy.orm._typing import OrmExecuteOptionsParameter
+from sqlalchemy.sql._typing import _ColumnsClauseArgument
 from sqlalchemy.sql.base import Executable as _Executable
-from typing_extensions import Literal
+from sqlmodel.sql.base import Executable
+from sqlmodel.sql.expression import Select, SelectOfScalar
+from typing_extensions import deprecated
 
-from ..engine.result import Result, ScalarResult
-from ..sql.base import Executable
-from ..sql.expression import Select, SelectOfScalar
-
-_TSelectParam = TypeVar("_TSelectParam")
+_TSelectParam = TypeVar("_TSelectParam", bound=Any)
 
 
 class Session(_Session):
@@ -21,11 +32,10 @@ class Session(_Session):
         *,
         params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
         execution_options: Mapping[str, Any] = util.EMPTY_DICT,
-        bind_arguments: Optional[Mapping[str, Any]] = None,
+        bind_arguments: Optional[Dict[str, Any]] = None,
         _parent_execute_state: Optional[Any] = None,
         _add_event: Optional[Any] = None,
-        **kw: Any,
-    ) -> Result[_TSelectParam]:
+    ) -> TupleResult[_TSelectParam]:
         ...
 
     @overload
@@ -35,10 +45,9 @@ class Session(_Session):
         *,
         params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
         execution_options: Mapping[str, Any] = util.EMPTY_DICT,
-        bind_arguments: Optional[Mapping[str, Any]] = None,
+        bind_arguments: Optional[Dict[str, Any]] = None,
         _parent_execute_state: Optional[Any] = None,
         _add_event: Optional[Any] = None,
-        **kw: Any,
     ) -> ScalarResult[_TSelectParam]:
         ...
 
@@ -52,11 +61,10 @@ class Session(_Session):
         *,
         params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
         execution_options: Mapping[str, Any] = util.EMPTY_DICT,
-        bind_arguments: Optional[Mapping[str, Any]] = None,
+        bind_arguments: Optional[Dict[str, Any]] = None,
         _parent_execute_state: Optional[Any] = None,
         _add_event: Optional[Any] = None,
-        **kw: Any,
-    ) -> Union[Result[_TSelectParam], ScalarResult[_TSelectParam]]:
+    ) -> Union[TupleResult[_TSelectParam], ScalarResult[_TSelectParam]]:
         results = super().execute(
             statement,
             params=params,
@@ -64,21 +72,40 @@ class Session(_Session):
             bind_arguments=bind_arguments,
             _parent_execute_state=_parent_execute_state,
             _add_event=_add_event,
-            **kw,
         )
         if isinstance(statement, SelectOfScalar):
-            return results.scalars()  # type: ignore
+            return results.scalars()
         return results  # type: ignore
 
-    def execute(
+    @deprecated(
+        """
+        🚨 You probably want to use `session.exec()` instead of `session.execute()`.
+
+        This is the original SQLAlchemy `session.execute()` method that returns objects
+        of type `Row`, and that you have to call `scalars()` to get the model objects.
+
+        For example:
+
+        ```Python
+        heroes = session.execute(select(Hero)).scalars().all()
+        ```
+
+        instead you could use `exec()`:
+
+        ```Python
+        heroes = session.exec(select(Hero)).all()
+        ```
+        """
+    )
+    def execute(  # type: ignore
         self,
         statement: _Executable,
-        params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
-        execution_options: Optional[Mapping[str, Any]] = util.EMPTY_DICT,
-        bind_arguments: Optional[Mapping[str, Any]] = None,
+        params: Optional[_CoreAnyExecuteParams] = None,
+        *,
+        execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
+        bind_arguments: Optional[Dict[str, Any]] = None,
         _parent_execute_state: Optional[Any] = None,
         _add_event: Optional[Any] = None,
-        **kw: Any,
     ) -> Result[Any]:
         """
         🚨 You probably want to use `session.exec()` instead of `session.execute()`.
@@ -98,17 +125,16 @@ class Session(_Session):
         heroes = session.exec(select(Hero)).all()
         ```
         """
-        return super().execute(  # type: ignore
+        return super().execute(
             statement,
             params=params,
             execution_options=execution_options,
             bind_arguments=bind_arguments,
             _parent_execute_state=_parent_execute_state,
             _add_event=_add_event,
-            **kw,
         )
 
-    def query(self, *entities: Any, **kwargs: Any) -> "_Query[Any]":
+    @deprecated(
         """
         🚨 You probably want to use `session.exec()` instead of `session.query()`.
 
@@ -118,24 +144,17 @@ class Session(_Session):
         Or otherwise you might want to use `session.execute()` instead of
         `session.query()`.
         """
-        return super().query(*entities, **kwargs)
+    )
+    def query(  # type: ignore
+        self, *entities: _ColumnsClauseArgument[Any], **kwargs: Any
+    ) -> _Query[Any]:
+        """
+        🚨 You probably want to use `session.exec()` instead of `session.query()`.
 
-    def get(
-        self,
-        entity: Type[_TSelectParam],
-        ident: Any,
-        options: Optional[Sequence[Any]] = None,
-        populate_existing: bool = False,
-        with_for_update: Optional[Union[Literal[True], Mapping[str, Any]]] = None,
-        identity_token: Optional[Any] = None,
-        execution_options: Optional[Mapping[Any, Any]] = util.EMPTY_DICT,
-    ) -> Optional[_TSelectParam]:
-        return super().get(
-            entity,
-            ident,
-            options=options,
-            populate_existing=populate_existing,
-            with_for_update=with_for_update,
-            identity_token=identity_token,
-            execution_options=execution_options,
-        )
+        `session.exec()` is SQLModel's own short version with increased type
+        annotations.
+
+        Or otherwise you might want to use `session.execute()` instead of
+        `session.query()`.
+        """
+        return super().query(*entities, **kwargs)
index 264e39cba7cdf8d497f04c23af018f445ff064fb..a8a572501cd1754f0a0e0349739d5196cc07cd6f 100644 (file)
@@ -2,10 +2,10 @@
 
 from datetime import datetime
 from typing import (
-    TYPE_CHECKING,
     Any,
-    Generic,
+    Iterable,
     Mapping,
+    Optional,
     Sequence,
     Tuple,
     Type,
@@ -15,15 +15,223 @@ from typing import (
 )
 from uuid import UUID
 
-from sqlalchemy import Column
-from sqlalchemy.orm import InstrumentedAttribute
-from sqlalchemy.sql.elements import ColumnClause
+import sqlalchemy
+from sqlalchemy import (
+    Column,
+    ColumnElement,
+    Extract,
+    FunctionElement,
+    FunctionFilter,
+    Label,
+    Over,
+    TypeCoerce,
+    WithinGroup,
+)
+from sqlalchemy.orm import InstrumentedAttribute, Mapped
+from sqlalchemy.sql._typing import (
+    _ColumnExpressionArgument,
+    _ColumnExpressionOrLiteralArgument,
+    _ColumnExpressionOrStrLabelArgument,
+)
+from sqlalchemy.sql.elements import (
+    BinaryExpression,
+    Case,
+    Cast,
+    CollectionAggregate,
+    ColumnClause,
+    SQLCoreOperations,
+    TryCast,
+    UnaryExpression,
+)
 from sqlalchemy.sql.expression import Select as _Select
+from sqlalchemy.sql.roles import TypedColumnsClauseRole
+from sqlalchemy.sql.type_api import TypeEngine
+from typing_extensions import Literal, Self
+
+_T = TypeVar("_T")
+
+_TypeEngineArgument = Union[Type[TypeEngine[_T]], TypeEngine[_T]]
+
+# Redefine operatos that would only take a column expresion to also take the (virtual)
+# types of Pydantic models, e.g. str instead of only Mapped[str].
+
+
+def all_(expr: Union[_ColumnExpressionArgument[_T], _T]) -> CollectionAggregate[bool]:
+    return sqlalchemy.all_(expr)  # type: ignore[arg-type]
+
+
+def and_(
+    initial_clause: Union[Literal[True], _ColumnExpressionArgument[bool], bool],
+    *clauses: Union[_ColumnExpressionArgument[bool], bool],
+) -> ColumnElement[bool]:
+    return sqlalchemy.and_(initial_clause, *clauses)  # type: ignore[arg-type]
+
+
+def any_(expr: Union[_ColumnExpressionArgument[_T], _T]) -> CollectionAggregate[bool]:
+    return sqlalchemy.any_(expr)  # type: ignore[arg-type]
+
+
+def asc(
+    column: Union[_ColumnExpressionOrStrLabelArgument[_T], _T],
+) -> UnaryExpression[_T]:
+    return sqlalchemy.asc(column)  # type: ignore[arg-type]
+
+
+def collate(
+    expression: Union[_ColumnExpressionArgument[str], str], collation: str
+) -> BinaryExpression[str]:
+    return sqlalchemy.collate(expression, collation)  # type: ignore[arg-type]
+
+
+def between(
+    expr: Union[_ColumnExpressionOrLiteralArgument[_T], _T],
+    lower_bound: Any,
+    upper_bound: Any,
+    symmetric: bool = False,
+) -> BinaryExpression[bool]:
+    return sqlalchemy.between(expr, lower_bound, upper_bound, symmetric=symmetric)  # type: ignore[arg-type]
+
+
+def not_(clause: Union[_ColumnExpressionArgument[_T], _T]) -> ColumnElement[_T]:
+    return sqlalchemy.not_(clause)  # type: ignore[arg-type]
+
+
+def case(
+    *whens: Union[
+        Tuple[Union[_ColumnExpressionArgument[bool], bool], Any], Mapping[Any, Any]
+    ],
+    value: Optional[Any] = None,
+    else_: Optional[Any] = None,
+) -> Case[Any]:
+    return sqlalchemy.case(*whens, value=value, else_=else_)  # type: ignore[arg-type]
+
+
+def cast(
+    expression: Union[_ColumnExpressionOrLiteralArgument[Any], Any],
+    type_: "_TypeEngineArgument[_T]",
+) -> Cast[_T]:
+    return sqlalchemy.cast(expression, type_)  # type: ignore[arg-type]
+
+
+def try_cast(
+    expression: Union[_ColumnExpressionOrLiteralArgument[Any], Any],
+    type_: "_TypeEngineArgument[_T]",
+) -> TryCast[_T]:
+    return sqlalchemy.try_cast(expression, type_)  # type: ignore[arg-type]
+
+
+def desc(
+    column: Union[_ColumnExpressionOrStrLabelArgument[_T], _T],
+) -> UnaryExpression[_T]:
+    return sqlalchemy.desc(column)  # type: ignore[arg-type]
+
+
+def distinct(expr: Union[_ColumnExpressionArgument[_T], _T]) -> UnaryExpression[_T]:
+    return sqlalchemy.distinct(expr)  # type: ignore[arg-type]
+
+
+def bitwise_not(expr: Union[_ColumnExpressionArgument[_T], _T]) -> UnaryExpression[_T]:
+    return sqlalchemy.bitwise_not(expr)  # type: ignore[arg-type]
+
+
+def extract(field: str, expr: Union[_ColumnExpressionArgument[Any], Any]) -> Extract:
+    return sqlalchemy.extract(field, expr)  # type: ignore[arg-type]
+
+
+def funcfilter(
+    func: FunctionElement[_T], *criterion: Union[_ColumnExpressionArgument[bool], bool]
+) -> FunctionFilter[_T]:
+    return sqlalchemy.funcfilter(func, *criterion)  # type: ignore[arg-type]
+
+
+def label(
+    name: str,
+    element: Union[_ColumnExpressionArgument[_T], _T],
+    type_: Optional["_TypeEngineArgument[_T]"] = None,
+) -> Label[_T]:
+    return sqlalchemy.label(name, element, type_=type_)  # type: ignore[arg-type]
+
+
+def nulls_first(
+    column: Union[_ColumnExpressionArgument[_T], _T]
+) -> UnaryExpression[_T]:
+    return sqlalchemy.nulls_first(column)  # type: ignore[arg-type]
+
+
+def nulls_last(column: Union[_ColumnExpressionArgument[_T], _T]) -> UnaryExpression[_T]:
+    return sqlalchemy.nulls_last(column)  # type: ignore[arg-type]
+
+
+def or_(  # type: ignore[empty-body]
+    initial_clause: Union[Literal[False], _ColumnExpressionArgument[bool], bool],
+    *clauses: Union[_ColumnExpressionArgument[bool], bool],
+) -> ColumnElement[bool]:
+    return sqlalchemy.or_(initial_clause, *clauses)  # type: ignore[arg-type]
+
+
+def over(
+    element: FunctionElement[_T],
+    partition_by: Optional[
+        Union[
+            Iterable[Union[_ColumnExpressionArgument[Any], Any]],
+            _ColumnExpressionArgument[Any],
+            Any,
+        ]
+    ] = None,
+    order_by: Optional[
+        Union[
+            Iterable[Union[_ColumnExpressionArgument[Any], Any]],
+            _ColumnExpressionArgument[Any],
+            Any,
+        ]
+    ] = None,
+    range_: Optional[Tuple[Optional[int], Optional[int]]] = None,
+    rows: Optional[Tuple[Optional[int], Optional[int]]] = None,
+) -> Over[_T]:
+    return sqlalchemy.over(
+        element, partition_by=partition_by, order_by=order_by, range_=range_, rows=rows
+    )  # type: ignore[arg-type]
+
+
+def tuple_(
+    *clauses: Union[_ColumnExpressionArgument[Any], Any],
+    types: Optional[Sequence["_TypeEngineArgument[Any]"]] = None,
+) -> Tuple[Any, ...]:
+    return sqlalchemy.tuple_(*clauses, types=types)  # type: ignore[return-value]
+
+
+def type_coerce(
+    expression: Union[_ColumnExpressionOrLiteralArgument[Any], Any],
+    type_: "_TypeEngineArgument[_T]",
+) -> TypeCoerce[_T]:
+    return sqlalchemy.type_coerce(expression, type_)  # type: ignore[arg-type]
+
+
+def within_group(
+    element: FunctionElement[_T], *order_by: Union[_ColumnExpressionArgument[Any], Any]
+) -> WithinGroup[_T]:
+    return sqlalchemy.within_group(element, *order_by)  # type: ignore[arg-type]
+
+
+# Separate this class in SelectBase, Select, and SelectOfScalar so that they can share
+# where and having without having type overlap incompatibility in session.exec().
+class SelectBase(_Select[Tuple[_T]]):
+    inherit_cache = True
+
+    def where(self, *whereclause: Union[_ColumnExpressionArgument[bool], bool]) -> Self:
+        """Return a new `Select` construct with the given expression added to
+        its `WHERE` clause, joined to the existing clause via `AND`, if any.
+        """
+        return super().where(*whereclause)  # type: ignore[arg-type]
 
-_TSelect = TypeVar("_TSelect")
+    def having(self, *having: Union[_ColumnExpressionArgument[bool], bool]) -> Self:
+        """Return a new `Select` construct with the given expression added to
+        its `HAVING` clause, joined to the existing clause via `AND`, if any.
+        """
+        return super().having(*having)  # type: ignore[arg-type]
 
 
-class Select(_Select, Generic[_TSelect]):
+class Select(SelectBase[_T]):
     inherit_cache = True
 
 
@@ -31,12 +239,15 @@ class Select(_Select, Generic[_TSelect]):
 # purpose. This is the same as a normal SQLAlchemy Select class where there's only one
 # entity, so the result will be converted to a scalar by default. This way writing
 # for loops on the results will feel natural.
-class SelectOfScalar(_Select, Generic[_TSelect]):
+class SelectOfScalar(SelectBase[_T]):
     inherit_cache = True
 
 
-if TYPE_CHECKING:  # pragma: no cover
-    from ..main import SQLModel
+_TCCA = Union[
+    TypedColumnsClauseRole[_T],
+    SQLCoreOperations[_T],
+    Type[_T],
+]
 
 # Generated TypeVars start
 
@@ -56,7 +267,7 @@ _TScalar_0 = TypeVar(
     None,
 )
 
-_TModel_0 = TypeVar("_TModel_0", bound="SQLModel")
+_T0 = TypeVar("_T0")
 
 
 _TScalar_1 = TypeVar(
@@ -74,7 +285,7 @@ _TScalar_1 = TypeVar(
     None,
 )
 
-_TModel_1 = TypeVar("_TModel_1", bound="SQLModel")
+_T1 = TypeVar("_T1")
 
 
 _TScalar_2 = TypeVar(
@@ -92,7 +303,7 @@ _TScalar_2 = TypeVar(
     None,
 )
 
-_TModel_2 = TypeVar("_TModel_2", bound="SQLModel")
+_T2 = TypeVar("_T2")
 
 
 _TScalar_3 = TypeVar(
@@ -110,19 +321,19 @@ _TScalar_3 = TypeVar(
     None,
 )
 
-_TModel_3 = TypeVar("_TModel_3", bound="SQLModel")
+_T3 = TypeVar("_T3")
 
 
 # Generated TypeVars end
 
 
 @overload
-def select(entity_0: _TScalar_0, **kw: Any) -> SelectOfScalar[_TScalar_0]:  # type: ignore
+def select(__ent0: _TScalar_0) -> SelectOfScalar[_TScalar_0]:  # type: ignore
     ...
 
 
 @overload
-def select(entity_0: Type[_TModel_0], **kw: Any) -> SelectOfScalar[_TModel_0]:  # type: ignore
+def select(__ent0: _TCCA[_T0]) -> SelectOfScalar[_T0]:
     ...
 
 
@@ -133,7 +344,6 @@ def select(entity_0: Type[_TModel_0], **kw: Any) -> SelectOfScalar[_TModel_0]:
 def select(  # type: ignore
     entity_0: _TScalar_0,
     entity_1: _TScalar_1,
-    **kw: Any,
 ) -> Select[Tuple[_TScalar_0, _TScalar_1]]:
     ...
 
@@ -141,27 +351,24 @@ def select(  # type: ignore
 @overload
 def select(  # type: ignore
     entity_0: _TScalar_0,
-    entity_1: Type[_TModel_1],
-    **kw: Any,
-) -> Select[Tuple[_TScalar_0, _TModel_1]]:
+    __ent1: _TCCA[_T1],
+) -> Select[Tuple[_TScalar_0, _T1]]:
     ...
 
 
 @overload
 def select(  # type: ignore
-    entity_0: Type[_TModel_0],
+    __ent0: _TCCA[_T0],
     entity_1: _TScalar_1,
-    **kw: Any,
-) -> Select[Tuple[_TModel_0, _TScalar_1]]:
+) -> Select[Tuple[_T0, _TScalar_1]]:
     ...
 
 
 @overload
 def select(  # type: ignore
-    entity_0: Type[_TModel_0],
-    entity_1: Type[_TModel_1],
-    **kw: Any,
-) -> Select[Tuple[_TModel_0, _TModel_1]]:
+    __ent0: _TCCA[_T0],
+    __ent1: _TCCA[_T1],
+) -> Select[Tuple[_T0, _T1]]:
     ...
 
 
@@ -170,7 +377,6 @@ def select(  # type: ignore
     entity_0: _TScalar_0,
     entity_1: _TScalar_1,
     entity_2: _TScalar_2,
-    **kw: Any,
 ) -> Select[Tuple[_TScalar_0, _TScalar_1, _TScalar_2]]:
     ...
 
@@ -179,69 +385,62 @@ def select(  # type: ignore
 def select(  # type: ignore
     entity_0: _TScalar_0,
     entity_1: _TScalar_1,
-    entity_2: Type[_TModel_2],
-    **kw: Any,
-) -> Select[Tuple[_TScalar_0, _TScalar_1, _TModel_2]]:
+    __ent2: _TCCA[_T2],
+) -> Select[Tuple[_TScalar_0, _TScalar_1, _T2]]:
     ...
 
 
 @overload
 def select(  # type: ignore
     entity_0: _TScalar_0,
-    entity_1: Type[_TModel_1],
+    __ent1: _TCCA[_T1],
     entity_2: _TScalar_2,
-    **kw: Any,
-) -> Select[Tuple[_TScalar_0, _TModel_1, _TScalar_2]]:
+) -> Select[Tuple[_TScalar_0, _T1, _TScalar_2]]:
     ...
 
 
 @overload
 def select(  # type: ignore
     entity_0: _TScalar_0,
-    entity_1: Type[_TModel_1],
-    entity_2: Type[_TModel_2],
-    **kw: Any,
-) -> Select[Tuple[_TScalar_0, _TModel_1, _TModel_2]]:
+    __ent1: _TCCA[_T1],
+    __ent2: _TCCA[_T2],
+) -> Select[Tuple[_TScalar_0, _T1, _T2]]:
     ...
 
 
 @overload
 def select(  # type: ignore
-    entity_0: Type[_TModel_0],
+    __ent0: _TCCA[_T0],
     entity_1: _TScalar_1,
     entity_2: _TScalar_2,
-    **kw: Any,
-) -> Select[Tuple[_TModel_0, _TScalar_1, _TScalar_2]]:
+) -> Select[Tuple[_T0, _TScalar_1, _TScalar_2]]:
     ...
 
 
 @overload
 def select(  # type: ignore
-    entity_0: Type[_TModel_0],
+    __ent0: _TCCA[_T0],
     entity_1: _TScalar_1,
-    entity_2: Type[_TModel_2],
-    **kw: Any,
-) -> Select[Tuple[_TModel_0, _TScalar_1, _TModel_2]]:
+    __ent2: _TCCA[_T2],
+) -> Select[Tuple[_T0, _TScalar_1, _T2]]:
     ...
 
 
 @overload
 def select(  # type: ignore
-    entity_0: Type[_TModel_0],
-    entity_1: Type[_TModel_1],
+    __ent0: _TCCA[_T0],
+    __ent1: _TCCA[_T1],
     entity_2: _TScalar_2,
-    **kw: Any,
-) -> Select[Tuple[_TModel_0, _TModel_1, _TScalar_2]]:
+) -> Select[Tuple[_T0, _T1, _TScalar_2]]:
     ...
 
 
 @overload
 def select(  # type: ignore
-    entity_0: Type[_TModel_0],
-    entity_1: Type[_TModel_1],
-    entity_2: Type[_TModel_2],
-    **kw: Any,
-) -> Select[Tuple[_TModel_0, _TModel_1, _TModel_2]]:
+    __ent0: _TCCA[_T0],
+    __ent1: _TCCA[_T1],
+    __ent2: _TCCA[_T2],
+) -> Select[Tuple[_T0, _T1, _T2]]:
     ...
 
 
@@ -251,7 +450,6 @@ def select(  # type: ignore
     entity_1: _TScalar_1,
     entity_2: _TScalar_2,
     entity_3: _TScalar_3,
-    **kw: Any,
 ) -> Select[Tuple[_TScalar_0, _TScalar_1, _TScalar_2, _TScalar_3]]:
     ...
 
@@ -261,9 +459,8 @@ def select(  # type: ignore
     entity_0: _TScalar_0,
     entity_1: _TScalar_1,
     entity_2: _TScalar_2,
-    entity_3: Type[_TModel_3],
-    **kw: Any,
-) -> Select[Tuple[_TScalar_0, _TScalar_1, _TScalar_2, _TModel_3]]:
+    __ent3: _TCCA[_T3],
+) -> Select[Tuple[_TScalar_0, _TScalar_1, _TScalar_2, _T3]]:
     ...
 
 
@@ -271,10 +468,9 @@ def select(  # type: ignore
 def select(  # type: ignore
     entity_0: _TScalar_0,
     entity_1: _TScalar_1,
-    entity_2: Type[_TModel_2],
+    __ent2: _TCCA[_T2],
     entity_3: _TScalar_3,
-    **kw: Any,
-) -> Select[Tuple[_TScalar_0, _TScalar_1, _TModel_2, _TScalar_3]]:
+) -> Select[Tuple[_TScalar_0, _TScalar_1, _T2, _TScalar_3]]:
     ...
 
 
@@ -282,156 +478,142 @@ def select(  # type: ignore
 def select(  # type: ignore
     entity_0: _TScalar_0,
     entity_1: _TScalar_1,
-    entity_2: Type[_TModel_2],
-    entity_3: Type[_TModel_3],
-    **kw: Any,
-) -> Select[Tuple[_TScalar_0, _TScalar_1, _TModel_2, _TModel_3]]:
+    __ent2: _TCCA[_T2],
+    __ent3: _TCCA[_T3],
+) -> Select[Tuple[_TScalar_0, _TScalar_1, _T2, _T3]]:
     ...
 
 
 @overload
 def select(  # type: ignore
     entity_0: _TScalar_0,
-    entity_1: Type[_TModel_1],
+    __ent1: _TCCA[_T1],
     entity_2: _TScalar_2,
     entity_3: _TScalar_3,
-    **kw: Any,
-) -> Select[Tuple[_TScalar_0, _TModel_1, _TScalar_2, _TScalar_3]]:
+) -> Select[Tuple[_TScalar_0, _T1, _TScalar_2, _TScalar_3]]:
     ...
 
 
 @overload
 def select(  # type: ignore
     entity_0: _TScalar_0,
-    entity_1: Type[_TModel_1],
+    __ent1: _TCCA[_T1],
     entity_2: _TScalar_2,
-    entity_3: Type[_TModel_3],
-    **kw: Any,
-) -> Select[Tuple[_TScalar_0, _TModel_1, _TScalar_2, _TModel_3]]:
+    __ent3: _TCCA[_T3],
+) -> Select[Tuple[_TScalar_0, _T1, _TScalar_2, _T3]]:
     ...
 
 
 @overload
 def select(  # type: ignore
     entity_0: _TScalar_0,
-    entity_1: Type[_TModel_1],
-    entity_2: Type[_TModel_2],
+    __ent1: _TCCA[_T1],
+    __ent2: _TCCA[_T2],
     entity_3: _TScalar_3,
-    **kw: Any,
-) -> Select[Tuple[_TScalar_0, _TModel_1, _TModel_2, _TScalar_3]]:
+) -> Select[Tuple[_TScalar_0, _T1, _T2, _TScalar_3]]:
     ...
 
 
 @overload
 def select(  # type: ignore
     entity_0: _TScalar_0,
-    entity_1: Type[_TModel_1],
-    entity_2: Type[_TModel_2],
-    entity_3: Type[_TModel_3],
-    **kw: Any,
-) -> Select[Tuple[_TScalar_0, _TModel_1, _TModel_2, _TModel_3]]:
+    __ent1: _TCCA[_T1],
+    __ent2: _TCCA[_T2],
+    __ent3: _TCCA[_T3],
+) -> Select[Tuple[_TScalar_0, _T1, _T2, _T3]]:
     ...
 
 
 @overload
 def select(  # type: ignore
-    entity_0: Type[_TModel_0],
+    __ent0: _TCCA[_T0],
     entity_1: _TScalar_1,
     entity_2: _TScalar_2,
     entity_3: _TScalar_3,
-    **kw: Any,
-) -> Select[Tuple[_TModel_0, _TScalar_1, _TScalar_2, _TScalar_3]]:
+) -> Select[Tuple[_T0, _TScalar_1, _TScalar_2, _TScalar_3]]:
     ...
 
 
 @overload
 def select(  # type: ignore
-    entity_0: Type[_TModel_0],
+    __ent0: _TCCA[_T0],
     entity_1: _TScalar_1,
     entity_2: _TScalar_2,
-    entity_3: Type[_TModel_3],
-    **kw: Any,
-) -> Select[Tuple[_TModel_0, _TScalar_1, _TScalar_2, _TModel_3]]:
+    __ent3: _TCCA[_T3],
+) -> Select[Tuple[_T0, _TScalar_1, _TScalar_2, _T3]]:
     ...
 
 
 @overload
 def select(  # type: ignore
-    entity_0: Type[_TModel_0],
+    __ent0: _TCCA[_T0],
     entity_1: _TScalar_1,
-    entity_2: Type[_TModel_2],
+    __ent2: _TCCA[_T2],
     entity_3: _TScalar_3,
-    **kw: Any,
-) -> Select[Tuple[_TModel_0, _TScalar_1, _TModel_2, _TScalar_3]]:
+) -> Select[Tuple[_T0, _TScalar_1, _T2, _TScalar_3]]:
     ...
 
 
 @overload
 def select(  # type: ignore
-    entity_0: Type[_TModel_0],
+    __ent0: _TCCA[_T0],
     entity_1: _TScalar_1,
-    entity_2: Type[_TModel_2],
-    entity_3: Type[_TModel_3],
-    **kw: Any,
-) -> Select[Tuple[_TModel_0, _TScalar_1, _TModel_2, _TModel_3]]:
+    __ent2: _TCCA[_T2],
+    __ent3: _TCCA[_T3],
+) -> Select[Tuple[_T0, _TScalar_1, _T2, _T3]]:
     ...
 
 
 @overload
 def select(  # type: ignore
-    entity_0: Type[_TModel_0],
-    entity_1: Type[_TModel_1],
+    __ent0: _TCCA[_T0],
+    __ent1: _TCCA[_T1],
     entity_2: _TScalar_2,
     entity_3: _TScalar_3,
-    **kw: Any,
-) -> Select[Tuple[_TModel_0, _TModel_1, _TScalar_2, _TScalar_3]]:
+) -> Select[Tuple[_T0, _T1, _TScalar_2, _TScalar_3]]:
     ...
 
 
 @overload
 def select(  # type: ignore
-    entity_0: Type[_TModel_0],
-    entity_1: Type[_TModel_1],
+    __ent0: _TCCA[_T0],
+    __ent1: _TCCA[_T1],
     entity_2: _TScalar_2,
-    entity_3: Type[_TModel_3],
-    **kw: Any,
-) -> Select[Tuple[_TModel_0, _TModel_1, _TScalar_2, _TModel_3]]:
+    __ent3: _TCCA[_T3],
+) -> Select[Tuple[_T0, _T1, _TScalar_2, _T3]]:
     ...
 
 
 @overload
 def select(  # type: ignore
-    entity_0: Type[_TModel_0],
-    entity_1: Type[_TModel_1],
-    entity_2: Type[_TModel_2],
+    __ent0: _TCCA[_T0],
+    __ent1: _TCCA[_T1],
+    __ent2: _TCCA[_T2],
     entity_3: _TScalar_3,
-    **kw: Any,
-) -> Select[Tuple[_TModel_0, _TModel_1, _TModel_2, _TScalar_3]]:
+) -> Select[Tuple[_T0, _T1, _T2, _TScalar_3]]:
     ...
 
 
 @overload
 def select(  # type: ignore
-    entity_0: Type[_TModel_0],
-    entity_1: Type[_TModel_1],
-    entity_2: Type[_TModel_2],
-    entity_3: Type[_TModel_3],
-    **kw: Any,
-) -> Select[Tuple[_TModel_0, _TModel_1, _TModel_2, _TModel_3]]:
+    __ent0: _TCCA[_T0],
+    __ent1: _TCCA[_T1],
+    __ent2: _TCCA[_T2],
+    __ent3: _TCCA[_T3],
+) -> Select[Tuple[_T0, _T1, _T2, _T3]]:
     ...
 
 
 # Generated overloads end
 
 
-def select(*entities: Any, **kw: Any) -> Union[Select, SelectOfScalar]:  # type: ignore
+def select(*entities: Any) -> Union[Select, SelectOfScalar]:  # type: ignore
     if len(entities) == 1:
-        return SelectOfScalar._create(*entities, **kw)  # type: ignore
-    return Select._create(*entities, **kw)  # type: ignore
+        return SelectOfScalar(*entities)
+    return Select(*entities)
 
 
-# TODO: add several @overload from Python types to SQLAlchemy equivalents
-def col(column_expression: Any) -> ColumnClause:  # type: ignore
+def col(column_expression: _T) -> Mapped[_T]:
     if not isinstance(column_expression, (ColumnClause, Column, InstrumentedAttribute)):
         raise RuntimeError(f"Not a SQLAlchemy column: {column_expression}")
-    return column_expression
+    return column_expression  # type: ignore
index 26d12a0395743eb1b5ca5e87e05125a08fe7d3bc..f1a25419c02f66a503695200d12bc21ff93941d7 100644 (file)
@@ -1,9 +1,9 @@
 from datetime import datetime
 from typing import (
-    TYPE_CHECKING,
     Any,
-    Generic,
+    Iterable,
     Mapping,
+    Optional,
     Sequence,
     Tuple,
     Type,
@@ -13,28 +13,243 @@ from typing import (
 )
 from uuid import UUID
 
-from sqlalchemy import Column
-from sqlalchemy.orm import InstrumentedAttribute
-from sqlalchemy.sql.elements import ColumnClause
+import sqlalchemy
+from sqlalchemy import (
+    Column,
+    ColumnElement,
+    Extract,
+    FunctionElement,
+    FunctionFilter,
+    Label,
+    Over,
+    TypeCoerce,
+    WithinGroup,
+)
+from sqlalchemy.orm import InstrumentedAttribute, Mapped
+from sqlalchemy.sql._typing import (
+    _ColumnExpressionArgument,
+    _ColumnExpressionOrLiteralArgument,
+    _ColumnExpressionOrStrLabelArgument,
+)
+from sqlalchemy.sql.elements import (
+    BinaryExpression,
+    Case,
+    Cast,
+    CollectionAggregate,
+    ColumnClause,
+    SQLCoreOperations,
+    TryCast,
+    UnaryExpression,
+)
 from sqlalchemy.sql.expression import Select as _Select
+from sqlalchemy.sql.roles import TypedColumnsClauseRole
+from sqlalchemy.sql.type_api import TypeEngine
+from typing_extensions import Literal, Self
+
+_T = TypeVar("_T")
+
+_TypeEngineArgument = Union[Type[TypeEngine[_T]], TypeEngine[_T]]
+
+# Redefine operatos that would only take a column expresion to also take the (virtual)
+# types of Pydantic models, e.g. str instead of only Mapped[str].
+
+
+def all_(expr: Union[_ColumnExpressionArgument[_T], _T]) -> CollectionAggregate[bool]:
+    return sqlalchemy.all_(expr)  # type: ignore[arg-type]
+
+
+def and_(
+    initial_clause: Union[Literal[True], _ColumnExpressionArgument[bool], bool],
+    *clauses: Union[_ColumnExpressionArgument[bool], bool],
+) -> ColumnElement[bool]:
+    return sqlalchemy.and_(initial_clause, *clauses)  # type: ignore[arg-type]
+
+
+def any_(expr: Union[_ColumnExpressionArgument[_T], _T]) -> CollectionAggregate[bool]:
+    return sqlalchemy.any_(expr)  # type: ignore[arg-type]
+
+
+def asc(
+    column: Union[_ColumnExpressionOrStrLabelArgument[_T], _T],
+) -> UnaryExpression[_T]:
+    return sqlalchemy.asc(column)  # type: ignore[arg-type]
+
+
+def collate(
+    expression: Union[_ColumnExpressionArgument[str], str], collation: str
+) -> BinaryExpression[str]:
+    return sqlalchemy.collate(expression, collation)  # type: ignore[arg-type]
+
+
+def between(
+    expr: Union[_ColumnExpressionOrLiteralArgument[_T], _T],
+    lower_bound: Any,
+    upper_bound: Any,
+    symmetric: bool = False,
+) -> BinaryExpression[bool]:
+    return sqlalchemy.between(expr, lower_bound, upper_bound, symmetric=symmetric)  # type: ignore[arg-type]
+
+
+def not_(clause: Union[_ColumnExpressionArgument[_T], _T]) -> ColumnElement[_T]:
+    return sqlalchemy.not_(clause)  # type: ignore[arg-type]
+
+
+def case(
+    *whens: Union[
+        Tuple[Union[_ColumnExpressionArgument[bool], bool], Any], Mapping[Any, Any]
+    ],
+    value: Optional[Any] = None,
+    else_: Optional[Any] = None,
+) -> Case[Any]:
+    return sqlalchemy.case(*whens, value=value, else_=else_)  # type: ignore[arg-type]
+
+
+def cast(
+    expression: Union[_ColumnExpressionOrLiteralArgument[Any], Any],
+    type_: "_TypeEngineArgument[_T]",
+) -> Cast[_T]:
+    return sqlalchemy.cast(expression, type_)  # type: ignore[arg-type]
+
+
+def try_cast(
+    expression: Union[_ColumnExpressionOrLiteralArgument[Any], Any],
+    type_: "_TypeEngineArgument[_T]",
+) -> TryCast[_T]:
+    return sqlalchemy.try_cast(expression, type_)  # type: ignore[arg-type]
+
+
+def desc(
+    column: Union[_ColumnExpressionOrStrLabelArgument[_T], _T],
+) -> UnaryExpression[_T]:
+    return sqlalchemy.desc(column)  # type: ignore[arg-type]
+
+
+def distinct(expr: Union[_ColumnExpressionArgument[_T], _T]) -> UnaryExpression[_T]:
+    return sqlalchemy.distinct(expr)  # type: ignore[arg-type]
+
 
-_TSelect = TypeVar("_TSelect")
+def bitwise_not(expr: Union[_ColumnExpressionArgument[_T], _T]) -> UnaryExpression[_T]:
+    return sqlalchemy.bitwise_not(expr)  # type: ignore[arg-type]
 
-class Select(_Select, Generic[_TSelect]):
+
+def extract(field: str, expr: Union[_ColumnExpressionArgument[Any], Any]) -> Extract:
+    return sqlalchemy.extract(field, expr)  # type: ignore[arg-type]
+
+
+def funcfilter(
+    func: FunctionElement[_T], *criterion: Union[_ColumnExpressionArgument[bool], bool]
+) -> FunctionFilter[_T]:
+    return sqlalchemy.funcfilter(func, *criterion)  # type: ignore[arg-type]
+
+
+def label(
+    name: str,
+    element: Union[_ColumnExpressionArgument[_T], _T],
+    type_: Optional["_TypeEngineArgument[_T]"] = None,
+) -> Label[_T]:
+    return sqlalchemy.label(name, element, type_=type_)  # type: ignore[arg-type]
+
+
+def nulls_first(
+    column: Union[_ColumnExpressionArgument[_T], _T]
+) -> UnaryExpression[_T]:
+    return sqlalchemy.nulls_first(column)  # type: ignore[arg-type]
+
+
+def nulls_last(column: Union[_ColumnExpressionArgument[_T], _T]) -> UnaryExpression[_T]:
+    return sqlalchemy.nulls_last(column)  # type: ignore[arg-type]
+
+
+def or_(  # type: ignore[empty-body]
+    initial_clause: Union[Literal[False], _ColumnExpressionArgument[bool], bool],
+    *clauses: Union[_ColumnExpressionArgument[bool], bool],
+) -> ColumnElement[bool]:
+    return sqlalchemy.or_(initial_clause, *clauses)  # type: ignore[arg-type]
+
+
+def over(
+    element: FunctionElement[_T],
+    partition_by: Optional[
+        Union[
+            Iterable[Union[_ColumnExpressionArgument[Any], Any]],
+            _ColumnExpressionArgument[Any],
+            Any,
+        ]
+    ] = None,
+    order_by: Optional[
+        Union[
+            Iterable[Union[_ColumnExpressionArgument[Any], Any]],
+            _ColumnExpressionArgument[Any],
+            Any,
+        ]
+    ] = None,
+    range_: Optional[Tuple[Optional[int], Optional[int]]] = None,
+    rows: Optional[Tuple[Optional[int], Optional[int]]] = None,
+) -> Over[_T]:
+    return sqlalchemy.over(
+        element, partition_by=partition_by, order_by=order_by, range_=range_, rows=rows
+    )  # type: ignore[arg-type]
+
+
+def tuple_(
+    *clauses: Union[_ColumnExpressionArgument[Any], Any],
+    types: Optional[Sequence["_TypeEngineArgument[Any]"]] = None,
+) -> Tuple[Any, ...]:
+    return sqlalchemy.tuple_(*clauses, types=types)  # type: ignore[return-value]
+
+
+def type_coerce(
+    expression: Union[_ColumnExpressionOrLiteralArgument[Any], Any],
+    type_: "_TypeEngineArgument[_T]",
+) -> TypeCoerce[_T]:
+    return sqlalchemy.type_coerce(expression, type_)  # type: ignore[arg-type]
+
+
+def within_group(
+    element: FunctionElement[_T], *order_by: Union[_ColumnExpressionArgument[Any], Any]
+) -> WithinGroup[_T]:
+    return sqlalchemy.within_group(element, *order_by)  # type: ignore[arg-type]
+
+
+# Separate this class in SelectBase, Select, and SelectOfScalar so that they can share
+# where and having without having type overlap incompatibility in session.exec().
+class SelectBase(_Select[Tuple[_T]]):
     inherit_cache = True
 
+    def where(self, *whereclause: Union[_ColumnExpressionArgument[bool], bool]) -> Self:
+        """Return a new `Select` construct with the given expression added to
+        its `WHERE` clause, joined to the existing clause via `AND`, if any.
+        """
+        return super().where(*whereclause)  # type: ignore[arg-type]
+
+    def having(self, *having: Union[_ColumnExpressionArgument[bool], bool]) -> Self:
+        """Return a new `Select` construct with the given expression added to
+        its `HAVING` clause, joined to the existing clause via `AND`, if any.
+        """
+        return super().having(*having)  # type: ignore[arg-type]
+
+
+class Select(SelectBase[_T]):
+    inherit_cache = True
+
+
 # This is not comparable to sqlalchemy.sql.selectable.ScalarSelect, that has a different
 # purpose. This is the same as a normal SQLAlchemy Select class where there's only one
 # entity, so the result will be converted to a scalar by default. This way writing
 # for loops on the results will feel natural.
-class SelectOfScalar(_Select, Generic[_TSelect]):
+class SelectOfScalar(SelectBase[_T]):
     inherit_cache = True
 
-if TYPE_CHECKING:  # pragma: no cover
-    from ..main import SQLModel
+
+_TCCA = Union[
+    TypedColumnsClauseRole[_T],
+    SQLCoreOperations[_T],
+    Type[_T],
+]
 
 # Generated TypeVars start
 
+
 {% for i in range(number_of_types) %}
 _TScalar_{{ i }} = TypeVar(
     "_TScalar_{{ i }}",
@@ -51,19 +266,19 @@ _TScalar_{{ i }} = TypeVar(
     None,
 )
 
-_TModel_{{ i }} = TypeVar("_TModel_{{ i }}", bound="SQLModel")
+_T{{ i }} = TypeVar("_T{{ i }}")
 
 {% endfor %}
 
 # Generated TypeVars end
 
 @overload
-def select(entity_0: _TScalar_0, **kw: Any) -> SelectOfScalar[_TScalar_0]:  # type: ignore
+def select(__ent0: _TScalar_0) -> SelectOfScalar[_TScalar_0]:  # type: ignore
     ...
 
 
 @overload
-def select(entity_0: Type[_TModel_0], **kw: Any) -> SelectOfScalar[_TModel_0]:  # type: ignore
+def select(__ent0: _TCCA[_T0]) -> SelectOfScalar[_T0]:
     ...
 
 
@@ -73,7 +288,7 @@ def select(entity_0: Type[_TModel_0], **kw: Any) -> SelectOfScalar[_TModel_0]:
 
 @overload
 def select(  # type: ignore
-    {% for arg in signature[0] %}{{ arg.name }}: {{ arg.annotation }}, {% endfor %}**kw: Any,
+    {% for arg in signature[0] %}{{ arg.name }}: {{ arg.annotation }}, {% endfor %}
     ) -> Select[Tuple[{%for ret in signature[1] %}{{ ret }} {% if not loop.last %}, {% endif %}{% endfor %}]]:
     ...
 
@@ -81,14 +296,14 @@ def select(  # type: ignore
 
 # Generated overloads end
 
-def select(*entities: Any, **kw: Any) -> Union[Select, SelectOfScalar]:  # type: ignore
+
+def select(*entities: Any) -> Union[Select, SelectOfScalar]:  # type: ignore
     if len(entities) == 1:
-        return SelectOfScalar._create(*entities, **kw)  # type: ignore
-    return Select._create(*entities, **kw)  # type: ignore
+        return SelectOfScalar(*entities)
+    return Select(*entities)
 
 
-# TODO: add several @overload from Python types to SQLAlchemy equivalents
-def col(column_expression: Any) -> ColumnClause:  # type: ignore
+def col(column_expression: _T) -> Mapped[_T]:
     if not isinstance(column_expression, (ColumnClause, Column, InstrumentedAttribute)):
         raise RuntimeError(f"Not a SQLAlchemy column: {column_expression}")
-    return column_expression
+    return column_expression  # type: ignore
index 17d9b06126d12806af09ab5b343d4fb3a525b3c0..5a4bb04ef12a6f424eb2a07addb98a7dcef5a360 100644 (file)
@@ -15,7 +15,7 @@ class AutoString(types.TypeDecorator):  # type: ignore
     def load_dialect_impl(self, dialect: Dialect) -> "types.TypeEngine[Any]":
         impl = cast(types.String, self.impl)
         if impl.length is None and dialect.name == "mysql":
-            return dialect.type_descriptor(types.String(self.mysql_default_length))  # type: ignore
+            return dialect.type_descriptor(types.String(self.mysql_default_length))
         return super().load_dialect_impl(dialect)
 
 
@@ -32,11 +32,11 @@ class GUID(types.TypeDecorator):  # type: ignore
     impl = CHAR
     cache_ok = True
 
-    def load_dialect_impl(self, dialect: Dialect) -> TypeEngine:  # type: ignore
+    def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]:
         if dialect.name == "postgresql":
-            return dialect.type_descriptor(UUID())  # type: ignore
+            return dialect.type_descriptor(UUID())
         else:
-            return dialect.type_descriptor(CHAR(32))  # type: ignore
+            return dialect.type_descriptor(CHAR(32))
 
     def process_bind_param(self, value: Any, dialect: Dialect) -> Optional[str]:
         if value is None:
index b08affb92095863c660b83cfe0236f0951011e9d..6a55d6cb986bb934fc51c95b49fec78bd97befaa 100644 (file)
@@ -59,7 +59,7 @@ def test_tutorial(clear_sqlmodel):
         response = client.get("/openapi.json")
         assert response.status_code == 200, response.text
         assert response.json() == {
-            "openapi": "3.0.2",
+            "openapi": "3.1.0",
             "info": {"title": "FastAPI", "version": "0.1.0"},
             "paths": {
                 "/heroes/": {
@@ -315,7 +315,9 @@ def test_tutorial(clear_sqlmodel):
                             "loc": {
                                 "title": "Location",
                                 "type": "array",
-                                "items": {"type": "string"},
+                                "items": {
+                                    "anyOf": [{"type": "string"}, {"type": "integer"}]
+                                },
                             },
                             "msg": {"title": "Message", "type": "string"},
                             "type": {"title": "Error Type", "type": "string"},
index 0aee3ca004463f0592e75a8af01c1292843085dc..2709231504a8b57a847f0321e60633ae26c5dc93 100644 (file)
@@ -64,7 +64,7 @@ def test_tutorial(clear_sqlmodel):
         response = client.get("/openapi.json")
         assert response.status_code == 200, response.text
         assert response.json() == {
-            "openapi": "3.0.2",
+            "openapi": "3.1.0",
             "info": {"title": "FastAPI", "version": "0.1.0"},
             "paths": {
                 "/heroes/": {
@@ -239,7 +239,9 @@ def test_tutorial(clear_sqlmodel):
                             "loc": {
                                 "title": "Location",
                                 "type": "array",
-                                "items": {"type": "string"},
+                                "items": {
+                                    "anyOf": [{"type": "string"}, {"type": "integer"}]
+                                },
                             },
                             "msg": {"title": "Message", "type": "string"},
                             "type": {"title": "Error Type", "type": "string"},
index 8d99cf9f5bdd83793e8b38f65268eee17d461427..dc5a3cb8ff6bd39925dbef2699909874d2b134f8 100644 (file)
@@ -5,7 +5,7 @@ from sqlmodel import create_engine
 from sqlmodel.pool import StaticPool
 
 openapi_schema = {
-    "openapi": "3.0.2",
+    "openapi": "3.1.0",
     "info": {"title": "FastAPI", "version": "0.1.0"},
     "paths": {
         "/heroes/": {
@@ -103,7 +103,7 @@ openapi_schema = {
                     "loc": {
                         "title": "Location",
                         "type": "array",
-                        "items": {"type": "string"},
+                        "items": {"anyOf": [{"type": "string"}, {"type": "integer"}]},
                     },
                     "msg": {"title": "Message", "type": "string"},
                     "type": {"title": "Error Type", "type": "string"},
index 94a41b307622d689d8f233298408aee18604f784..e3c20404c0bd1f9656bf70e370e3a604310525bd 100644 (file)
@@ -5,7 +5,7 @@ from sqlmodel import create_engine
 from sqlmodel.pool import StaticPool
 
 openapi_schema = {
-    "openapi": "3.0.2",
+    "openapi": "3.1.0",
     "info": {"title": "FastAPI", "version": "0.1.0"},
     "paths": {
         "/heroes/": {
@@ -103,7 +103,7 @@ openapi_schema = {
                     "loc": {
                         "title": "Location",
                         "type": "array",
-                        "items": {"type": "string"},
+                        "items": {"anyOf": [{"type": "string"}, {"type": "integer"}]},
                     },
                     "msg": {"title": "Message", "type": "string"},
                     "type": {"title": "Error Type", "type": "string"},
index 0609ae41ffe287955daffe2fe9998074f6bf3a00..0a599574d5dba84b13e362df1ebee307a9db872d 100644 (file)
@@ -3,7 +3,7 @@ from sqlmodel import create_engine
 from sqlmodel.pool import StaticPool
 
 openapi_schema = {
-    "openapi": "3.0.2",
+    "openapi": "3.1.0",
     "info": {"title": "FastAPI", "version": "0.1.0"},
     "paths": {
         "/heroes/": {
@@ -135,7 +135,7 @@ openapi_schema = {
                     "loc": {
                         "title": "Location",
                         "type": "array",
-                        "items": {"type": "string"},
+                        "items": {"anyOf": [{"type": "string"}, {"type": "integer"}]},
                     },
                     "msg": {"title": "Message", "type": "string"},
                     "type": {"title": "Error Type", "type": "string"},
index 8869862e95baa804bea49698eb351e6a779c66b2..fb08b9a5fddd13db73f58013514bc0996e012bdd 100644 (file)
@@ -107,7 +107,7 @@ def test_tutorial(clear_sqlmodel):
         response = client.get("/openapi.json")
         assert response.status_code == 200, response.text
         assert response.json() == {
-            "openapi": "3.0.2",
+            "openapi": "3.1.0",
             "info": {"title": "FastAPI", "version": "0.1.0"},
             "paths": {
                 "/heroes/": {
@@ -622,7 +622,9 @@ def test_tutorial(clear_sqlmodel):
                             "loc": {
                                 "title": "Location",
                                 "type": "array",
-                                "items": {"type": "string"},
+                                "items": {
+                                    "anyOf": [{"type": "string"}, {"type": "integer"}]
+                                },
                             },
                             "msg": {"title": "Message", "type": "string"},
                             "type": {"title": "Error Type", "type": "string"},
index ebb3046ef31a68a8f0c95bd8a6e9acc9551fabfb..968fefa8ca6dd2f88b2e430b7b5cab071e8c5bf4 100644 (file)
@@ -3,7 +3,7 @@ from sqlmodel import create_engine
 from sqlmodel.pool import StaticPool
 
 openapi_schema = {
-    "openapi": "3.0.2",
+    "openapi": "3.1.0",
     "info": {"title": "FastAPI", "version": "0.1.0"},
     "paths": {
         "/heroes/": {
@@ -91,7 +91,7 @@ openapi_schema = {
                     "loc": {
                         "title": "Location",
                         "type": "array",
-                        "items": {"type": "string"},
+                        "items": {"anyOf": [{"type": "string"}, {"type": "integer"}]},
                     },
                     "msg": {"title": "Message", "type": "string"},
                     "type": {"title": "Error Type", "type": "string"},
index cb0a6f9282a3b6aa28b3779da1f75b8d4fec650e..6f97cbf92be18ff79ec528b61610aff1528f23e7 100644 (file)
@@ -59,7 +59,7 @@ def test_tutorial(clear_sqlmodel):
         response = client.get("/openapi.json")
         assert response.status_code == 200, response.text
         assert response.json() == {
-            "openapi": "3.0.2",
+            "openapi": "3.1.0",
             "info": {"title": "FastAPI", "version": "0.1.0"},
             "paths": {
                 "/heroes/": {
@@ -315,7 +315,9 @@ def test_tutorial(clear_sqlmodel):
                             "loc": {
                                 "title": "Location",
                                 "type": "array",
-                                "items": {"type": "string"},
+                                "items": {
+                                    "anyOf": [{"type": "string"}, {"type": "integer"}]
+                                },
                             },
                             "msg": {"title": "Message", "type": "string"},
                             "type": {"title": "Error Type", "type": "string"},
index eb834ec2a4ae55f0f11a95656e71480b2c094118..435155d6e93873ff3e1993b9ad70e20c10c877b3 100644 (file)
@@ -3,7 +3,7 @@ from sqlmodel import create_engine
 from sqlmodel.pool import StaticPool
 
 openapi_schema = {
-    "openapi": "3.0.2",
+    "openapi": "3.1.0",
     "info": {"title": "FastAPI", "version": "0.1.0"},
     "paths": {
         "/heroes/": {
@@ -79,7 +79,7 @@ openapi_schema = {
                     "loc": {
                         "title": "Location",
                         "type": "array",
-                        "items": {"type": "string"},
+                        "items": {"anyOf": [{"type": "string"}, {"type": "integer"}]},
                     },
                     "msg": {"title": "Message", "type": "string"},
                     "type": {"title": "Error Type", "type": "string"},
index e66c975142f3b464ba5013480e8895e7e60da274..42f87cef76b6078aae51d6d5e15eae400cb4ec47 100644 (file)
@@ -94,7 +94,7 @@ def test_tutorial(clear_sqlmodel):
         response = client.get("/openapi.json")
         assert response.status_code == 200, response.text
         assert response.json() == {
-            "openapi": "3.0.2",
+            "openapi": "3.1.0",
             "info": {"title": "FastAPI", "version": "0.1.0"},
             "paths": {
                 "/heroes/": {
@@ -579,7 +579,9 @@ def test_tutorial(clear_sqlmodel):
                             "loc": {
                                 "title": "Location",
                                 "type": "array",
-                                "items": {"type": "string"},
+                                "items": {
+                                    "anyOf": [{"type": "string"}, {"type": "integer"}]
+                                },
                             },
                             "msg": {"title": "Message", "type": "string"},
                             "type": {"title": "Error Type", "type": "string"},
index 49906256c9cfd50e1eb7e9b680f402bfc5440d56..a4573ef11bc51555f7dccebbd42bb7032bbea48e 100644 (file)
@@ -66,7 +66,7 @@ def test_tutorial(clear_sqlmodel):
         response = client.get("/openapi.json")
         assert response.status_code == 200, response.text
         assert response.json() == {
-            "openapi": "3.0.2",
+            "openapi": "3.1.0",
             "info": {"title": "FastAPI", "version": "0.1.0"},
             "paths": {
                 "/heroes/": {
@@ -294,7 +294,9 @@ def test_tutorial(clear_sqlmodel):
                             "loc": {
                                 "title": "Location",
                                 "type": "array",
-                                "items": {"type": "string"},
+                                "items": {
+                                    "anyOf": [{"type": "string"}, {"type": "integer"}]
+                                },
                             },
                             "msg": {"title": "Message", "type": "string"},
                             "type": {"title": "Error Type", "type": "string"},