--- /dev/null
+__version__ = "0.0.1"
+
+# Re-export from SQLAlchemy
+from sqlalchemy.engine import create_mock_engine as create_mock_engine
+from sqlalchemy.engine import engine_from_config as engine_from_config
+from sqlalchemy.inspection import inspect as inspect
+from sqlalchemy.schema import BLANK_SCHEMA as BLANK_SCHEMA
+from sqlalchemy.schema import CheckConstraint as CheckConstraint
+from sqlalchemy.schema import Column as Column
+from sqlalchemy.schema import ColumnDefault as ColumnDefault
+from sqlalchemy.schema import Computed as Computed
+from sqlalchemy.schema import Constraint as Constraint
+from sqlalchemy.schema import DDL as DDL
+from sqlalchemy.schema import DefaultClause as DefaultClause
+from sqlalchemy.schema import FetchedValue as FetchedValue
+from sqlalchemy.schema import ForeignKey as ForeignKey
+from sqlalchemy.schema import ForeignKeyConstraint as ForeignKeyConstraint
+from sqlalchemy.schema import Identity as Identity
+from sqlalchemy.schema import Index as Index
+from sqlalchemy.schema import MetaData as MetaData
+from sqlalchemy.schema import PrimaryKeyConstraint as PrimaryKeyConstraint
+from sqlalchemy.schema import Sequence as Sequence
+from sqlalchemy.schema import Table as Table
+from sqlalchemy.schema import ThreadLocalMetaData as ThreadLocalMetaData
+from sqlalchemy.schema import UniqueConstraint as UniqueConstraint
+from sqlalchemy.sql import alias as alias
+from sqlalchemy.sql import all_ as all_
+from sqlalchemy.sql import and_ as and_
+from sqlalchemy.sql import any_ as any_
+from sqlalchemy.sql import asc as asc
+from sqlalchemy.sql import between as between
+from sqlalchemy.sql import bindparam as bindparam
+from sqlalchemy.sql import case as case
+from sqlalchemy.sql import cast as cast
+from sqlalchemy.sql import collate as collate
+from sqlalchemy.sql import column as column
+from sqlalchemy.sql import delete as delete
+from sqlalchemy.sql import desc as desc
+from sqlalchemy.sql import distinct as distinct
+from sqlalchemy.sql import except_ as except_
+from sqlalchemy.sql import except_all as except_all
+from sqlalchemy.sql import exists as exists
+from sqlalchemy.sql import extract as extract
+from sqlalchemy.sql import false as false
+from sqlalchemy.sql import func as func
+from sqlalchemy.sql import funcfilter as funcfilter
+from sqlalchemy.sql import insert as insert
+from sqlalchemy.sql import intersect as intersect
+from sqlalchemy.sql import intersect_all as intersect_all
+from sqlalchemy.sql import join as join
+from sqlalchemy.sql import LABEL_STYLE_DEFAULT as LABEL_STYLE_DEFAULT
+from sqlalchemy.sql import (
+ LABEL_STYLE_DISAMBIGUATE_ONLY as LABEL_STYLE_DISAMBIGUATE_ONLY,
+)
+from sqlalchemy.sql import LABEL_STYLE_NONE as LABEL_STYLE_NONE
+from sqlalchemy.sql import (
+ LABEL_STYLE_TABLENAME_PLUS_COL as LABEL_STYLE_TABLENAME_PLUS_COL,
+)
+from sqlalchemy.sql import lambda_stmt as lambda_stmt
+from sqlalchemy.sql import lateral as lateral
+from sqlalchemy.sql import literal as literal
+from sqlalchemy.sql import literal_column as literal_column
+from sqlalchemy.sql import modifier as modifier
+from sqlalchemy.sql import not_ as not_
+from sqlalchemy.sql import null as null
+from sqlalchemy.sql import nulls_first as nulls_first
+from sqlalchemy.sql import nulls_last as nulls_last
+from sqlalchemy.sql import nullsfirst as nullsfirst
+from sqlalchemy.sql import nullslast as nullslast
+from sqlalchemy.sql import or_ as or_
+from sqlalchemy.sql import outerjoin as outerjoin
+from sqlalchemy.sql import outparam as outparam
+from sqlalchemy.sql import over as over
+from sqlalchemy.sql import subquery as subquery
+from sqlalchemy.sql import table as table
+from sqlalchemy.sql import tablesample as tablesample
+from sqlalchemy.sql import text as text
+from sqlalchemy.sql import true as true
+from sqlalchemy.sql import tuple_ as tuple_
+from sqlalchemy.sql import type_coerce as type_coerce
+from sqlalchemy.sql import union as union
+from sqlalchemy.sql import union_all as union_all
+from sqlalchemy.sql import update as update
+from sqlalchemy.sql import values as values
+from sqlalchemy.sql import within_group as within_group
+from sqlalchemy.types import ARRAY as ARRAY
+from sqlalchemy.types import BIGINT as BIGINT
+from sqlalchemy.types import BigInteger as BigInteger
+from sqlalchemy.types import BINARY as BINARY
+from sqlalchemy.types import BLOB as BLOB
+from sqlalchemy.types import BOOLEAN as BOOLEAN
+from sqlalchemy.types import Boolean as Boolean
+from sqlalchemy.types import CHAR as CHAR
+from sqlalchemy.types import CLOB as CLOB
+from sqlalchemy.types import DATE as DATE
+from sqlalchemy.types import Date as Date
+from sqlalchemy.types import DATETIME as DATETIME
+from sqlalchemy.types import DateTime as DateTime
+from sqlalchemy.types import DECIMAL as DECIMAL
+from sqlalchemy.types import Enum as Enum
+from sqlalchemy.types import FLOAT as FLOAT
+from sqlalchemy.types import Float as Float
+from sqlalchemy.types import INT as INT
+from sqlalchemy.types import INTEGER as INTEGER
+from sqlalchemy.types import Integer as Integer
+from sqlalchemy.types import Interval as Interval
+from sqlalchemy.types import JSON as JSON
+from sqlalchemy.types import LargeBinary as LargeBinary
+from sqlalchemy.types import NCHAR as NCHAR
+from sqlalchemy.types import NUMERIC as NUMERIC
+from sqlalchemy.types import Numeric as Numeric
+from sqlalchemy.types import NVARCHAR as NVARCHAR
+from sqlalchemy.types import PickleType as PickleType
+from sqlalchemy.types import REAL as REAL
+from sqlalchemy.types import SMALLINT as SMALLINT
+from sqlalchemy.types import SmallInteger as SmallInteger
+from sqlalchemy.types import String as String
+from sqlalchemy.types import TEXT as TEXT
+from sqlalchemy.types import Text as Text
+from sqlalchemy.types import TIME as TIME
+from sqlalchemy.types import Time as Time
+from sqlalchemy.types import TIMESTAMP as TIMESTAMP
+from sqlalchemy.types import TypeDecorator as TypeDecorator
+from sqlalchemy.types import Unicode as Unicode
+from sqlalchemy.types import UnicodeText as UnicodeText
+from sqlalchemy.types import VARBINARY as VARBINARY
+from sqlalchemy.types import VARCHAR as VARCHAR
+
+# Extensions and modifications of SQLAlchemy in SQLModel
+from .engine.create import create_engine as create_engine
+from .orm.session import Session as Session
+from .sql.expression import select as select
+from .sql.expression import col as col
+from .sql.sqltypes import AutoString as AutoString
+
+# Export SQLModel specifics (equivalent to Pydantic)
+from .main import SQLModel as SQLModel
+from .main import Field as Field
+from .main import Relationship as Relationship
--- /dev/null
+from typing import Any, TypeVar
+
+
+class _DefaultPlaceholder:
+ """
+ You shouldn't use this class directly.
+
+ It's used internally to recognize when a default value has been overwritten, even
+ if the overriden default value was truthy.
+ """
+
+ def __init__(self, value: Any):
+ self.value = value
+
+ def __bool__(self) -> bool:
+ return bool(self.value)
+
+ def __eq__(self, o: object) -> bool:
+ return isinstance(o, _DefaultPlaceholder) and o.value == self.value
+
+
+_TDefaultType = TypeVar("_TDefaultType")
+
+
+def Default(value: _TDefaultType) -> _TDefaultType:
+ """
+ You shouldn't use this function directly.
+
+ It's used internally to recognize when a default value has been overwritten, even
+ if the overriden default value was truthy.
+ """
+ return _DefaultPlaceholder(value) # type: ignore
--- /dev/null
+import json
+import sqlite3
+from typing import Any, Callable, Dict, List, Optional, Type, Union
+
+from sqlalchemy import create_engine as _create_engine
+from sqlalchemy.engine.url import URL
+from sqlalchemy.future import Engine as _FutureEngine
+from sqlalchemy.pool import Pool
+from typing_extensions import Literal, TypedDict
+
+from ..default import Default, _DefaultPlaceholder
+
+# Types defined in sqlalchemy2-stubs, but can't be imported, so re-define here
+
+_Debug = Literal["debug"]
+
+_IsolationLevel = Literal[
+ "SERIALIZABLE",
+ "REPEATABLE READ",
+ "READ COMMITTED",
+ "READ UNCOMMITTED",
+ "AUTOCOMMIT",
+]
+_ParamStyle = Literal["qmark", "numeric", "named", "format", "pyformat"]
+_ResetOnReturn = Literal["rollback", "commit"]
+
+
+class _SQLiteConnectArgs(TypedDict, total=False):
+ timeout: float
+ detect_types: Any
+ isolation_level: Optional[Literal["DEFERRED", "IMMEDIATE", "EXCLUSIVE"]]
+ check_same_thread: bool
+ factory: Type[sqlite3.Connection]
+ cached_statements: int
+ uri: bool
+
+
+_ConnectArgs = Union[_SQLiteConnectArgs, Dict[str, Any]]
+
+
+# Re-define create_engine to have by default future=True, and assume that's what is used
+# Also show the default values used for each parameter, but don't set them unless
+# explicitly passed as arguments by the user to prevent errors. E.g. SQLite doesn't
+# support pool connection arguments.
+def create_engine(
+ url: Union[str, URL],
+ *,
+ connect_args: _ConnectArgs = Default({}), # type: ignore
+ echo: Union[bool, _Debug] = Default(False),
+ echo_pool: Union[bool, _Debug] = Default(False),
+ enable_from_linting: bool = Default(True),
+ encoding: str = Default("utf-8"),
+ execution_options: Dict[Any, Any] = Default({}),
+ future: bool = True,
+ hide_parameters: bool = Default(False),
+ implicit_returning: bool = Default(True),
+ isolation_level: Optional[_IsolationLevel] = Default(None),
+ json_deserializer: Callable[..., Any] = Default(json.loads),
+ json_serializer: Callable[..., Any] = Default(json.dumps),
+ label_length: Optional[int] = Default(None),
+ logging_name: Optional[str] = Default(None),
+ max_identifier_length: Optional[int] = Default(None),
+ max_overflow: int = Default(10),
+ module: Optional[Any] = Default(None),
+ paramstyle: Optional[_ParamStyle] = Default(None),
+ pool: Optional[Pool] = Default(None),
+ poolclass: Optional[Type[Pool]] = Default(None),
+ pool_logging_name: Optional[str] = Default(None),
+ pool_pre_ping: bool = Default(False),
+ pool_size: int = Default(5),
+ pool_recycle: int = Default(-1),
+ pool_reset_on_return: Optional[_ResetOnReturn] = Default("rollback"),
+ pool_timeout: float = Default(30),
+ pool_use_lifo: bool = Default(False),
+ plugins: Optional[List[str]] = Default(None),
+ query_cache_size: Optional[int] = Default(None),
+ **kwargs: Any,
+) -> _FutureEngine:
+ current_kwargs: Dict[str, Any] = {
+ "future": future,
+ }
+ if not isinstance(echo, _DefaultPlaceholder):
+ current_kwargs["echo"] = echo
+ if not isinstance(echo_pool, _DefaultPlaceholder):
+ current_kwargs["echo_pool"] = echo_pool
+ if not isinstance(enable_from_linting, _DefaultPlaceholder):
+ current_kwargs["enable_from_linting"] = enable_from_linting
+ if not isinstance(connect_args, _DefaultPlaceholder):
+ current_kwargs["connect_args"] = connect_args
+ if not isinstance(encoding, _DefaultPlaceholder):
+ current_kwargs["encoding"] = encoding
+ if not isinstance(execution_options, _DefaultPlaceholder):
+ current_kwargs["execution_options"] = execution_options
+ if not isinstance(hide_parameters, _DefaultPlaceholder):
+ current_kwargs["hide_parameters"] = hide_parameters
+ if not isinstance(implicit_returning, _DefaultPlaceholder):
+ current_kwargs["implicit_returning"] = implicit_returning
+ if not isinstance(isolation_level, _DefaultPlaceholder):
+ current_kwargs["isolation_level"] = isolation_level
+ if not isinstance(json_deserializer, _DefaultPlaceholder):
+ current_kwargs["json_deserializer"] = json_deserializer
+ if not isinstance(json_serializer, _DefaultPlaceholder):
+ current_kwargs["json_serializer"] = json_serializer
+ if not isinstance(label_length, _DefaultPlaceholder):
+ current_kwargs["label_length"] = label_length
+ if not isinstance(logging_name, _DefaultPlaceholder):
+ current_kwargs["logging_name"] = logging_name
+ if not isinstance(max_identifier_length, _DefaultPlaceholder):
+ current_kwargs["max_identifier_length"] = max_identifier_length
+ if not isinstance(max_overflow, _DefaultPlaceholder):
+ current_kwargs["max_overflow"] = max_overflow
+ if not isinstance(module, _DefaultPlaceholder):
+ current_kwargs["module"] = module
+ if not isinstance(paramstyle, _DefaultPlaceholder):
+ current_kwargs["paramstyle"] = paramstyle
+ if not isinstance(pool, _DefaultPlaceholder):
+ current_kwargs["pool"] = pool
+ if not isinstance(poolclass, _DefaultPlaceholder):
+ current_kwargs["poolclass"] = poolclass
+ if not isinstance(pool_logging_name, _DefaultPlaceholder):
+ current_kwargs["pool_logging_name"] = pool_logging_name
+ if not isinstance(pool_pre_ping, _DefaultPlaceholder):
+ current_kwargs["pool_pre_ping"] = pool_pre_ping
+ if not isinstance(pool_size, _DefaultPlaceholder):
+ current_kwargs["pool_size"] = pool_size
+ if not isinstance(pool_recycle, _DefaultPlaceholder):
+ current_kwargs["pool_recycle"] = pool_recycle
+ if not isinstance(pool_reset_on_return, _DefaultPlaceholder):
+ current_kwargs["pool_reset_on_return"] = pool_reset_on_return
+ if not isinstance(pool_timeout, _DefaultPlaceholder):
+ current_kwargs["pool_timeout"] = pool_timeout
+ if not isinstance(pool_use_lifo, _DefaultPlaceholder):
+ current_kwargs["pool_use_lifo"] = pool_use_lifo
+ if not isinstance(plugins, _DefaultPlaceholder):
+ current_kwargs["plugins"] = plugins
+ if not isinstance(query_cache_size, _DefaultPlaceholder):
+ current_kwargs["query_cache_size"] = query_cache_size
+ current_kwargs.update(kwargs)
+ return _create_engine(url, **current_kwargs)
--- /dev/null
+from typing import Generic, Iterator, List, Optional, TypeVar
+
+from sqlalchemy.engine.result import Result as _Result
+from sqlalchemy.engine.result import ScalarResult as _ScalarResult
+
+_T = TypeVar("_T")
+
+
+class ScalarResult(_ScalarResult, Generic[_T]):
+ def all(self) -> List[_T]:
+ return super().all()
+
+ def partitions(self, size: Optional[int] = None) -> Iterator[List[_T]]:
+ return super().partitions(size)
+
+ def fetchall(self) -> List[_T]:
+ return super().fetchall()
+
+ def fetchmany(self, size: Optional[int] = None) -> List[_T]:
+ return super().fetchmany(size)
+
+ def __iter__(self) -> Iterator[_T]:
+ return super().__iter__()
+
+ def __next__(self) -> _T:
+ return super().__next__()
+
+ def first(self) -> Optional[_T]:
+ return super().first()
+
+ def one_or_none(self) -> Optional[_T]:
+ return super().one_or_none()
+
+ def one(self) -> _T:
+ return super().one()
+
+
+class Result(_Result, Generic[_T]):
+ def scalars(self, index: int = 0) -> ScalarResult[_T]:
+ return super().scalars(index) # type: ignore
+
+ def __iter__(self) -> Iterator[_T]: # type: ignore
+ return super().__iter__() # type: ignore
+
+ def __next__(self) -> _T: # type: ignore
+ return super().__next__() # type: ignore
+
+ def partitions(self, size: Optional[int] = None) -> Iterator[List[_T]]: # type: ignore
+ return super().partitions(size) # type: ignore
+
+ def fetchall(self) -> List[_T]: # type: ignore
+ return super().fetchall() # type: ignore
+
+ def fetchone(self) -> Optional[_T]: # type: ignore
+ return super().fetchone() # type: ignore
+
+ def fetchmany(self, size: Optional[int] = None) -> List[_T]: # type: ignore
+ return super().fetchmany() # type: ignore
+
+ def all(self) -> List[_T]: # type: ignore
+ return super().all() # type: ignore
+
+ def first(self) -> Optional[_T]: # type: ignore
+ return super().first() # type: ignore
+
+ def one_or_none(self) -> Optional[_T]: # type: ignore
+ return super().one_or_none() # type: ignore
+
+ def scalar_one(self) -> _T:
+ return super().scalar_one() # type: ignore
+
+ def scalar_one_or_none(self) -> Optional[_T]:
+ return super().scalar_one_or_none() # type: ignore
+
+ def one(self) -> _T: # type: ignore
+ return super().one() # type: ignore
+
+ def scalar(self) -> Optional[_T]:
+ return super().scalar() # type: ignore
--- /dev/null
+from typing import Any, Mapping, Optional, Sequence, TypeVar, Union
+
+from sqlalchemy import util
+from sqlalchemy.ext.asyncio import AsyncSession as _AsyncSession
+from sqlalchemy.ext.asyncio import engine
+from sqlalchemy.ext.asyncio.engine import AsyncConnection, AsyncEngine
+from sqlalchemy.util.concurrency import greenlet_spawn
+from sqlmodel.sql.base import Executable
+
+from ...engine.result import ScalarResult
+from ...orm.session import Session
+from ...sql.expression import Select
+
+_T = TypeVar("_T")
+
+
+class AsyncSession(_AsyncSession):
+ sync_session: Session
+
+ def __init__(
+ self,
+ bind: Optional[Union[AsyncConnection, AsyncEngine]] = None,
+ binds: Optional[Mapping[object, Union[AsyncConnection, AsyncEngine]]] = None,
+ **kw,
+ ):
+ # All the same code of the original AsyncSession
+ kw["future"] = True
+ if bind:
+ self.bind = bind
+ bind = engine._get_sync_engine_or_connection(bind) # type: ignore
+
+ if binds:
+ self.binds = binds
+ binds = {
+ key: engine._get_sync_engine_or_connection(b) # type: ignore
+ for key, b in binds.items()
+ }
+
+ self.sync_session = self._proxied = self._assign_proxied( # type: ignore
+ Session(bind=bind, binds=binds, **kw) # type: ignore
+ )
+
+ async def exec(
+ self,
+ statement: Union[Select[_T], Executable[_T]],
+ params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
+ execution_options: Mapping[Any, Any] = util.EMPTY_DICT,
+ bind_arguments: Optional[Mapping[str, Any]] = None,
+ **kw: Any,
+ ) -> ScalarResult[_T]:
+ # TODO: the documentation says execution_options accepts a dict, but only
+ # util.immutabledict has the union() method. Is this a bug in SQLAlchemy?
+ execution_options = execution_options.union({"prebuffer_rows": True}) # type: ignore
+
+ return await greenlet_spawn( # type: ignore
+ self.sync_session.exec,
+ statement,
+ params=params,
+ execution_options=execution_options,
+ bind_arguments=bind_arguments,
+ **kw,
+ )
--- /dev/null
+import ipaddress
+import uuid
+import weakref
+from datetime import date, datetime, time, timedelta
+from decimal import Decimal
+from enum import Enum
+from pathlib import Path
+from typing import (
+ TYPE_CHECKING,
+ AbstractSet,
+ Any,
+ Callable,
+ ClassVar,
+ Dict,
+ ForwardRef,
+ List,
+ Mapping,
+ Optional,
+ Sequence,
+ Set,
+ Tuple,
+ Type,
+ TypeVar,
+ Union,
+ cast,
+)
+
+from pydantic import BaseModel
+from pydantic.errors import ConfigError, DictError
+from pydantic.fields import FieldInfo as PydanticFieldInfo
+from pydantic.fields import ModelField, Undefined, UndefinedType
+from pydantic.main import BaseConfig, ModelMetaclass, validate_model
+from pydantic.typing import NoArgAnyCallable, resolve_annotations
+from pydantic.utils import ROOT_KEY, Representation, ValueItems
+from sqlalchemy import (
+ Boolean,
+ Column,
+ Date,
+ DateTime,
+ Float,
+ ForeignKey,
+ Integer,
+ Interval,
+ Numeric,
+ inspect,
+)
+from sqlalchemy.orm import RelationshipProperty, declared_attr, registry, relationship
+from sqlalchemy.orm.attributes import set_attribute
+from sqlalchemy.orm.decl_api import DeclarativeMeta
+from sqlalchemy.orm.instrumentation import is_instrumented
+from sqlalchemy.sql.schema import MetaData
+from sqlalchemy.sql.sqltypes import LargeBinary, Time
+
+from .sql.sqltypes import GUID, AutoString
+
+_T = TypeVar("_T")
+
+
+def __dataclass_transform__(
+ *,
+ eq_default: bool = True,
+ order_default: bool = False,
+ kw_only_default: bool = False,
+ field_descriptors: Tuple[Union[type, Callable[..., Any]], ...] = (()),
+) -> Callable[[_T], _T]:
+ return lambda a: a
+
+
+class FieldInfo(PydanticFieldInfo):
+ def __init__(self, default: Any = Undefined, **kwargs: Any) -> None:
+ primary_key = kwargs.pop("primary_key", False)
+ nullable = kwargs.pop("nullable", Undefined)
+ foreign_key = kwargs.pop("foreign_key", Undefined)
+ index = kwargs.pop("index", Undefined)
+ sa_column = kwargs.pop("sa_column", Undefined)
+ sa_column_args = kwargs.pop("sa_column_args", Undefined)
+ sa_column_kwargs = kwargs.pop("sa_column_kwargs", Undefined)
+ if sa_column is not Undefined:
+ if sa_column_args is not Undefined:
+ raise RuntimeError(
+ "Passing sa_column_args is not supported when "
+ "also passing a sa_column"
+ )
+ if sa_column_kwargs is not Undefined:
+ raise RuntimeError(
+ "Passing sa_column_kwargs is not supported when "
+ "also passing a sa_column"
+ )
+ super().__init__(default=default, **kwargs)
+ self.primary_key = primary_key
+ self.nullable = nullable
+ self.foreign_key = foreign_key
+ self.index = index
+ self.sa_column = sa_column
+ self.sa_column_args = sa_column_args
+ self.sa_column_kwargs = sa_column_kwargs
+
+
+class RelationshipInfo(Representation):
+ def __init__(
+ self,
+ *,
+ back_populates: Optional[str] = None,
+ link_model: Optional[Any] = None,
+ sa_relationship: Optional[RelationshipProperty] = None,
+ sa_relationship_args: Optional[Sequence[Any]] = None,
+ sa_relationship_kwargs: Optional[Mapping[str, Any]] = None,
+ ) -> None:
+ if sa_relationship is not None:
+ if sa_relationship_args is not None:
+ raise RuntimeError(
+ "Passing sa_relationship_args is not supported when "
+ "also passing a sa_relationship"
+ )
+ if sa_relationship_kwargs is not None:
+ raise RuntimeError(
+ "Passing sa_relationship_kwargs is not supported when "
+ "also passing a sa_relationship"
+ )
+ self.back_populates = back_populates
+ self.link_model = link_model
+ self.sa_relationship = sa_relationship
+ self.sa_relationship_args = sa_relationship_args
+ self.sa_relationship_kwargs = sa_relationship_kwargs
+
+
+def Field(
+ default: Any = Undefined,
+ *,
+ default_factory: Optional[NoArgAnyCallable] = None,
+ alias: str = None,
+ title: str = None,
+ description: str = None,
+ exclude: Union[
+ AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
+ ] = None,
+ include: Union[
+ AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
+ ] = None,
+ const: bool = None,
+ gt: float = None,
+ ge: float = None,
+ lt: float = None,
+ le: float = None,
+ multiple_of: float = None,
+ min_items: int = None,
+ max_items: int = None,
+ min_length: int = None,
+ max_length: int = None,
+ allow_mutation: bool = True,
+ regex: str = None,
+ primary_key: bool = False,
+ foreign_key: Optional[Any] = None,
+ nullable: Union[bool, UndefinedType] = Undefined,
+ index: Union[bool, UndefinedType] = Undefined,
+ sa_column: Union[Column, UndefinedType] = Undefined,
+ sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined,
+ sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined,
+ schema_extra: Optional[Dict[str, Any]] = None,
+) -> Any:
+ current_schema_extra = schema_extra or {}
+ field_info = FieldInfo(
+ default,
+ default_factory=default_factory,
+ alias=alias,
+ title=title,
+ description=description,
+ exclude=exclude,
+ include=include,
+ const=const,
+ gt=gt,
+ ge=ge,
+ lt=lt,
+ le=le,
+ multiple_of=multiple_of,
+ min_items=min_items,
+ max_items=max_items,
+ min_length=min_length,
+ max_length=max_length,
+ allow_mutation=allow_mutation,
+ regex=regex,
+ primary_key=primary_key,
+ foreign_key=foreign_key,
+ nullable=nullable,
+ index=index,
+ sa_column=sa_column,
+ sa_column_args=sa_column_args,
+ sa_column_kwargs=sa_column_kwargs,
+ **current_schema_extra,
+ )
+ field_info._validate()
+ return field_info
+
+
+def Relationship(
+ *,
+ back_populates: Optional[str] = None,
+ link_model: Optional[Any] = None,
+ sa_relationship: Optional[RelationshipProperty] = None,
+ sa_relationship_args: Optional[Sequence[Any]] = None,
+ sa_relationship_kwargs: Optional[Mapping[str, Any]] = None,
+) -> Any:
+ relationship_info = RelationshipInfo(
+ back_populates=back_populates,
+ link_model=link_model,
+ sa_relationship=sa_relationship,
+ sa_relationship_args=sa_relationship_args,
+ sa_relationship_kwargs=sa_relationship_kwargs,
+ )
+ return relationship_info
+
+
+@__dataclass_transform__(kw_only_default=True, field_descriptors=(Field, FieldInfo))
+class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
+ __sqlmodel_relationships__: Dict[str, RelationshipInfo]
+ __config__: Type[BaseConfig]
+ __fields__: Dict[str, ModelField]
+
+ # Replicate SQLAlchemy
+ def __setattr__(cls, name: str, value: Any) -> None:
+ if getattr(cls.__config__, "table", False): # type: ignore
+ DeclarativeMeta.__setattr__(cls, name, value)
+ else:
+ super().__setattr__(name, value)
+
+ def __delattr__(cls, name: str) -> None:
+ if getattr(cls.__config__, "table", False): # type: ignore
+ DeclarativeMeta.__delattr__(cls, name)
+ else:
+ super().__delattr__(name)
+
+ # From Pydantic
+ def __new__(cls, name, bases, class_dict: dict, **kwargs) -> Any:
+ relationships: Dict[str, RelationshipInfo] = {}
+ dict_for_pydantic = {}
+ original_annotations = resolve_annotations(
+ class_dict.get("__annotations__", {}), class_dict.get("__module__", None)
+ )
+ pydantic_annotations = {}
+ relationship_annotations = {}
+ for k, v in class_dict.items():
+ if isinstance(v, RelationshipInfo):
+ relationships[k] = v
+ else:
+ dict_for_pydantic[k] = v
+ for k, v in original_annotations.items():
+ if k in relationships:
+ relationship_annotations[k] = v
+ else:
+ pydantic_annotations[k] = v
+ dict_used = {
+ **dict_for_pydantic,
+ "__weakref__": None,
+ "__sqlmodel_relationships__": relationships,
+ "__annotations__": pydantic_annotations,
+ }
+ # Duplicate logic from Pydantic to filter config kwargs because if they are
+ # passed directly including the registry Pydantic will pass them over to the
+ # superclass causing an error
+ allowed_config_kwargs: Set[str] = {
+ key
+ for key in dir(BaseConfig)
+ if not (
+ key.startswith("__") and key.endswith("__")
+ ) # skip dunder methods and attributes
+ }
+ pydantic_kwargs = kwargs.copy()
+ config_kwargs = {
+ key: pydantic_kwargs.pop(key)
+ for key in pydantic_kwargs.keys() & allowed_config_kwargs
+ }
+ new_cls = super().__new__(cls, name, bases, dict_used, **config_kwargs)
+ new_cls.__annotations__ = {
+ **relationship_annotations,
+ **pydantic_annotations,
+ **new_cls.__annotations__,
+ }
+
+ def get_config(name: str) -> Any:
+ config_class_value = getattr(new_cls.__config__, name, Undefined)
+ if config_class_value is not Undefined:
+ return config_class_value
+ kwarg_value = kwargs.get(name, Undefined)
+ if kwarg_value is not Undefined:
+ return kwarg_value
+ return Undefined
+
+ config_table = get_config("table")
+ if config_table is True:
+ # If it was passed by kwargs, ensure it's also set in config
+ new_cls.__config__.table = config_table
+ for k, v in new_cls.__fields__.items():
+ col = get_column_from_field(v)
+ setattr(new_cls, k, col)
+ # Set a config flag to tell FastAPI that this should be read with a field
+ # in orm_mode instead of preemptively converting it to a dict.
+ # This could be done by reading new_cls.__config__.table in FastAPI, but
+ # that's very specific about SQLModel, so let's have another config that
+ # other future tools based on Pydantic can use.
+ new_cls.__config__.read_with_orm_mode = True
+
+ config_registry = get_config("registry")
+ if config_registry is not Undefined:
+ config_registry = cast(registry, config_registry)
+ # If it was passed by kwargs, ensure it's also set in config
+ new_cls.__config__.registry = config_table
+ setattr(new_cls, "_sa_registry", config_registry)
+ setattr(new_cls, "metadata", config_registry.metadata)
+ setattr(new_cls, "__abstract__", True)
+ return new_cls
+
+ # Override SQLAlchemy, allow both SQLAlchemy and plain Pydantic models
+ def __init__(
+ cls, classname: str, bases: Tuple[type, ...], dict_: Dict[str, Any], **kw: Any
+ ) -> None:
+ # Only one of the base classes (or the current one) should be a table model
+ # this allows FastAPI cloning a SQLModel for the response_model without
+ # trying to create a new SQLAlchemy, for a new table, with the same name, that
+ # triggers an error
+ base_is_table = False
+ for base in bases:
+ config = getattr(base, "__config__")
+ if config and getattr(config, "table", False):
+ base_is_table = True
+ break
+ if getattr(cls.__config__, "table", False) and not base_is_table:
+ dict_used = dict_.copy()
+ for field_name, field_value in cls.__fields__.items():
+ dict_used[field_name] = get_column_from_field(field_value)
+ for rel_name, rel_info in cls.__sqlmodel_relationships__.items():
+ if rel_info.sa_relationship:
+ # There's a SQLAlchemy relationship declared, that takes precedence
+ # over anything else, use that and continue with the next attribute
+ dict_used[rel_name] = rel_info.sa_relationship
+ continue
+ ann = cls.__annotations__[rel_name]
+ temp_field = ModelField.infer(
+ name=rel_name,
+ value=rel_info,
+ annotation=ann,
+ class_validators=None,
+ config=BaseConfig,
+ )
+ relationship_to = temp_field.type_
+ if isinstance(temp_field.type_, ForwardRef):
+ relationship_to = temp_field.type_.__forward_arg__
+ rel_kwargs: Dict[str, Any] = {}
+ if rel_info.back_populates:
+ rel_kwargs["back_populates"] = rel_info.back_populates
+ if rel_info.link_model:
+ ins = inspect(rel_info.link_model)
+ local_table = getattr(ins, "local_table")
+ if local_table is None:
+ raise RuntimeError(
+ "Couldn't find the secondary table for "
+ f"model {rel_info.link_model}"
+ )
+ rel_kwargs["secondary"] = local_table
+ rel_args: List[Any] = []
+ if rel_info.sa_relationship_args:
+ rel_args.extend(rel_info.sa_relationship_args)
+ if rel_info.sa_relationship_kwargs:
+ rel_kwargs.update(rel_info.sa_relationship_kwargs)
+ rel_value: RelationshipProperty = relationship(
+ relationship_to, *rel_args, **rel_kwargs
+ )
+ dict_used[rel_name] = rel_value
+ DeclarativeMeta.__init__(cls, classname, bases, dict_used, **kw)
+ else:
+ ModelMetaclass.__init__(cls, classname, bases, dict_, **kw)
+
+
+def get_sqlachemy_type(field: ModelField) -> Any:
+ if issubclass(field.type_, str):
+ if field.field_info.max_length:
+ return AutoString(length=field.field_info.max_length)
+ return AutoString
+ if issubclass(field.type_, float):
+ return Float
+ if issubclass(field.type_, bool):
+ return Boolean
+ if issubclass(field.type_, int):
+ return Integer
+ if issubclass(field.type_, datetime):
+ return DateTime
+ if issubclass(field.type_, date):
+ return Date
+ if issubclass(field.type_, timedelta):
+ return Interval
+ if issubclass(field.type_, time):
+ return Time
+ if issubclass(field.type_, Enum):
+ return Enum
+ if issubclass(field.type_, bytes):
+ return LargeBinary
+ if issubclass(field.type_, Decimal):
+ return Numeric
+ if issubclass(field.type_, ipaddress.IPv4Address):
+ return AutoString
+ if issubclass(field.type_, ipaddress.IPv4Network):
+ return AutoString
+ if issubclass(field.type_, ipaddress.IPv6Address):
+ return AutoString
+ if issubclass(field.type_, ipaddress.IPv6Network):
+ return AutoString
+ if issubclass(field.type_, Path):
+ return AutoString
+ if issubclass(field.type_, uuid.UUID):
+ return GUID
+
+
+def get_column_from_field(field: ModelField) -> Column:
+ sa_column = getattr(field.field_info, "sa_column", Undefined)
+ if isinstance(sa_column, Column):
+ return sa_column
+ sa_type = get_sqlachemy_type(field)
+ primary_key = getattr(field.field_info, "primary_key", False)
+ nullable = not field.required
+ index = getattr(field.field_info, "index", Undefined)
+ if index is Undefined:
+ index = True
+ if hasattr(field.field_info, "nullable"):
+ field_nullable = getattr(field.field_info, "nullable")
+ if field_nullable != Undefined:
+ nullable = field_nullable
+ args = []
+ foreign_key = getattr(field.field_info, "foreign_key", None)
+ if foreign_key:
+ args.append(ForeignKey(foreign_key))
+ kwargs = {
+ "primary_key": primary_key,
+ "nullable": nullable,
+ "index": index,
+ }
+ sa_default = Undefined
+ if field.field_info.default_factory:
+ sa_default = field.field_info.default_factory
+ elif field.field_info.default is not Undefined:
+ sa_default = field.field_info.default
+ if sa_default is not Undefined:
+ kwargs["default"] = sa_default
+ sa_column_args = getattr(field.field_info, "sa_column_args", Undefined)
+ if sa_column_args is not Undefined:
+ args.extend(list(cast(Sequence, sa_column_args)))
+ sa_column_kwargs = getattr(field.field_info, "sa_column_kwargs", Undefined)
+ if sa_column_kwargs is not Undefined:
+ kwargs.update(cast(dict, sa_column_kwargs))
+ return Column(sa_type, *args, **kwargs)
+
+
+class_registry = weakref.WeakValueDictionary() # type: ignore
+
+default_registry = registry()
+
+
+class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry):
+ # SQLAlchemy needs to set weakref(s), Pydantic will set the other slots values
+ __slots__ = ("__weakref__",)
+ __tablename__: ClassVar[Union[str, Callable[..., str]]]
+ __sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty]]
+ __name__: ClassVar[str]
+ metadata: ClassVar[MetaData]
+
+ class Config:
+ orm_mode = True
+
+ def __new__(cls, *args, **kwargs) -> Any:
+ new_object = super().__new__(cls)
+ # SQLAlchemy doesn't call __init__ on the base class
+ # Ref: https://docs.sqlalchemy.org/en/14/orm/constructors.html
+ # Set __fields_set__ here, that would have been set when calling __init__
+ # in the Pydantic model so that when SQLAlchemy sets attributes that are
+ # added (e.g. when querying from DB) to the __fields_set__, this already exists
+ object.__setattr__(new_object, "__fields_set__", set())
+ return new_object
+
+ def __init__(__pydantic_self__, **data: Any) -> None:
+ # Uses something other than `self` the first arg to allow "self" as a
+ # settable attribute
+ if TYPE_CHECKING:
+ __pydantic_self__.__dict__: Dict[str, Any] = {}
+ __pydantic_self__.__fields_set__: Set[str] = set()
+ values, fields_set, validation_error = validate_model(
+ __pydantic_self__.__class__, data
+ )
+ # Only raise errors if not a SQLModel model
+ if (
+ not getattr(__pydantic_self__.__config__, "table", False)
+ and validation_error
+ ):
+ raise validation_error
+ # Do not set values as in Pydantic, pass them through setattr, so SQLAlchemy
+ # can handle them
+ # object.__setattr__(__pydantic_self__, '__dict__', values)
+ object.__setattr__(__pydantic_self__, "__fields_set__", fields_set)
+ for key, value in values.items():
+ setattr(__pydantic_self__, key, value)
+ non_pydantic_keys = data.keys() - values.keys()
+ for key in non_pydantic_keys:
+ if key in __pydantic_self__.__sqlmodel_relationships__:
+ setattr(__pydantic_self__, key, data[key])
+
+ def __setattr__(self, name: str, value: Any) -> None:
+ if name in {"_sa_instance_state"}:
+ self.__dict__[name] = value
+ return
+ else:
+ # Set in SQLAlchemy, before Pydantic to trigger events and updates
+ if getattr(self.__config__, "table", False):
+ if is_instrumented(self, name):
+ set_attribute(self, name, value)
+ # Set in Pydantic model to trigger possible validation changes, only for
+ # non relationship values
+ if name not in self.__sqlmodel_relationships__:
+ super().__setattr__(name, value)
+
+ @classmethod
+ def from_orm(cls: Type["SQLModel"], obj: Any, update: Dict[str, Any] = None):
+ # Duplicated from Pydantic
+ if not cls.__config__.orm_mode:
+ raise ConfigError(
+ "You must have the config attribute orm_mode=True to use from_orm"
+ )
+ obj = {ROOT_KEY: obj} if cls.__custom_root_type__ else cls._decompose_class(obj)
+ # SQLModel, support update dict
+ if update is not None:
+ obj = {**obj, **update}
+ # End SQLModel support dict
+ if not getattr(cls.__config__, "table", False):
+ # If not table, normal Pydantic code
+ m = cls.__new__(cls)
+ else:
+ # If table, create the new instance normally to make SQLAlchemy create
+ # the _sa_instance_state attribute
+ m = cls()
+ values, fields_set, validation_error = validate_model(cls, obj)
+ if validation_error:
+ raise validation_error
+ # Updated to trigger SQLAlchemy internal handling
+ if not getattr(cls.__config__, "table", False):
+ object.__setattr__(m, "__dict__", values)
+ else:
+ for key, value in values.items():
+ setattr(m, key, value)
+ # Continue with standard Pydantic logic
+ object.__setattr__(m, "__fields_set__", fields_set)
+ m._init_private_attributes()
+ return m
+
+ @classmethod
+ def parse_obj(
+ cls: Type["SQLModel"], obj: Any, update: Dict[str, Any] = None
+ ) -> "SQLModel":
+ obj = cls._enforce_dict_if_root(obj)
+ # SQLModel, support update dict
+ if update is not None:
+ obj = {**obj, **update}
+ # End SQLModel support dict
+ return super().parse_obj(obj)
+
+ def __repr_args__(self) -> Sequence[Tuple[Optional[str], Any]]:
+ # Don't show SQLAlchemy private attributes
+ return [(k, v) for k, v in self.__dict__.items() if not k.startswith("_sa_")]
+
+ # From Pydantic, override to enforce validation with dict
+ @classmethod
+ def validate(cls: Type["SQLModel"], value: Any) -> "SQLModel":
+ if isinstance(value, cls):
+ return value.copy() if cls.__config__.copy_on_model_validation else value
+
+ value = cls._enforce_dict_if_root(value)
+ if isinstance(value, dict):
+ values, fields_set, validation_error = validate_model(cls, value)
+ if validation_error:
+ raise validation_error
+ model = cls(**values)
+ # Reset fields set, this would have been done in Pydantic in __init__
+ object.__setattr__(model, "__fields_set__", fields_set)
+ return model
+ elif cls.__config__.orm_mode:
+ return cls.from_orm(value)
+ elif cls.__custom_root_type__:
+ return cls.parse_obj(value)
+ else:
+ try:
+ value_as_dict = dict(value)
+ except (TypeError, ValueError) as e:
+ raise DictError() from e
+ return cls(**value_as_dict)
+
+ # From Pydantic, override to only show keys from fields, omit SQLAlchemy attributes
+ def _calculate_keys(
+ self,
+ include: Optional[Mapping[Union[int, str], Any]],
+ exclude: Optional[Mapping[Union[int, str], Any]],
+ exclude_unset: bool,
+ update: Optional[Dict[str, Any]] = None,
+ ) -> Optional[AbstractSet[str]]:
+ if include is None and exclude is None and exclude_unset is False:
+ # Original in Pydantic:
+ # return None
+ # Updated to not return SQLAlchemy attributes
+ # Do not include relationships as that would easily lead to infinite
+ # recursion, or traversing the whole database
+ return self.__fields__.keys() # | self.__sqlmodel_relationships__.keys()
+
+ keys: AbstractSet[str]
+ if exclude_unset:
+ keys = self.__fields_set__.copy()
+ else:
+ # Original in Pydantic:
+ # keys = self.__dict__.keys()
+ # Updated to not return SQLAlchemy attributes
+ # Do not include relationships as that would easily lead to infinite
+ # recursion, or traversing the whole database
+ keys = self.__fields__.keys() # | self.__sqlmodel_relationships__.keys()
+
+ if include is not None:
+ keys &= include.keys()
+
+ if update:
+ keys -= update.keys()
+
+ if exclude:
+ keys -= {k for k, v in exclude.items() if ValueItems.is_true(v)}
+
+ return keys
+
+ @declared_attr # type: ignore
+ def __tablename__(cls) -> str:
+ return cls.__name__.lower()
--- /dev/null
+from typing import Any, Mapping, Optional, Sequence, TypeVar, Union, overload
+
+from sqlalchemy import util
+from sqlalchemy.orm import Query as _Query
+from sqlalchemy.orm import Session as _Session
+from sqlalchemy.sql.base import Executable as _Executable
+from sqlmodel.sql.expression import Select, SelectOfScalar
+from typing_extensions import Literal
+
+from ..engine.result import Result, ScalarResult
+from ..sql.base import Executable
+
+_T = TypeVar("_T")
+
+
+class Session(_Session):
+ @overload
+ def exec(
+ self,
+ statement: Select[_T],
+ *,
+ params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
+ execution_options: Mapping[str, Any] = util.EMPTY_DICT,
+ bind_arguments: Optional[Mapping[str, Any]] = None,
+ _parent_execute_state: Optional[Any] = None,
+ _add_event: Optional[Any] = None,
+ **kw: Any,
+ ) -> Union[Result[_T]]:
+ ...
+
+ @overload
+ def exec(
+ self,
+ statement: SelectOfScalar[_T],
+ *,
+ params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
+ execution_options: Mapping[str, Any] = util.EMPTY_DICT,
+ bind_arguments: Optional[Mapping[str, Any]] = None,
+ _parent_execute_state: Optional[Any] = None,
+ _add_event: Optional[Any] = None,
+ **kw: Any,
+ ) -> Union[ScalarResult[_T]]:
+ ...
+
+ def exec(
+ self,
+ statement: Union[Select[_T], SelectOfScalar[_T], Executable[_T]],
+ *,
+ params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
+ execution_options: Mapping[str, Any] = util.EMPTY_DICT,
+ bind_arguments: Optional[Mapping[str, Any]] = None,
+ _parent_execute_state: Optional[Any] = None,
+ _add_event: Optional[Any] = None,
+ **kw: Any,
+ ) -> Union[Result[_T], ScalarResult[_T]]:
+ results = super().execute(
+ statement,
+ params=params,
+ execution_options=execution_options, # type: ignore
+ bind_arguments=bind_arguments,
+ _parent_execute_state=_parent_execute_state,
+ _add_event=_add_event,
+ **kw,
+ )
+ if isinstance(statement, SelectOfScalar):
+ return results.scalars() # type: ignore
+ return results # type: ignore
+
+ def execute(
+ self,
+ statement: _Executable,
+ params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
+ execution_options: Mapping[str, Any] = util.EMPTY_DICT,
+ bind_arguments: Optional[Mapping[str, Any]] = None,
+ _parent_execute_state: Optional[Any] = None,
+ _add_event: Optional[Any] = None,
+ **kw: Any,
+ ) -> Result[Any]:
+ """
+ 🚨 You probably want to use `session.exec()` instead of `session.execute()`.
+
+ This is the original SQLAlchemy `session.execute()` method that returns objects
+ of type `Row`, and that you have to call `scalars()` to get the model objects.
+
+ For example:
+
+ ```Python
+ heroes = session.execute(select(Hero)).scalars().all()
+ ```
+
+ instead you could use `exec()`:
+
+ ```Python
+ heroes = session.exec(select(Hero)).all()
+ ```
+ """
+ return super().execute( # type: ignore
+ statement,
+ params=params,
+ execution_options=execution_options, # type: ignore
+ bind_arguments=bind_arguments,
+ _parent_execute_state=_parent_execute_state,
+ _add_event=_add_event,
+ **kw,
+ )
+
+ def query(self, *entities: Any, **kwargs: Any) -> "_Query[Any]":
+ """
+ 🚨 You probably want to use `session.exec()` instead of `session.query()`.
+
+ `session.exec()` is SQLModel's own short version with increased type
+ annotations.
+
+ Or otherwise you might want to use `session.execute()` instead of
+ `session.query()`.
+ """
+ return super().query(*entities, **kwargs)
+
+ def get(
+ self,
+ entity: _T,
+ ident: Any,
+ options: Optional[Sequence[Any]] = None,
+ populate_existing: bool = False,
+ with_for_update: Optional[Union[Literal[True], Mapping[str, Any]]] = None,
+ identity_token: Optional[Any] = None,
+ ) -> _T:
+ return super().get(
+ entity,
+ ident,
+ options=options,
+ populate_existing=populate_existing,
+ with_for_update=with_for_update,
+ identity_token=identity_token,
+ )
--- /dev/null
+from sqlalchemy.pool import StaticPool as StaticPool # noqa: F401
--- /dev/null
+from typing import Generic, TypeVar
+
+from sqlalchemy.sql.base import Executable as _Executable
+
+_T = TypeVar("_T")
+
+
+class Executable(_Executable, Generic[_T]):
+ def __init__(self, *args, **kwargs):
+ self.__dict__["_exec_options"] = kwargs.pop("_exec_options", None)
+ super(_Executable, self).__init__(*args, **kwargs)
--- /dev/null
+# WARNING: do not modify this code, it is generated by expression.py.jinja2
+
+import sys
+from datetime import datetime
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Generic,
+ Mapping,
+ Sequence,
+ Tuple,
+ Type,
+ TypeVar,
+ Union,
+ cast,
+ overload,
+)
+from uuid import UUID
+
+from sqlalchemy import Column
+from sqlalchemy.orm import InstrumentedAttribute
+from sqlalchemy.sql.elements import ColumnClause
+from sqlalchemy.sql.expression import Select as _Select
+
+_TSelect = TypeVar("_TSelect")
+
+# Workaround Generics incompatibility in Python 3.6
+# Ref: https://github.com/python/typing/issues/449#issuecomment-316061322
+if sys.version_info.minor >= 7:
+
+ class Select(_Select, Generic[_TSelect]):
+ pass
+
+ # This is not comparable to sqlalchemy.sql.selectable.ScalarSelect, that has a different
+ # purpose. This is the same as a normal SQLAlchemy Select class where there's only one
+ # entity, so the result will be converted to a scalar by default. This way writing
+ # for loops on the results will feel natural.
+ class SelectOfScalar(_Select, Generic[_TSelect]):
+ pass
+
+
+else:
+ from typing import GenericMeta # type: ignore
+
+ class GenericSelectMeta(GenericMeta, _Select.__class__): # type: ignore
+ pass
+
+ class _Py36Select(_Select, Generic[_TSelect], metaclass=GenericSelectMeta):
+ pass
+
+ class _Py36SelectOfScalar(_Select, Generic[_TSelect], metaclass=GenericSelectMeta):
+ pass
+
+ # Cast them for editors to work correctly, from several tricks tried, this works
+ # for both VS Code and PyCharm
+ Select = cast("Select", _Py36Select) # type: ignore
+ SelectOfScalar = cast("SelectOfScalar", _Py36SelectOfScalar) # type: ignore
+
+
+if TYPE_CHECKING: # pragma: no cover
+ from ..main import SQLModel
+
+# Generated TypeVars start
+
+
+_TScalar_0 = TypeVar(
+ "_TScalar_0",
+ Column,
+ Sequence,
+ Mapping,
+ UUID,
+ datetime,
+ float,
+ int,
+ bool,
+ bytes,
+ str,
+ None,
+)
+
+_TModel_0 = TypeVar("_TModel_0", bound="SQLModel")
+
+
+_TScalar_1 = TypeVar(
+ "_TScalar_1",
+ Column,
+ Sequence,
+ Mapping,
+ UUID,
+ datetime,
+ float,
+ int,
+ bool,
+ bytes,
+ str,
+ None,
+)
+
+_TModel_1 = TypeVar("_TModel_1", bound="SQLModel")
+
+
+_TScalar_2 = TypeVar(
+ "_TScalar_2",
+ Column,
+ Sequence,
+ Mapping,
+ UUID,
+ datetime,
+ float,
+ int,
+ bool,
+ bytes,
+ str,
+ None,
+)
+
+_TModel_2 = TypeVar("_TModel_2", bound="SQLModel")
+
+
+_TScalar_3 = TypeVar(
+ "_TScalar_3",
+ Column,
+ Sequence,
+ Mapping,
+ UUID,
+ datetime,
+ float,
+ int,
+ bool,
+ bytes,
+ str,
+ None,
+)
+
+_TModel_3 = TypeVar("_TModel_3", bound="SQLModel")
+
+
+# Generated TypeVars end
+
+
+@overload
+def select(entity_0: _TScalar_0, **kw: Any) -> SelectOfScalar[_TScalar_0]: # type: ignore
+ ...
+
+
+@overload
+def select(entity_0: Type[_TModel_0], **kw: Any) -> SelectOfScalar[_TModel_0]: # type: ignore
+ ...
+
+
+# Generated overloads start
+
+
+@overload
+def select( # type: ignore
+ entity_0: _TScalar_0,
+ entity_1: _TScalar_1,
+ **kw: Any,
+) -> Select[Tuple[_TScalar_0, _TScalar_1]]:
+ ...
+
+
+@overload
+def select( # type: ignore
+ entity_0: _TScalar_0,
+ entity_1: Type[_TModel_1],
+ **kw: Any,
+) -> Select[Tuple[_TScalar_0, _TModel_1]]:
+ ...
+
+
+@overload
+def select( # type: ignore
+ entity_0: Type[_TModel_0],
+ entity_1: _TScalar_1,
+ **kw: Any,
+) -> Select[Tuple[_TModel_0, _TScalar_1]]:
+ ...
+
+
+@overload
+def select( # type: ignore
+ entity_0: Type[_TModel_0],
+ entity_1: Type[_TModel_1],
+ **kw: Any,
+) -> Select[Tuple[_TModel_0, _TModel_1]]:
+ ...
+
+
+@overload
+def select( # type: ignore
+ entity_0: _TScalar_0,
+ entity_1: _TScalar_1,
+ entity_2: _TScalar_2,
+ **kw: Any,
+) -> Select[Tuple[_TScalar_0, _TScalar_1, _TScalar_2]]:
+ ...
+
+
+@overload
+def select( # type: ignore
+ entity_0: _TScalar_0,
+ entity_1: _TScalar_1,
+ entity_2: Type[_TModel_2],
+ **kw: Any,
+) -> Select[Tuple[_TScalar_0, _TScalar_1, _TModel_2]]:
+ ...
+
+
+@overload
+def select( # type: ignore
+ entity_0: _TScalar_0,
+ entity_1: Type[_TModel_1],
+ entity_2: _TScalar_2,
+ **kw: Any,
+) -> Select[Tuple[_TScalar_0, _TModel_1, _TScalar_2]]:
+ ...
+
+
+@overload
+def select( # type: ignore
+ entity_0: _TScalar_0,
+ entity_1: Type[_TModel_1],
+ entity_2: Type[_TModel_2],
+ **kw: Any,
+) -> Select[Tuple[_TScalar_0, _TModel_1, _TModel_2]]:
+ ...
+
+
+@overload
+def select( # type: ignore
+ entity_0: Type[_TModel_0],
+ entity_1: _TScalar_1,
+ entity_2: _TScalar_2,
+ **kw: Any,
+) -> Select[Tuple[_TModel_0, _TScalar_1, _TScalar_2]]:
+ ...
+
+
+@overload
+def select( # type: ignore
+ entity_0: Type[_TModel_0],
+ entity_1: _TScalar_1,
+ entity_2: Type[_TModel_2],
+ **kw: Any,
+) -> Select[Tuple[_TModel_0, _TScalar_1, _TModel_2]]:
+ ...
+
+
+@overload
+def select( # type: ignore
+ entity_0: Type[_TModel_0],
+ entity_1: Type[_TModel_1],
+ entity_2: _TScalar_2,
+ **kw: Any,
+) -> Select[Tuple[_TModel_0, _TModel_1, _TScalar_2]]:
+ ...
+
+
+@overload
+def select( # type: ignore
+ entity_0: Type[_TModel_0],
+ entity_1: Type[_TModel_1],
+ entity_2: Type[_TModel_2],
+ **kw: Any,
+) -> Select[Tuple[_TModel_0, _TModel_1, _TModel_2]]:
+ ...
+
+
+@overload
+def select( # type: ignore
+ entity_0: _TScalar_0,
+ entity_1: _TScalar_1,
+ entity_2: _TScalar_2,
+ entity_3: _TScalar_3,
+ **kw: Any,
+) -> Select[Tuple[_TScalar_0, _TScalar_1, _TScalar_2, _TScalar_3]]:
+ ...
+
+
+@overload
+def select( # type: ignore
+ entity_0: _TScalar_0,
+ entity_1: _TScalar_1,
+ entity_2: _TScalar_2,
+ entity_3: Type[_TModel_3],
+ **kw: Any,
+) -> Select[Tuple[_TScalar_0, _TScalar_1, _TScalar_2, _TModel_3]]:
+ ...
+
+
+@overload
+def select( # type: ignore
+ entity_0: _TScalar_0,
+ entity_1: _TScalar_1,
+ entity_2: Type[_TModel_2],
+ entity_3: _TScalar_3,
+ **kw: Any,
+) -> Select[Tuple[_TScalar_0, _TScalar_1, _TModel_2, _TScalar_3]]:
+ ...
+
+
+@overload
+def select( # type: ignore
+ entity_0: _TScalar_0,
+ entity_1: _TScalar_1,
+ entity_2: Type[_TModel_2],
+ entity_3: Type[_TModel_3],
+ **kw: Any,
+) -> Select[Tuple[_TScalar_0, _TScalar_1, _TModel_2, _TModel_3]]:
+ ...
+
+
+@overload
+def select( # type: ignore
+ entity_0: _TScalar_0,
+ entity_1: Type[_TModel_1],
+ entity_2: _TScalar_2,
+ entity_3: _TScalar_3,
+ **kw: Any,
+) -> Select[Tuple[_TScalar_0, _TModel_1, _TScalar_2, _TScalar_3]]:
+ ...
+
+
+@overload
+def select( # type: ignore
+ entity_0: _TScalar_0,
+ entity_1: Type[_TModel_1],
+ entity_2: _TScalar_2,
+ entity_3: Type[_TModel_3],
+ **kw: Any,
+) -> Select[Tuple[_TScalar_0, _TModel_1, _TScalar_2, _TModel_3]]:
+ ...
+
+
+@overload
+def select( # type: ignore
+ entity_0: _TScalar_0,
+ entity_1: Type[_TModel_1],
+ entity_2: Type[_TModel_2],
+ entity_3: _TScalar_3,
+ **kw: Any,
+) -> Select[Tuple[_TScalar_0, _TModel_1, _TModel_2, _TScalar_3]]:
+ ...
+
+
+@overload
+def select( # type: ignore
+ entity_0: _TScalar_0,
+ entity_1: Type[_TModel_1],
+ entity_2: Type[_TModel_2],
+ entity_3: Type[_TModel_3],
+ **kw: Any,
+) -> Select[Tuple[_TScalar_0, _TModel_1, _TModel_2, _TModel_3]]:
+ ...
+
+
+@overload
+def select( # type: ignore
+ entity_0: Type[_TModel_0],
+ entity_1: _TScalar_1,
+ entity_2: _TScalar_2,
+ entity_3: _TScalar_3,
+ **kw: Any,
+) -> Select[Tuple[_TModel_0, _TScalar_1, _TScalar_2, _TScalar_3]]:
+ ...
+
+
+@overload
+def select( # type: ignore
+ entity_0: Type[_TModel_0],
+ entity_1: _TScalar_1,
+ entity_2: _TScalar_2,
+ entity_3: Type[_TModel_3],
+ **kw: Any,
+) -> Select[Tuple[_TModel_0, _TScalar_1, _TScalar_2, _TModel_3]]:
+ ...
+
+
+@overload
+def select( # type: ignore
+ entity_0: Type[_TModel_0],
+ entity_1: _TScalar_1,
+ entity_2: Type[_TModel_2],
+ entity_3: _TScalar_3,
+ **kw: Any,
+) -> Select[Tuple[_TModel_0, _TScalar_1, _TModel_2, _TScalar_3]]:
+ ...
+
+
+@overload
+def select( # type: ignore
+ entity_0: Type[_TModel_0],
+ entity_1: _TScalar_1,
+ entity_2: Type[_TModel_2],
+ entity_3: Type[_TModel_3],
+ **kw: Any,
+) -> Select[Tuple[_TModel_0, _TScalar_1, _TModel_2, _TModel_3]]:
+ ...
+
+
+@overload
+def select( # type: ignore
+ entity_0: Type[_TModel_0],
+ entity_1: Type[_TModel_1],
+ entity_2: _TScalar_2,
+ entity_3: _TScalar_3,
+ **kw: Any,
+) -> Select[Tuple[_TModel_0, _TModel_1, _TScalar_2, _TScalar_3]]:
+ ...
+
+
+@overload
+def select( # type: ignore
+ entity_0: Type[_TModel_0],
+ entity_1: Type[_TModel_1],
+ entity_2: _TScalar_2,
+ entity_3: Type[_TModel_3],
+ **kw: Any,
+) -> Select[Tuple[_TModel_0, _TModel_1, _TScalar_2, _TModel_3]]:
+ ...
+
+
+@overload
+def select( # type: ignore
+ entity_0: Type[_TModel_0],
+ entity_1: Type[_TModel_1],
+ entity_2: Type[_TModel_2],
+ entity_3: _TScalar_3,
+ **kw: Any,
+) -> Select[Tuple[_TModel_0, _TModel_1, _TModel_2, _TScalar_3]]:
+ ...
+
+
+@overload
+def select( # type: ignore
+ entity_0: Type[_TModel_0],
+ entity_1: Type[_TModel_1],
+ entity_2: Type[_TModel_2],
+ entity_3: Type[_TModel_3],
+ **kw: Any,
+) -> Select[Tuple[_TModel_0, _TModel_1, _TModel_2, _TModel_3]]:
+ ...
+
+
+# Generated overloads end
+
+
+def select(*entities: Any, **kw: Any) -> Union[Select, SelectOfScalar]:
+ if len(entities) == 1:
+ return SelectOfScalar._create(*entities, **kw) # type: ignore
+ return Select._create(*entities, **kw) # type: ignore
+
+
+# TODO: add several @overload from Python types to SQLAlchemy equivalents
+def col(column_expression: Any) -> ColumnClause:
+ if not isinstance(column_expression, (ColumnClause, Column, InstrumentedAttribute)):
+ raise RuntimeError(f"Not a SQLAlchemy column: {column_expression}")
+ return column_expression
--- /dev/null
+import sys
+from datetime import datetime
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Generic,
+ Mapping,
+ Sequence,
+ Tuple,
+ Type,
+ TypeVar,
+ Union,
+ cast,
+ overload,
+)
+from uuid import UUID
+
+from sqlalchemy import Column
+from sqlalchemy.orm import InstrumentedAttribute
+from sqlalchemy.sql.elements import ColumnClause
+from sqlalchemy.sql.expression import Select as _Select
+
+_TSelect = TypeVar("_TSelect")
+
+# Workaround Generics incompatibility in Python 3.6
+# Ref: https://github.com/python/typing/issues/449#issuecomment-316061322
+if sys.version_info.minor >= 7:
+
+ class Select(_Select, Generic[_TSelect]):
+ pass
+
+ # This is not comparable to sqlalchemy.sql.selectable.ScalarSelect, that has a different
+ # purpose. This is the same as a normal SQLAlchemy Select class where there's only one
+ # entity, so the result will be converted to a scalar by default. This way writing
+ # for loops on the results will feel natural.
+ class SelectOfScalar(_Select, Generic[_TSelect]):
+ pass
+
+
+else:
+ from typing import GenericMeta # type: ignore
+
+ class GenericSelectMeta(GenericMeta, _Select.__class__): # type: ignore
+ pass
+
+ class _Py36Select(_Select, Generic[_TSelect], metaclass=GenericSelectMeta):
+ pass
+
+ class _Py36SelectOfScalar(_Select, Generic[_TSelect], metaclass=GenericSelectMeta):
+ pass
+
+ # Cast them for editors to work correctly, from several tricks tried, this works
+ # for both VS Code and PyCharm
+ Select = cast("Select", _Py36Select) # type: ignore
+ SelectOfScalar = cast("SelectOfScalar", _Py36SelectOfScalar) # type: ignore
+
+
+if TYPE_CHECKING: # pragma: no cover
+ from ..main import SQLModel
+
+# Generated TypeVars start
+
+{% for i in range(number_of_types) %}
+_TScalar_{{ i }} = TypeVar(
+ "_TScalar_{{ i }}",
+ Column,
+ Sequence,
+ Mapping,
+ UUID,
+ datetime,
+ float,
+ int,
+ bool,
+ bytes,
+ str,
+ None,
+)
+
+_TModel_{{ i }} = TypeVar("_TModel_{{ i }}", bound="SQLModel")
+
+{% endfor %}
+
+# Generated TypeVars end
+
+@overload
+def select(entity_0: _TScalar_0, **kw: Any) -> SelectOfScalar[_TScalar_0]: # type: ignore
+ ...
+
+
+@overload
+def select(entity_0: Type[_TModel_0], **kw: Any) -> SelectOfScalar[_TModel_0]: # type: ignore
+ ...
+
+
+# Generated overloads start
+
+{% for signature in signatures %}
+
+@overload
+def select( # type: ignore
+ {% for arg in signature[0] %}{{ arg.name }}: {{ arg.annotation }}, {% endfor %}**kw: Any,
+ ) -> Select[Tuple[{%for ret in signature[1] %}{{ ret }} {% if not loop.last %}, {% endif %}{% endfor %}]]:
+ ...
+
+{% endfor %}
+
+# Generated overloads end
+
+def select(*entities: Any, **kw: Any) -> Union[Select, SelectOfScalar]:
+ if len(entities) == 1:
+ return SelectOfScalar._create(*entities, **kw) # type: ignore
+ return Select._create(*entities, **kw)
+
+
+# TODO: add several @overload from Python types to SQLAlchemy equivalents
+def col(column_expression: Any) -> ColumnClause:
+ if not isinstance(column_expression, (ColumnClause, Column, InstrumentedAttribute)):
+ raise RuntimeError(f"Not a SQLAlchemy column: {column_expression}")
+ return column_expression
--- /dev/null
+import uuid
+from typing import Any, cast
+
+from sqlalchemy import types
+from sqlalchemy.dialects.postgresql import UUID
+from sqlalchemy.engine.interfaces import Dialect
+from sqlalchemy.types import CHAR, TypeDecorator
+
+
+class AutoString(types.TypeDecorator):
+
+ impl = types.String
+ cache_ok = True
+ mysql_default_length = 255
+
+ def load_dialect_impl(self, dialect: Dialect) -> "types.TypeEngine[Any]":
+ impl = cast(types.String, self.impl)
+ if impl.length is None and dialect.name == "mysql":
+ return dialect.type_descriptor(types.String(self.mysql_default_length)) # type: ignore
+ return super().load_dialect_impl(dialect)
+
+
+# Reference form SQLAlchemy docs: https://docs.sqlalchemy.org/en/14/core/custom_types.html#backend-agnostic-guid-type
+# with small modifications
+class GUID(TypeDecorator):
+ """Platform-independent GUID type.
+
+ Uses PostgreSQL's UUID type, otherwise uses
+ CHAR(32), storing as stringified hex values.
+
+ """
+
+ impl = CHAR
+ cache_ok = True
+
+ def load_dialect_impl(self, dialect):
+ if dialect.name == "postgresql":
+ return dialect.type_descriptor(UUID())
+ else:
+ return dialect.type_descriptor(CHAR(32))
+
+ def process_bind_param(self, value, dialect):
+ if value is None:
+ return value
+ elif dialect.name == "postgresql":
+ return str(value)
+ else:
+ if not isinstance(value, uuid.UUID):
+ return f"{uuid.UUID(value).int:x}"
+ else:
+ # hexstring
+ return f"{value.int:x}"
+
+ def process_result_value(self, value, dialect):
+ if value is None:
+ return value
+ else:
+ if not isinstance(value, uuid.UUID):
+ value = uuid.UUID(value)
+ return value