]> git.ipfire.org Git - thirdparty/fastapi/sqlmodel.git/commitdiff
✨ Add SQLModel core code
authorSebastián Ramírez <tiangolo@gmail.com>
Tue, 24 Aug 2021 12:41:53 +0000 (14:41 +0200)
committerSebastián Ramírez <tiangolo@gmail.com>
Tue, 24 Aug 2021 12:41:53 +0000 (14:41 +0200)
17 files changed:
sqlmodel/__init__.py [new file with mode: 0644]
sqlmodel/default.py [new file with mode: 0644]
sqlmodel/engine/__init__.py [new file with mode: 0644]
sqlmodel/engine/create.py [new file with mode: 0644]
sqlmodel/engine/result.py [new file with mode: 0644]
sqlmodel/ext/__init__.py [new file with mode: 0644]
sqlmodel/ext/asyncio/__init__.py [new file with mode: 0644]
sqlmodel/ext/asyncio/session.py [new file with mode: 0644]
sqlmodel/main.py [new file with mode: 0644]
sqlmodel/orm/__init__.py [new file with mode: 0644]
sqlmodel/orm/session.py [new file with mode: 0644]
sqlmodel/pool/__init__.py [new file with mode: 0644]
sqlmodel/sql/__init__.py [new file with mode: 0644]
sqlmodel/sql/base.py [new file with mode: 0644]
sqlmodel/sql/expression.py [new file with mode: 0644]
sqlmodel/sql/expression.py.jinja2 [new file with mode: 0644]
sqlmodel/sql/sqltypes.py [new file with mode: 0644]

diff --git a/sqlmodel/__init__.py b/sqlmodel/__init__.py
new file mode 100644 (file)
index 0000000..cdfb889
--- /dev/null
@@ -0,0 +1,139 @@
+__version__ = "0.0.1"
+
+# Re-export from SQLAlchemy
+from sqlalchemy.engine import create_mock_engine as create_mock_engine
+from sqlalchemy.engine import engine_from_config as engine_from_config
+from sqlalchemy.inspection import inspect as inspect
+from sqlalchemy.schema import BLANK_SCHEMA as BLANK_SCHEMA
+from sqlalchemy.schema import CheckConstraint as CheckConstraint
+from sqlalchemy.schema import Column as Column
+from sqlalchemy.schema import ColumnDefault as ColumnDefault
+from sqlalchemy.schema import Computed as Computed
+from sqlalchemy.schema import Constraint as Constraint
+from sqlalchemy.schema import DDL as DDL
+from sqlalchemy.schema import DefaultClause as DefaultClause
+from sqlalchemy.schema import FetchedValue as FetchedValue
+from sqlalchemy.schema import ForeignKey as ForeignKey
+from sqlalchemy.schema import ForeignKeyConstraint as ForeignKeyConstraint
+from sqlalchemy.schema import Identity as Identity
+from sqlalchemy.schema import Index as Index
+from sqlalchemy.schema import MetaData as MetaData
+from sqlalchemy.schema import PrimaryKeyConstraint as PrimaryKeyConstraint
+from sqlalchemy.schema import Sequence as Sequence
+from sqlalchemy.schema import Table as Table
+from sqlalchemy.schema import ThreadLocalMetaData as ThreadLocalMetaData
+from sqlalchemy.schema import UniqueConstraint as UniqueConstraint
+from sqlalchemy.sql import alias as alias
+from sqlalchemy.sql import all_ as all_
+from sqlalchemy.sql import and_ as and_
+from sqlalchemy.sql import any_ as any_
+from sqlalchemy.sql import asc as asc
+from sqlalchemy.sql import between as between
+from sqlalchemy.sql import bindparam as bindparam
+from sqlalchemy.sql import case as case
+from sqlalchemy.sql import cast as cast
+from sqlalchemy.sql import collate as collate
+from sqlalchemy.sql import column as column
+from sqlalchemy.sql import delete as delete
+from sqlalchemy.sql import desc as desc
+from sqlalchemy.sql import distinct as distinct
+from sqlalchemy.sql import except_ as except_
+from sqlalchemy.sql import except_all as except_all
+from sqlalchemy.sql import exists as exists
+from sqlalchemy.sql import extract as extract
+from sqlalchemy.sql import false as false
+from sqlalchemy.sql import func as func
+from sqlalchemy.sql import funcfilter as funcfilter
+from sqlalchemy.sql import insert as insert
+from sqlalchemy.sql import intersect as intersect
+from sqlalchemy.sql import intersect_all as intersect_all
+from sqlalchemy.sql import join as join
+from sqlalchemy.sql import LABEL_STYLE_DEFAULT as LABEL_STYLE_DEFAULT
+from sqlalchemy.sql import (
+    LABEL_STYLE_DISAMBIGUATE_ONLY as LABEL_STYLE_DISAMBIGUATE_ONLY,
+)
+from sqlalchemy.sql import LABEL_STYLE_NONE as LABEL_STYLE_NONE
+from sqlalchemy.sql import (
+    LABEL_STYLE_TABLENAME_PLUS_COL as LABEL_STYLE_TABLENAME_PLUS_COL,
+)
+from sqlalchemy.sql import lambda_stmt as lambda_stmt
+from sqlalchemy.sql import lateral as lateral
+from sqlalchemy.sql import literal as literal
+from sqlalchemy.sql import literal_column as literal_column
+from sqlalchemy.sql import modifier as modifier
+from sqlalchemy.sql import not_ as not_
+from sqlalchemy.sql import null as null
+from sqlalchemy.sql import nulls_first as nulls_first
+from sqlalchemy.sql import nulls_last as nulls_last
+from sqlalchemy.sql import nullsfirst as nullsfirst
+from sqlalchemy.sql import nullslast as nullslast
+from sqlalchemy.sql import or_ as or_
+from sqlalchemy.sql import outerjoin as outerjoin
+from sqlalchemy.sql import outparam as outparam
+from sqlalchemy.sql import over as over
+from sqlalchemy.sql import subquery as subquery
+from sqlalchemy.sql import table as table
+from sqlalchemy.sql import tablesample as tablesample
+from sqlalchemy.sql import text as text
+from sqlalchemy.sql import true as true
+from sqlalchemy.sql import tuple_ as tuple_
+from sqlalchemy.sql import type_coerce as type_coerce
+from sqlalchemy.sql import union as union
+from sqlalchemy.sql import union_all as union_all
+from sqlalchemy.sql import update as update
+from sqlalchemy.sql import values as values
+from sqlalchemy.sql import within_group as within_group
+from sqlalchemy.types import ARRAY as ARRAY
+from sqlalchemy.types import BIGINT as BIGINT
+from sqlalchemy.types import BigInteger as BigInteger
+from sqlalchemy.types import BINARY as BINARY
+from sqlalchemy.types import BLOB as BLOB
+from sqlalchemy.types import BOOLEAN as BOOLEAN
+from sqlalchemy.types import Boolean as Boolean
+from sqlalchemy.types import CHAR as CHAR
+from sqlalchemy.types import CLOB as CLOB
+from sqlalchemy.types import DATE as DATE
+from sqlalchemy.types import Date as Date
+from sqlalchemy.types import DATETIME as DATETIME
+from sqlalchemy.types import DateTime as DateTime
+from sqlalchemy.types import DECIMAL as DECIMAL
+from sqlalchemy.types import Enum as Enum
+from sqlalchemy.types import FLOAT as FLOAT
+from sqlalchemy.types import Float as Float
+from sqlalchemy.types import INT as INT
+from sqlalchemy.types import INTEGER as INTEGER
+from sqlalchemy.types import Integer as Integer
+from sqlalchemy.types import Interval as Interval
+from sqlalchemy.types import JSON as JSON
+from sqlalchemy.types import LargeBinary as LargeBinary
+from sqlalchemy.types import NCHAR as NCHAR
+from sqlalchemy.types import NUMERIC as NUMERIC
+from sqlalchemy.types import Numeric as Numeric
+from sqlalchemy.types import NVARCHAR as NVARCHAR
+from sqlalchemy.types import PickleType as PickleType
+from sqlalchemy.types import REAL as REAL
+from sqlalchemy.types import SMALLINT as SMALLINT
+from sqlalchemy.types import SmallInteger as SmallInteger
+from sqlalchemy.types import String as String
+from sqlalchemy.types import TEXT as TEXT
+from sqlalchemy.types import Text as Text
+from sqlalchemy.types import TIME as TIME
+from sqlalchemy.types import Time as Time
+from sqlalchemy.types import TIMESTAMP as TIMESTAMP
+from sqlalchemy.types import TypeDecorator as TypeDecorator
+from sqlalchemy.types import Unicode as Unicode
+from sqlalchemy.types import UnicodeText as UnicodeText
+from sqlalchemy.types import VARBINARY as VARBINARY
+from sqlalchemy.types import VARCHAR as VARCHAR
+
+# Extensions and modifications of SQLAlchemy in SQLModel
+from .engine.create import create_engine as create_engine
+from .orm.session import Session as Session
+from .sql.expression import select as select
+from .sql.expression import col as col
+from .sql.sqltypes import AutoString as AutoString
+
+# Export SQLModel specifics (equivalent to Pydantic)
+from .main import SQLModel as SQLModel
+from .main import Field as Field
+from .main import Relationship as Relationship
diff --git a/sqlmodel/default.py b/sqlmodel/default.py
new file mode 100644 (file)
index 0000000..bb44972
--- /dev/null
@@ -0,0 +1,32 @@
+from typing import Any, TypeVar
+
+
+class _DefaultPlaceholder:
+    """
+    You shouldn't use this class directly.
+
+    It's used internally to recognize when a default value has been overwritten, even
+    if the overriden default value was truthy.
+    """
+
+    def __init__(self, value: Any):
+        self.value = value
+
+    def __bool__(self) -> bool:
+        return bool(self.value)
+
+    def __eq__(self, o: object) -> bool:
+        return isinstance(o, _DefaultPlaceholder) and o.value == self.value
+
+
+_TDefaultType = TypeVar("_TDefaultType")
+
+
+def Default(value: _TDefaultType) -> _TDefaultType:
+    """
+    You shouldn't use this function directly.
+
+    It's used internally to recognize when a default value has been overwritten, even
+    if the overriden default value was truthy.
+    """
+    return _DefaultPlaceholder(value)  # type: ignore
diff --git a/sqlmodel/engine/__init__.py b/sqlmodel/engine/__init__.py
new file mode 100644 (file)
index 0000000..e69de29
diff --git a/sqlmodel/engine/create.py b/sqlmodel/engine/create.py
new file mode 100644 (file)
index 0000000..9748125
--- /dev/null
@@ -0,0 +1,139 @@
+import json
+import sqlite3
+from typing import Any, Callable, Dict, List, Optional, Type, Union
+
+from sqlalchemy import create_engine as _create_engine
+from sqlalchemy.engine.url import URL
+from sqlalchemy.future import Engine as _FutureEngine
+from sqlalchemy.pool import Pool
+from typing_extensions import Literal, TypedDict
+
+from ..default import Default, _DefaultPlaceholder
+
+# Types defined in sqlalchemy2-stubs, but can't be imported, so re-define here
+
+_Debug = Literal["debug"]
+
+_IsolationLevel = Literal[
+    "SERIALIZABLE",
+    "REPEATABLE READ",
+    "READ COMMITTED",
+    "READ UNCOMMITTED",
+    "AUTOCOMMIT",
+]
+_ParamStyle = Literal["qmark", "numeric", "named", "format", "pyformat"]
+_ResetOnReturn = Literal["rollback", "commit"]
+
+
+class _SQLiteConnectArgs(TypedDict, total=False):
+    timeout: float
+    detect_types: Any
+    isolation_level: Optional[Literal["DEFERRED", "IMMEDIATE", "EXCLUSIVE"]]
+    check_same_thread: bool
+    factory: Type[sqlite3.Connection]
+    cached_statements: int
+    uri: bool
+
+
+_ConnectArgs = Union[_SQLiteConnectArgs, Dict[str, Any]]
+
+
+# Re-define create_engine to have by default future=True, and assume that's what is used
+# Also show the default values used for each parameter, but don't set them unless
+# explicitly passed as arguments by the user to prevent errors. E.g. SQLite doesn't
+# support pool connection arguments.
+def create_engine(
+    url: Union[str, URL],
+    *,
+    connect_args: _ConnectArgs = Default({}),  # type: ignore
+    echo: Union[bool, _Debug] = Default(False),
+    echo_pool: Union[bool, _Debug] = Default(False),
+    enable_from_linting: bool = Default(True),
+    encoding: str = Default("utf-8"),
+    execution_options: Dict[Any, Any] = Default({}),
+    future: bool = True,
+    hide_parameters: bool = Default(False),
+    implicit_returning: bool = Default(True),
+    isolation_level: Optional[_IsolationLevel] = Default(None),
+    json_deserializer: Callable[..., Any] = Default(json.loads),
+    json_serializer: Callable[..., Any] = Default(json.dumps),
+    label_length: Optional[int] = Default(None),
+    logging_name: Optional[str] = Default(None),
+    max_identifier_length: Optional[int] = Default(None),
+    max_overflow: int = Default(10),
+    module: Optional[Any] = Default(None),
+    paramstyle: Optional[_ParamStyle] = Default(None),
+    pool: Optional[Pool] = Default(None),
+    poolclass: Optional[Type[Pool]] = Default(None),
+    pool_logging_name: Optional[str] = Default(None),
+    pool_pre_ping: bool = Default(False),
+    pool_size: int = Default(5),
+    pool_recycle: int = Default(-1),
+    pool_reset_on_return: Optional[_ResetOnReturn] = Default("rollback"),
+    pool_timeout: float = Default(30),
+    pool_use_lifo: bool = Default(False),
+    plugins: Optional[List[str]] = Default(None),
+    query_cache_size: Optional[int] = Default(None),
+    **kwargs: Any,
+) -> _FutureEngine:
+    current_kwargs: Dict[str, Any] = {
+        "future": future,
+    }
+    if not isinstance(echo, _DefaultPlaceholder):
+        current_kwargs["echo"] = echo
+    if not isinstance(echo_pool, _DefaultPlaceholder):
+        current_kwargs["echo_pool"] = echo_pool
+    if not isinstance(enable_from_linting, _DefaultPlaceholder):
+        current_kwargs["enable_from_linting"] = enable_from_linting
+    if not isinstance(connect_args, _DefaultPlaceholder):
+        current_kwargs["connect_args"] = connect_args
+    if not isinstance(encoding, _DefaultPlaceholder):
+        current_kwargs["encoding"] = encoding
+    if not isinstance(execution_options, _DefaultPlaceholder):
+        current_kwargs["execution_options"] = execution_options
+    if not isinstance(hide_parameters, _DefaultPlaceholder):
+        current_kwargs["hide_parameters"] = hide_parameters
+    if not isinstance(implicit_returning, _DefaultPlaceholder):
+        current_kwargs["implicit_returning"] = implicit_returning
+    if not isinstance(isolation_level, _DefaultPlaceholder):
+        current_kwargs["isolation_level"] = isolation_level
+    if not isinstance(json_deserializer, _DefaultPlaceholder):
+        current_kwargs["json_deserializer"] = json_deserializer
+    if not isinstance(json_serializer, _DefaultPlaceholder):
+        current_kwargs["json_serializer"] = json_serializer
+    if not isinstance(label_length, _DefaultPlaceholder):
+        current_kwargs["label_length"] = label_length
+    if not isinstance(logging_name, _DefaultPlaceholder):
+        current_kwargs["logging_name"] = logging_name
+    if not isinstance(max_identifier_length, _DefaultPlaceholder):
+        current_kwargs["max_identifier_length"] = max_identifier_length
+    if not isinstance(max_overflow, _DefaultPlaceholder):
+        current_kwargs["max_overflow"] = max_overflow
+    if not isinstance(module, _DefaultPlaceholder):
+        current_kwargs["module"] = module
+    if not isinstance(paramstyle, _DefaultPlaceholder):
+        current_kwargs["paramstyle"] = paramstyle
+    if not isinstance(pool, _DefaultPlaceholder):
+        current_kwargs["pool"] = pool
+    if not isinstance(poolclass, _DefaultPlaceholder):
+        current_kwargs["poolclass"] = poolclass
+    if not isinstance(pool_logging_name, _DefaultPlaceholder):
+        current_kwargs["pool_logging_name"] = pool_logging_name
+    if not isinstance(pool_pre_ping, _DefaultPlaceholder):
+        current_kwargs["pool_pre_ping"] = pool_pre_ping
+    if not isinstance(pool_size, _DefaultPlaceholder):
+        current_kwargs["pool_size"] = pool_size
+    if not isinstance(pool_recycle, _DefaultPlaceholder):
+        current_kwargs["pool_recycle"] = pool_recycle
+    if not isinstance(pool_reset_on_return, _DefaultPlaceholder):
+        current_kwargs["pool_reset_on_return"] = pool_reset_on_return
+    if not isinstance(pool_timeout, _DefaultPlaceholder):
+        current_kwargs["pool_timeout"] = pool_timeout
+    if not isinstance(pool_use_lifo, _DefaultPlaceholder):
+        current_kwargs["pool_use_lifo"] = pool_use_lifo
+    if not isinstance(plugins, _DefaultPlaceholder):
+        current_kwargs["plugins"] = plugins
+    if not isinstance(query_cache_size, _DefaultPlaceholder):
+        current_kwargs["query_cache_size"] = query_cache_size
+    current_kwargs.update(kwargs)
+    return _create_engine(url, **current_kwargs)
diff --git a/sqlmodel/engine/result.py b/sqlmodel/engine/result.py
new file mode 100644 (file)
index 0000000..d521427
--- /dev/null
@@ -0,0 +1,79 @@
+from typing import Generic, Iterator, List, Optional, TypeVar
+
+from sqlalchemy.engine.result import Result as _Result
+from sqlalchemy.engine.result import ScalarResult as _ScalarResult
+
+_T = TypeVar("_T")
+
+
+class ScalarResult(_ScalarResult, Generic[_T]):
+    def all(self) -> List[_T]:
+        return super().all()
+
+    def partitions(self, size: Optional[int] = None) -> Iterator[List[_T]]:
+        return super().partitions(size)
+
+    def fetchall(self) -> List[_T]:
+        return super().fetchall()
+
+    def fetchmany(self, size: Optional[int] = None) -> List[_T]:
+        return super().fetchmany(size)
+
+    def __iter__(self) -> Iterator[_T]:
+        return super().__iter__()
+
+    def __next__(self) -> _T:
+        return super().__next__()
+
+    def first(self) -> Optional[_T]:
+        return super().first()
+
+    def one_or_none(self) -> Optional[_T]:
+        return super().one_or_none()
+
+    def one(self) -> _T:
+        return super().one()
+
+
+class Result(_Result, Generic[_T]):
+    def scalars(self, index: int = 0) -> ScalarResult[_T]:
+        return super().scalars(index)  # type: ignore
+
+    def __iter__(self) -> Iterator[_T]:  # type: ignore
+        return super().__iter__()  # type: ignore
+
+    def __next__(self) -> _T:  # type: ignore
+        return super().__next__()  # type: ignore
+
+    def partitions(self, size: Optional[int] = None) -> Iterator[List[_T]]:  # type: ignore
+        return super().partitions(size)  # type: ignore
+
+    def fetchall(self) -> List[_T]:  # type: ignore
+        return super().fetchall()  # type: ignore
+
+    def fetchone(self) -> Optional[_T]:  # type: ignore
+        return super().fetchone()  # type: ignore
+
+    def fetchmany(self, size: Optional[int] = None) -> List[_T]:  # type: ignore
+        return super().fetchmany()  # type: ignore
+
+    def all(self) -> List[_T]:  # type: ignore
+        return super().all()  # type: ignore
+
+    def first(self) -> Optional[_T]:  # type: ignore
+        return super().first()  # type: ignore
+
+    def one_or_none(self) -> Optional[_T]:  # type: ignore
+        return super().one_or_none()  # type: ignore
+
+    def scalar_one(self) -> _T:
+        return super().scalar_one()  # type: ignore
+
+    def scalar_one_or_none(self) -> Optional[_T]:
+        return super().scalar_one_or_none()  # type: ignore
+
+    def one(self) -> _T:  # type: ignore
+        return super().one()  # type: ignore
+
+    def scalar(self) -> Optional[_T]:
+        return super().scalar()  # type: ignore
diff --git a/sqlmodel/ext/__init__.py b/sqlmodel/ext/__init__.py
new file mode 100644 (file)
index 0000000..e69de29
diff --git a/sqlmodel/ext/asyncio/__init__.py b/sqlmodel/ext/asyncio/__init__.py
new file mode 100644 (file)
index 0000000..e69de29
diff --git a/sqlmodel/ext/asyncio/session.py b/sqlmodel/ext/asyncio/session.py
new file mode 100644 (file)
index 0000000..40e5b76
--- /dev/null
@@ -0,0 +1,62 @@
+from typing import Any, Mapping, Optional, Sequence, TypeVar, Union
+
+from sqlalchemy import util
+from sqlalchemy.ext.asyncio import AsyncSession as _AsyncSession
+from sqlalchemy.ext.asyncio import engine
+from sqlalchemy.ext.asyncio.engine import AsyncConnection, AsyncEngine
+from sqlalchemy.util.concurrency import greenlet_spawn
+from sqlmodel.sql.base import Executable
+
+from ...engine.result import ScalarResult
+from ...orm.session import Session
+from ...sql.expression import Select
+
+_T = TypeVar("_T")
+
+
+class AsyncSession(_AsyncSession):
+    sync_session: Session
+
+    def __init__(
+        self,
+        bind: Optional[Union[AsyncConnection, AsyncEngine]] = None,
+        binds: Optional[Mapping[object, Union[AsyncConnection, AsyncEngine]]] = None,
+        **kw,
+    ):
+        # All the same code of the original AsyncSession
+        kw["future"] = True
+        if bind:
+            self.bind = bind
+            bind = engine._get_sync_engine_or_connection(bind)  # type: ignore
+
+        if binds:
+            self.binds = binds
+            binds = {
+                key: engine._get_sync_engine_or_connection(b)  # type: ignore
+                for key, b in binds.items()
+            }
+
+        self.sync_session = self._proxied = self._assign_proxied(  # type: ignore
+            Session(bind=bind, binds=binds, **kw)  # type: ignore
+        )
+
+    async def exec(
+        self,
+        statement: Union[Select[_T], Executable[_T]],
+        params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
+        execution_options: Mapping[Any, Any] = util.EMPTY_DICT,
+        bind_arguments: Optional[Mapping[str, Any]] = None,
+        **kw: Any,
+    ) -> ScalarResult[_T]:
+        # TODO: the documentation says execution_options accepts a dict, but only
+        # util.immutabledict has the union() method. Is this a bug in SQLAlchemy?
+        execution_options = execution_options.union({"prebuffer_rows": True})  # type: ignore
+
+        return await greenlet_spawn(  # type: ignore
+            self.sync_session.exec,
+            statement,
+            params=params,
+            execution_options=execution_options,
+            bind_arguments=bind_arguments,
+            **kw,
+        )
diff --git a/sqlmodel/main.py b/sqlmodel/main.py
new file mode 100644 (file)
index 0000000..8036ceb
--- /dev/null
@@ -0,0 +1,631 @@
+import ipaddress
+import uuid
+import weakref
+from datetime import date, datetime, time, timedelta
+from decimal import Decimal
+from enum import Enum
+from pathlib import Path
+from typing import (
+    TYPE_CHECKING,
+    AbstractSet,
+    Any,
+    Callable,
+    ClassVar,
+    Dict,
+    ForwardRef,
+    List,
+    Mapping,
+    Optional,
+    Sequence,
+    Set,
+    Tuple,
+    Type,
+    TypeVar,
+    Union,
+    cast,
+)
+
+from pydantic import BaseModel
+from pydantic.errors import ConfigError, DictError
+from pydantic.fields import FieldInfo as PydanticFieldInfo
+from pydantic.fields import ModelField, Undefined, UndefinedType
+from pydantic.main import BaseConfig, ModelMetaclass, validate_model
+from pydantic.typing import NoArgAnyCallable, resolve_annotations
+from pydantic.utils import ROOT_KEY, Representation, ValueItems
+from sqlalchemy import (
+    Boolean,
+    Column,
+    Date,
+    DateTime,
+    Float,
+    ForeignKey,
+    Integer,
+    Interval,
+    Numeric,
+    inspect,
+)
+from sqlalchemy.orm import RelationshipProperty, declared_attr, registry, relationship
+from sqlalchemy.orm.attributes import set_attribute
+from sqlalchemy.orm.decl_api import DeclarativeMeta
+from sqlalchemy.orm.instrumentation import is_instrumented
+from sqlalchemy.sql.schema import MetaData
+from sqlalchemy.sql.sqltypes import LargeBinary, Time
+
+from .sql.sqltypes import GUID, AutoString
+
+_T = TypeVar("_T")
+
+
+def __dataclass_transform__(
+    *,
+    eq_default: bool = True,
+    order_default: bool = False,
+    kw_only_default: bool = False,
+    field_descriptors: Tuple[Union[type, Callable[..., Any]], ...] = (()),
+) -> Callable[[_T], _T]:
+    return lambda a: a
+
+
+class FieldInfo(PydanticFieldInfo):
+    def __init__(self, default: Any = Undefined, **kwargs: Any) -> None:
+        primary_key = kwargs.pop("primary_key", False)
+        nullable = kwargs.pop("nullable", Undefined)
+        foreign_key = kwargs.pop("foreign_key", Undefined)
+        index = kwargs.pop("index", Undefined)
+        sa_column = kwargs.pop("sa_column", Undefined)
+        sa_column_args = kwargs.pop("sa_column_args", Undefined)
+        sa_column_kwargs = kwargs.pop("sa_column_kwargs", Undefined)
+        if sa_column is not Undefined:
+            if sa_column_args is not Undefined:
+                raise RuntimeError(
+                    "Passing sa_column_args is not supported when "
+                    "also passing a sa_column"
+                )
+            if sa_column_kwargs is not Undefined:
+                raise RuntimeError(
+                    "Passing sa_column_kwargs is not supported when "
+                    "also passing a sa_column"
+                )
+        super().__init__(default=default, **kwargs)
+        self.primary_key = primary_key
+        self.nullable = nullable
+        self.foreign_key = foreign_key
+        self.index = index
+        self.sa_column = sa_column
+        self.sa_column_args = sa_column_args
+        self.sa_column_kwargs = sa_column_kwargs
+
+
+class RelationshipInfo(Representation):
+    def __init__(
+        self,
+        *,
+        back_populates: Optional[str] = None,
+        link_model: Optional[Any] = None,
+        sa_relationship: Optional[RelationshipProperty] = None,
+        sa_relationship_args: Optional[Sequence[Any]] = None,
+        sa_relationship_kwargs: Optional[Mapping[str, Any]] = None,
+    ) -> None:
+        if sa_relationship is not None:
+            if sa_relationship_args is not None:
+                raise RuntimeError(
+                    "Passing sa_relationship_args is not supported when "
+                    "also passing a sa_relationship"
+                )
+            if sa_relationship_kwargs is not None:
+                raise RuntimeError(
+                    "Passing sa_relationship_kwargs is not supported when "
+                    "also passing a sa_relationship"
+                )
+        self.back_populates = back_populates
+        self.link_model = link_model
+        self.sa_relationship = sa_relationship
+        self.sa_relationship_args = sa_relationship_args
+        self.sa_relationship_kwargs = sa_relationship_kwargs
+
+
+def Field(
+    default: Any = Undefined,
+    *,
+    default_factory: Optional[NoArgAnyCallable] = None,
+    alias: str = None,
+    title: str = None,
+    description: str = None,
+    exclude: Union[
+        AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
+    ] = None,
+    include: Union[
+        AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
+    ] = None,
+    const: bool = None,
+    gt: float = None,
+    ge: float = None,
+    lt: float = None,
+    le: float = None,
+    multiple_of: float = None,
+    min_items: int = None,
+    max_items: int = None,
+    min_length: int = None,
+    max_length: int = None,
+    allow_mutation: bool = True,
+    regex: str = None,
+    primary_key: bool = False,
+    foreign_key: Optional[Any] = None,
+    nullable: Union[bool, UndefinedType] = Undefined,
+    index: Union[bool, UndefinedType] = Undefined,
+    sa_column: Union[Column, UndefinedType] = Undefined,
+    sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined,
+    sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined,
+    schema_extra: Optional[Dict[str, Any]] = None,
+) -> Any:
+    current_schema_extra = schema_extra or {}
+    field_info = FieldInfo(
+        default,
+        default_factory=default_factory,
+        alias=alias,
+        title=title,
+        description=description,
+        exclude=exclude,
+        include=include,
+        const=const,
+        gt=gt,
+        ge=ge,
+        lt=lt,
+        le=le,
+        multiple_of=multiple_of,
+        min_items=min_items,
+        max_items=max_items,
+        min_length=min_length,
+        max_length=max_length,
+        allow_mutation=allow_mutation,
+        regex=regex,
+        primary_key=primary_key,
+        foreign_key=foreign_key,
+        nullable=nullable,
+        index=index,
+        sa_column=sa_column,
+        sa_column_args=sa_column_args,
+        sa_column_kwargs=sa_column_kwargs,
+        **current_schema_extra,
+    )
+    field_info._validate()
+    return field_info
+
+
+def Relationship(
+    *,
+    back_populates: Optional[str] = None,
+    link_model: Optional[Any] = None,
+    sa_relationship: Optional[RelationshipProperty] = None,
+    sa_relationship_args: Optional[Sequence[Any]] = None,
+    sa_relationship_kwargs: Optional[Mapping[str, Any]] = None,
+) -> Any:
+    relationship_info = RelationshipInfo(
+        back_populates=back_populates,
+        link_model=link_model,
+        sa_relationship=sa_relationship,
+        sa_relationship_args=sa_relationship_args,
+        sa_relationship_kwargs=sa_relationship_kwargs,
+    )
+    return relationship_info
+
+
+@__dataclass_transform__(kw_only_default=True, field_descriptors=(Field, FieldInfo))
+class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
+    __sqlmodel_relationships__: Dict[str, RelationshipInfo]
+    __config__: Type[BaseConfig]
+    __fields__: Dict[str, ModelField]
+
+    # Replicate SQLAlchemy
+    def __setattr__(cls, name: str, value: Any) -> None:
+        if getattr(cls.__config__, "table", False):  # type: ignore
+            DeclarativeMeta.__setattr__(cls, name, value)
+        else:
+            super().__setattr__(name, value)
+
+    def __delattr__(cls, name: str) -> None:
+        if getattr(cls.__config__, "table", False):  # type: ignore
+            DeclarativeMeta.__delattr__(cls, name)
+        else:
+            super().__delattr__(name)
+
+    # From Pydantic
+    def __new__(cls, name, bases, class_dict: dict, **kwargs) -> Any:
+        relationships: Dict[str, RelationshipInfo] = {}
+        dict_for_pydantic = {}
+        original_annotations = resolve_annotations(
+            class_dict.get("__annotations__", {}), class_dict.get("__module__", None)
+        )
+        pydantic_annotations = {}
+        relationship_annotations = {}
+        for k, v in class_dict.items():
+            if isinstance(v, RelationshipInfo):
+                relationships[k] = v
+            else:
+                dict_for_pydantic[k] = v
+        for k, v in original_annotations.items():
+            if k in relationships:
+                relationship_annotations[k] = v
+            else:
+                pydantic_annotations[k] = v
+        dict_used = {
+            **dict_for_pydantic,
+            "__weakref__": None,
+            "__sqlmodel_relationships__": relationships,
+            "__annotations__": pydantic_annotations,
+        }
+        # Duplicate logic from Pydantic to filter config kwargs because if they are
+        # passed directly including the registry Pydantic will pass them over to the
+        # superclass causing an error
+        allowed_config_kwargs: Set[str] = {
+            key
+            for key in dir(BaseConfig)
+            if not (
+                key.startswith("__") and key.endswith("__")
+            )  # skip dunder methods and attributes
+        }
+        pydantic_kwargs = kwargs.copy()
+        config_kwargs = {
+            key: pydantic_kwargs.pop(key)
+            for key in pydantic_kwargs.keys() & allowed_config_kwargs
+        }
+        new_cls = super().__new__(cls, name, bases, dict_used, **config_kwargs)
+        new_cls.__annotations__ = {
+            **relationship_annotations,
+            **pydantic_annotations,
+            **new_cls.__annotations__,
+        }
+
+        def get_config(name: str) -> Any:
+            config_class_value = getattr(new_cls.__config__, name, Undefined)
+            if config_class_value is not Undefined:
+                return config_class_value
+            kwarg_value = kwargs.get(name, Undefined)
+            if kwarg_value is not Undefined:
+                return kwarg_value
+            return Undefined
+
+        config_table = get_config("table")
+        if config_table is True:
+            # If it was passed by kwargs, ensure it's also set in config
+            new_cls.__config__.table = config_table
+            for k, v in new_cls.__fields__.items():
+                col = get_column_from_field(v)
+                setattr(new_cls, k, col)
+            # Set a config flag to tell FastAPI that this should be read with a field
+            # in orm_mode instead of preemptively converting it to a dict.
+            # This could be done by reading new_cls.__config__.table in FastAPI, but
+            # that's very specific about SQLModel, so let's have another config that
+            # other future tools based on Pydantic can use.
+            new_cls.__config__.read_with_orm_mode = True
+
+        config_registry = get_config("registry")
+        if config_registry is not Undefined:
+            config_registry = cast(registry, config_registry)
+            # If it was passed by kwargs, ensure it's also set in config
+            new_cls.__config__.registry = config_table
+            setattr(new_cls, "_sa_registry", config_registry)
+            setattr(new_cls, "metadata", config_registry.metadata)
+            setattr(new_cls, "__abstract__", True)
+        return new_cls
+
+    # Override SQLAlchemy, allow both SQLAlchemy and plain Pydantic models
+    def __init__(
+        cls, classname: str, bases: Tuple[type, ...], dict_: Dict[str, Any], **kw: Any
+    ) -> None:
+        # Only one of the base classes (or the current one) should be a table model
+        # this allows FastAPI cloning a SQLModel for the response_model without
+        # trying to create a new SQLAlchemy, for a new table, with the same name, that
+        # triggers an error
+        base_is_table = False
+        for base in bases:
+            config = getattr(base, "__config__")
+            if config and getattr(config, "table", False):
+                base_is_table = True
+                break
+        if getattr(cls.__config__, "table", False) and not base_is_table:
+            dict_used = dict_.copy()
+            for field_name, field_value in cls.__fields__.items():
+                dict_used[field_name] = get_column_from_field(field_value)
+            for rel_name, rel_info in cls.__sqlmodel_relationships__.items():
+                if rel_info.sa_relationship:
+                    # There's a SQLAlchemy relationship declared, that takes precedence
+                    # over anything else, use that and continue with the next attribute
+                    dict_used[rel_name] = rel_info.sa_relationship
+                    continue
+                ann = cls.__annotations__[rel_name]
+                temp_field = ModelField.infer(
+                    name=rel_name,
+                    value=rel_info,
+                    annotation=ann,
+                    class_validators=None,
+                    config=BaseConfig,
+                )
+                relationship_to = temp_field.type_
+                if isinstance(temp_field.type_, ForwardRef):
+                    relationship_to = temp_field.type_.__forward_arg__
+                rel_kwargs: Dict[str, Any] = {}
+                if rel_info.back_populates:
+                    rel_kwargs["back_populates"] = rel_info.back_populates
+                if rel_info.link_model:
+                    ins = inspect(rel_info.link_model)
+                    local_table = getattr(ins, "local_table")
+                    if local_table is None:
+                        raise RuntimeError(
+                            "Couldn't find the secondary table for "
+                            f"model {rel_info.link_model}"
+                        )
+                    rel_kwargs["secondary"] = local_table
+                rel_args: List[Any] = []
+                if rel_info.sa_relationship_args:
+                    rel_args.extend(rel_info.sa_relationship_args)
+                if rel_info.sa_relationship_kwargs:
+                    rel_kwargs.update(rel_info.sa_relationship_kwargs)
+                rel_value: RelationshipProperty = relationship(
+                    relationship_to, *rel_args, **rel_kwargs
+                )
+                dict_used[rel_name] = rel_value
+            DeclarativeMeta.__init__(cls, classname, bases, dict_used, **kw)
+        else:
+            ModelMetaclass.__init__(cls, classname, bases, dict_, **kw)
+
+
+def get_sqlachemy_type(field: ModelField) -> Any:
+    if issubclass(field.type_, str):
+        if field.field_info.max_length:
+            return AutoString(length=field.field_info.max_length)
+        return AutoString
+    if issubclass(field.type_, float):
+        return Float
+    if issubclass(field.type_, bool):
+        return Boolean
+    if issubclass(field.type_, int):
+        return Integer
+    if issubclass(field.type_, datetime):
+        return DateTime
+    if issubclass(field.type_, date):
+        return Date
+    if issubclass(field.type_, timedelta):
+        return Interval
+    if issubclass(field.type_, time):
+        return Time
+    if issubclass(field.type_, Enum):
+        return Enum
+    if issubclass(field.type_, bytes):
+        return LargeBinary
+    if issubclass(field.type_, Decimal):
+        return Numeric
+    if issubclass(field.type_, ipaddress.IPv4Address):
+        return AutoString
+    if issubclass(field.type_, ipaddress.IPv4Network):
+        return AutoString
+    if issubclass(field.type_, ipaddress.IPv6Address):
+        return AutoString
+    if issubclass(field.type_, ipaddress.IPv6Network):
+        return AutoString
+    if issubclass(field.type_, Path):
+        return AutoString
+    if issubclass(field.type_, uuid.UUID):
+        return GUID
+
+
+def get_column_from_field(field: ModelField) -> Column:
+    sa_column = getattr(field.field_info, "sa_column", Undefined)
+    if isinstance(sa_column, Column):
+        return sa_column
+    sa_type = get_sqlachemy_type(field)
+    primary_key = getattr(field.field_info, "primary_key", False)
+    nullable = not field.required
+    index = getattr(field.field_info, "index", Undefined)
+    if index is Undefined:
+        index = True
+    if hasattr(field.field_info, "nullable"):
+        field_nullable = getattr(field.field_info, "nullable")
+        if field_nullable != Undefined:
+            nullable = field_nullable
+    args = []
+    foreign_key = getattr(field.field_info, "foreign_key", None)
+    if foreign_key:
+        args.append(ForeignKey(foreign_key))
+    kwargs = {
+        "primary_key": primary_key,
+        "nullable": nullable,
+        "index": index,
+    }
+    sa_default = Undefined
+    if field.field_info.default_factory:
+        sa_default = field.field_info.default_factory
+    elif field.field_info.default is not Undefined:
+        sa_default = field.field_info.default
+    if sa_default is not Undefined:
+        kwargs["default"] = sa_default
+    sa_column_args = getattr(field.field_info, "sa_column_args", Undefined)
+    if sa_column_args is not Undefined:
+        args.extend(list(cast(Sequence, sa_column_args)))
+    sa_column_kwargs = getattr(field.field_info, "sa_column_kwargs", Undefined)
+    if sa_column_kwargs is not Undefined:
+        kwargs.update(cast(dict, sa_column_kwargs))
+    return Column(sa_type, *args, **kwargs)
+
+
+class_registry = weakref.WeakValueDictionary()  # type: ignore
+
+default_registry = registry()
+
+
+class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry):
+    # SQLAlchemy needs to set weakref(s), Pydantic will set the other slots values
+    __slots__ = ("__weakref__",)
+    __tablename__: ClassVar[Union[str, Callable[..., str]]]
+    __sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty]]
+    __name__: ClassVar[str]
+    metadata: ClassVar[MetaData]
+
+    class Config:
+        orm_mode = True
+
+    def __new__(cls, *args, **kwargs) -> Any:
+        new_object = super().__new__(cls)
+        # SQLAlchemy doesn't call __init__ on the base class
+        # Ref: https://docs.sqlalchemy.org/en/14/orm/constructors.html
+        # Set __fields_set__ here, that would have been set when calling __init__
+        # in the Pydantic model so that when SQLAlchemy sets attributes that are
+        # added (e.g. when querying from DB) to the __fields_set__, this already exists
+        object.__setattr__(new_object, "__fields_set__", set())
+        return new_object
+
+    def __init__(__pydantic_self__, **data: Any) -> None:
+        # Uses something other than `self` the first arg to allow "self" as a
+        # settable attribute
+        if TYPE_CHECKING:
+            __pydantic_self__.__dict__: Dict[str, Any] = {}
+            __pydantic_self__.__fields_set__: Set[str] = set()
+        values, fields_set, validation_error = validate_model(
+            __pydantic_self__.__class__, data
+        )
+        # Only raise errors if not a SQLModel model
+        if (
+            not getattr(__pydantic_self__.__config__, "table", False)
+            and validation_error
+        ):
+            raise validation_error
+        # Do not set values as in Pydantic, pass them through setattr, so SQLAlchemy
+        # can handle them
+        # object.__setattr__(__pydantic_self__, '__dict__', values)
+        object.__setattr__(__pydantic_self__, "__fields_set__", fields_set)
+        for key, value in values.items():
+            setattr(__pydantic_self__, key, value)
+        non_pydantic_keys = data.keys() - values.keys()
+        for key in non_pydantic_keys:
+            if key in __pydantic_self__.__sqlmodel_relationships__:
+                setattr(__pydantic_self__, key, data[key])
+
+    def __setattr__(self, name: str, value: Any) -> None:
+        if name in {"_sa_instance_state"}:
+            self.__dict__[name] = value
+            return
+        else:
+            # Set in SQLAlchemy, before Pydantic to trigger events and updates
+            if getattr(self.__config__, "table", False):
+                if is_instrumented(self, name):
+                    set_attribute(self, name, value)
+            # Set in Pydantic model to trigger possible validation changes, only for
+            # non relationship values
+            if name not in self.__sqlmodel_relationships__:
+                super().__setattr__(name, value)
+
+    @classmethod
+    def from_orm(cls: Type["SQLModel"], obj: Any, update: Dict[str, Any] = None):
+        # Duplicated from Pydantic
+        if not cls.__config__.orm_mode:
+            raise ConfigError(
+                "You must have the config attribute orm_mode=True to use from_orm"
+            )
+        obj = {ROOT_KEY: obj} if cls.__custom_root_type__ else cls._decompose_class(obj)
+        # SQLModel, support update dict
+        if update is not None:
+            obj = {**obj, **update}
+        # End SQLModel support dict
+        if not getattr(cls.__config__, "table", False):
+            # If not table, normal Pydantic code
+            m = cls.__new__(cls)
+        else:
+            # If table, create the new instance normally to make SQLAlchemy create
+            # the _sa_instance_state attribute
+            m = cls()
+        values, fields_set, validation_error = validate_model(cls, obj)
+        if validation_error:
+            raise validation_error
+        # Updated to trigger SQLAlchemy internal handling
+        if not getattr(cls.__config__, "table", False):
+            object.__setattr__(m, "__dict__", values)
+        else:
+            for key, value in values.items():
+                setattr(m, key, value)
+        # Continue with standard Pydantic logic
+        object.__setattr__(m, "__fields_set__", fields_set)
+        m._init_private_attributes()
+        return m
+
+    @classmethod
+    def parse_obj(
+        cls: Type["SQLModel"], obj: Any, update: Dict[str, Any] = None
+    ) -> "SQLModel":
+        obj = cls._enforce_dict_if_root(obj)
+        # SQLModel, support update dict
+        if update is not None:
+            obj = {**obj, **update}
+        # End SQLModel support dict
+        return super().parse_obj(obj)
+
+    def __repr_args__(self) -> Sequence[Tuple[Optional[str], Any]]:
+        # Don't show SQLAlchemy private attributes
+        return [(k, v) for k, v in self.__dict__.items() if not k.startswith("_sa_")]
+
+    # From Pydantic, override to enforce validation with dict
+    @classmethod
+    def validate(cls: Type["SQLModel"], value: Any) -> "SQLModel":
+        if isinstance(value, cls):
+            return value.copy() if cls.__config__.copy_on_model_validation else value
+
+        value = cls._enforce_dict_if_root(value)
+        if isinstance(value, dict):
+            values, fields_set, validation_error = validate_model(cls, value)
+            if validation_error:
+                raise validation_error
+            model = cls(**values)
+            # Reset fields set, this would have been done in Pydantic in __init__
+            object.__setattr__(model, "__fields_set__", fields_set)
+            return model
+        elif cls.__config__.orm_mode:
+            return cls.from_orm(value)
+        elif cls.__custom_root_type__:
+            return cls.parse_obj(value)
+        else:
+            try:
+                value_as_dict = dict(value)
+            except (TypeError, ValueError) as e:
+                raise DictError() from e
+            return cls(**value_as_dict)
+
+    # From Pydantic, override to only show keys from fields, omit SQLAlchemy attributes
+    def _calculate_keys(
+        self,
+        include: Optional[Mapping[Union[int, str], Any]],
+        exclude: Optional[Mapping[Union[int, str], Any]],
+        exclude_unset: bool,
+        update: Optional[Dict[str, Any]] = None,
+    ) -> Optional[AbstractSet[str]]:
+        if include is None and exclude is None and exclude_unset is False:
+            # Original in Pydantic:
+            # return None
+            # Updated to not return SQLAlchemy attributes
+            # Do not include relationships as that would easily lead to infinite
+            # recursion, or traversing the whole database
+            return self.__fields__.keys()  # | self.__sqlmodel_relationships__.keys()
+
+        keys: AbstractSet[str]
+        if exclude_unset:
+            keys = self.__fields_set__.copy()
+        else:
+            # Original in Pydantic:
+            # keys = self.__dict__.keys()
+            # Updated to not return SQLAlchemy attributes
+            # Do not include relationships as that would easily lead to infinite
+            # recursion, or traversing the whole database
+            keys = self.__fields__.keys()  # | self.__sqlmodel_relationships__.keys()
+
+        if include is not None:
+            keys &= include.keys()
+
+        if update:
+            keys -= update.keys()
+
+        if exclude:
+            keys -= {k for k, v in exclude.items() if ValueItems.is_true(v)}
+
+        return keys
+
+    @declared_attr  # type: ignore
+    def __tablename__(cls) -> str:
+        return cls.__name__.lower()
diff --git a/sqlmodel/orm/__init__.py b/sqlmodel/orm/__init__.py
new file mode 100644 (file)
index 0000000..e69de29
diff --git a/sqlmodel/orm/session.py b/sqlmodel/orm/session.py
new file mode 100644 (file)
index 0000000..a96544e
--- /dev/null
@@ -0,0 +1,135 @@
+from typing import Any, Mapping, Optional, Sequence, TypeVar, Union, overload
+
+from sqlalchemy import util
+from sqlalchemy.orm import Query as _Query
+from sqlalchemy.orm import Session as _Session
+from sqlalchemy.sql.base import Executable as _Executable
+from sqlmodel.sql.expression import Select, SelectOfScalar
+from typing_extensions import Literal
+
+from ..engine.result import Result, ScalarResult
+from ..sql.base import Executable
+
+_T = TypeVar("_T")
+
+
+class Session(_Session):
+    @overload
+    def exec(
+        self,
+        statement: Select[_T],
+        *,
+        params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
+        execution_options: Mapping[str, Any] = util.EMPTY_DICT,
+        bind_arguments: Optional[Mapping[str, Any]] = None,
+        _parent_execute_state: Optional[Any] = None,
+        _add_event: Optional[Any] = None,
+        **kw: Any,
+    ) -> Union[Result[_T]]:
+        ...
+
+    @overload
+    def exec(
+        self,
+        statement: SelectOfScalar[_T],
+        *,
+        params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
+        execution_options: Mapping[str, Any] = util.EMPTY_DICT,
+        bind_arguments: Optional[Mapping[str, Any]] = None,
+        _parent_execute_state: Optional[Any] = None,
+        _add_event: Optional[Any] = None,
+        **kw: Any,
+    ) -> Union[ScalarResult[_T]]:
+        ...
+
+    def exec(
+        self,
+        statement: Union[Select[_T], SelectOfScalar[_T], Executable[_T]],
+        *,
+        params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
+        execution_options: Mapping[str, Any] = util.EMPTY_DICT,
+        bind_arguments: Optional[Mapping[str, Any]] = None,
+        _parent_execute_state: Optional[Any] = None,
+        _add_event: Optional[Any] = None,
+        **kw: Any,
+    ) -> Union[Result[_T], ScalarResult[_T]]:
+        results = super().execute(
+            statement,
+            params=params,
+            execution_options=execution_options,  # type: ignore
+            bind_arguments=bind_arguments,
+            _parent_execute_state=_parent_execute_state,
+            _add_event=_add_event,
+            **kw,
+        )
+        if isinstance(statement, SelectOfScalar):
+            return results.scalars()  # type: ignore
+        return results  # type: ignore
+
+    def execute(
+        self,
+        statement: _Executable,
+        params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
+        execution_options: Mapping[str, Any] = util.EMPTY_DICT,
+        bind_arguments: Optional[Mapping[str, Any]] = None,
+        _parent_execute_state: Optional[Any] = None,
+        _add_event: Optional[Any] = None,
+        **kw: Any,
+    ) -> Result[Any]:
+        """
+        🚨 You probably want to use `session.exec()` instead of `session.execute()`.
+
+        This is the original SQLAlchemy `session.execute()` method that returns objects
+        of type `Row`, and that you have to call `scalars()` to get the model objects.
+
+        For example:
+
+        ```Python
+        heroes = session.execute(select(Hero)).scalars().all()
+        ```
+
+        instead you could use `exec()`:
+
+        ```Python
+        heroes = session.exec(select(Hero)).all()
+        ```
+        """
+        return super().execute(  # type: ignore
+            statement,
+            params=params,
+            execution_options=execution_options,  # type: ignore
+            bind_arguments=bind_arguments,
+            _parent_execute_state=_parent_execute_state,
+            _add_event=_add_event,
+            **kw,
+        )
+
+    def query(self, *entities: Any, **kwargs: Any) -> "_Query[Any]":
+        """
+        🚨 You probably want to use `session.exec()` instead of `session.query()`.
+
+        `session.exec()` is SQLModel's own short version with increased type
+        annotations.
+
+        Or otherwise you might want to use `session.execute()` instead of
+        `session.query()`.
+        """
+        return super().query(*entities, **kwargs)
+
+    def get(
+        self,
+        entity: _T,
+        ident: Any,
+        options: Optional[Sequence[Any]] = None,
+        populate_existing: bool = False,
+        with_for_update: Optional[Union[Literal[True], Mapping[str, Any]]] = None,
+        identity_token: Optional[Any] = None,
+    ) -> _T:
+        return super().get(
+            entity,
+            ident,
+            options=options,
+            populate_existing=populate_existing,
+            with_for_update=with_for_update,
+            identity_token=identity_token,
+        )
diff --git a/sqlmodel/pool/__init__.py b/sqlmodel/pool/__init__.py
new file mode 100644 (file)
index 0000000..20bb952
--- /dev/null
@@ -0,0 +1 @@
+from sqlalchemy.pool import StaticPool as StaticPool  # noqa: F401
diff --git a/sqlmodel/sql/__init__.py b/sqlmodel/sql/__init__.py
new file mode 100644 (file)
index 0000000..e69de29
diff --git a/sqlmodel/sql/base.py b/sqlmodel/sql/base.py
new file mode 100644 (file)
index 0000000..129e4d4
--- /dev/null
@@ -0,0 +1,11 @@
+from typing import Generic, TypeVar
+
+from sqlalchemy.sql.base import Executable as _Executable
+
+_T = TypeVar("_T")
+
+
+class Executable(_Executable, Generic[_T]):
+    def __init__(self, *args, **kwargs):
+        self.__dict__["_exec_options"] = kwargs.pop("_exec_options", None)
+        super(_Executable, self).__init__(*args, **kwargs)
diff --git a/sqlmodel/sql/expression.py b/sqlmodel/sql/expression.py
new file mode 100644 (file)
index 0000000..e8a922e
--- /dev/null
@@ -0,0 +1,459 @@
+# WARNING: do not modify this code, it is generated by expression.py.jinja2
+
+import sys
+from datetime import datetime
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Generic,
+    Mapping,
+    Sequence,
+    Tuple,
+    Type,
+    TypeVar,
+    Union,
+    cast,
+    overload,
+)
+from uuid import UUID
+
+from sqlalchemy import Column
+from sqlalchemy.orm import InstrumentedAttribute
+from sqlalchemy.sql.elements import ColumnClause
+from sqlalchemy.sql.expression import Select as _Select
+
+_TSelect = TypeVar("_TSelect")
+
+# Workaround Generics incompatibility in Python 3.6
+# Ref: https://github.com/python/typing/issues/449#issuecomment-316061322
+if sys.version_info.minor >= 7:
+
+    class Select(_Select, Generic[_TSelect]):
+        pass
+
+    # This is not comparable to sqlalchemy.sql.selectable.ScalarSelect, that has a different
+    # purpose. This is the same as a normal SQLAlchemy Select class where there's only one
+    # entity, so the result will be converted to a scalar by default. This way writing
+    # for loops on the results will feel natural.
+    class SelectOfScalar(_Select, Generic[_TSelect]):
+        pass
+
+
+else:
+    from typing import GenericMeta  # type: ignore
+
+    class GenericSelectMeta(GenericMeta, _Select.__class__):  # type: ignore
+        pass
+
+    class _Py36Select(_Select, Generic[_TSelect], metaclass=GenericSelectMeta):
+        pass
+
+    class _Py36SelectOfScalar(_Select, Generic[_TSelect], metaclass=GenericSelectMeta):
+        pass
+
+    # Cast them for editors to work correctly, from several tricks tried, this works
+    # for both VS Code and PyCharm
+    Select = cast("Select", _Py36Select)  # type: ignore
+    SelectOfScalar = cast("SelectOfScalar", _Py36SelectOfScalar)  # type: ignore
+
+
+if TYPE_CHECKING:  # pragma: no cover
+    from ..main import SQLModel
+
+# Generated TypeVars start
+
+
+_TScalar_0 = TypeVar(
+    "_TScalar_0",
+    Column,
+    Sequence,
+    Mapping,
+    UUID,
+    datetime,
+    float,
+    int,
+    bool,
+    bytes,
+    str,
+    None,
+)
+
+_TModel_0 = TypeVar("_TModel_0", bound="SQLModel")
+
+
+_TScalar_1 = TypeVar(
+    "_TScalar_1",
+    Column,
+    Sequence,
+    Mapping,
+    UUID,
+    datetime,
+    float,
+    int,
+    bool,
+    bytes,
+    str,
+    None,
+)
+
+_TModel_1 = TypeVar("_TModel_1", bound="SQLModel")
+
+
+_TScalar_2 = TypeVar(
+    "_TScalar_2",
+    Column,
+    Sequence,
+    Mapping,
+    UUID,
+    datetime,
+    float,
+    int,
+    bool,
+    bytes,
+    str,
+    None,
+)
+
+_TModel_2 = TypeVar("_TModel_2", bound="SQLModel")
+
+
+_TScalar_3 = TypeVar(
+    "_TScalar_3",
+    Column,
+    Sequence,
+    Mapping,
+    UUID,
+    datetime,
+    float,
+    int,
+    bool,
+    bytes,
+    str,
+    None,
+)
+
+_TModel_3 = TypeVar("_TModel_3", bound="SQLModel")
+
+
+# Generated TypeVars end
+
+
+@overload
+def select(entity_0: _TScalar_0, **kw: Any) -> SelectOfScalar[_TScalar_0]:  # type: ignore
+    ...
+
+
+@overload
+def select(entity_0: Type[_TModel_0], **kw: Any) -> SelectOfScalar[_TModel_0]:  # type: ignore
+    ...
+
+
+# Generated overloads start
+
+
+@overload
+def select(  # type: ignore
+    entity_0: _TScalar_0,
+    entity_1: _TScalar_1,
+    **kw: Any,
+) -> Select[Tuple[_TScalar_0, _TScalar_1]]:
+    ...
+
+
+@overload
+def select(  # type: ignore
+    entity_0: _TScalar_0,
+    entity_1: Type[_TModel_1],
+    **kw: Any,
+) -> Select[Tuple[_TScalar_0, _TModel_1]]:
+    ...
+
+
+@overload
+def select(  # type: ignore
+    entity_0: Type[_TModel_0],
+    entity_1: _TScalar_1,
+    **kw: Any,
+) -> Select[Tuple[_TModel_0, _TScalar_1]]:
+    ...
+
+
+@overload
+def select(  # type: ignore
+    entity_0: Type[_TModel_0],
+    entity_1: Type[_TModel_1],
+    **kw: Any,
+) -> Select[Tuple[_TModel_0, _TModel_1]]:
+    ...
+
+
+@overload
+def select(  # type: ignore
+    entity_0: _TScalar_0,
+    entity_1: _TScalar_1,
+    entity_2: _TScalar_2,
+    **kw: Any,
+) -> Select[Tuple[_TScalar_0, _TScalar_1, _TScalar_2]]:
+    ...
+
+
+@overload
+def select(  # type: ignore
+    entity_0: _TScalar_0,
+    entity_1: _TScalar_1,
+    entity_2: Type[_TModel_2],
+    **kw: Any,
+) -> Select[Tuple[_TScalar_0, _TScalar_1, _TModel_2]]:
+    ...
+
+
+@overload
+def select(  # type: ignore
+    entity_0: _TScalar_0,
+    entity_1: Type[_TModel_1],
+    entity_2: _TScalar_2,
+    **kw: Any,
+) -> Select[Tuple[_TScalar_0, _TModel_1, _TScalar_2]]:
+    ...
+
+
+@overload
+def select(  # type: ignore
+    entity_0: _TScalar_0,
+    entity_1: Type[_TModel_1],
+    entity_2: Type[_TModel_2],
+    **kw: Any,
+) -> Select[Tuple[_TScalar_0, _TModel_1, _TModel_2]]:
+    ...
+
+
+@overload
+def select(  # type: ignore
+    entity_0: Type[_TModel_0],
+    entity_1: _TScalar_1,
+    entity_2: _TScalar_2,
+    **kw: Any,
+) -> Select[Tuple[_TModel_0, _TScalar_1, _TScalar_2]]:
+    ...
+
+
+@overload
+def select(  # type: ignore
+    entity_0: Type[_TModel_0],
+    entity_1: _TScalar_1,
+    entity_2: Type[_TModel_2],
+    **kw: Any,
+) -> Select[Tuple[_TModel_0, _TScalar_1, _TModel_2]]:
+    ...
+
+
+@overload
+def select(  # type: ignore
+    entity_0: Type[_TModel_0],
+    entity_1: Type[_TModel_1],
+    entity_2: _TScalar_2,
+    **kw: Any,
+) -> Select[Tuple[_TModel_0, _TModel_1, _TScalar_2]]:
+    ...
+
+
+@overload
+def select(  # type: ignore
+    entity_0: Type[_TModel_0],
+    entity_1: Type[_TModel_1],
+    entity_2: Type[_TModel_2],
+    **kw: Any,
+) -> Select[Tuple[_TModel_0, _TModel_1, _TModel_2]]:
+    ...
+
+
+@overload
+def select(  # type: ignore
+    entity_0: _TScalar_0,
+    entity_1: _TScalar_1,
+    entity_2: _TScalar_2,
+    entity_3: _TScalar_3,
+    **kw: Any,
+) -> Select[Tuple[_TScalar_0, _TScalar_1, _TScalar_2, _TScalar_3]]:
+    ...
+
+
+@overload
+def select(  # type: ignore
+    entity_0: _TScalar_0,
+    entity_1: _TScalar_1,
+    entity_2: _TScalar_2,
+    entity_3: Type[_TModel_3],
+    **kw: Any,
+) -> Select[Tuple[_TScalar_0, _TScalar_1, _TScalar_2, _TModel_3]]:
+    ...
+
+
+@overload
+def select(  # type: ignore
+    entity_0: _TScalar_0,
+    entity_1: _TScalar_1,
+    entity_2: Type[_TModel_2],
+    entity_3: _TScalar_3,
+    **kw: Any,
+) -> Select[Tuple[_TScalar_0, _TScalar_1, _TModel_2, _TScalar_3]]:
+    ...
+
+
+@overload
+def select(  # type: ignore
+    entity_0: _TScalar_0,
+    entity_1: _TScalar_1,
+    entity_2: Type[_TModel_2],
+    entity_3: Type[_TModel_3],
+    **kw: Any,
+) -> Select[Tuple[_TScalar_0, _TScalar_1, _TModel_2, _TModel_3]]:
+    ...
+
+
+@overload
+def select(  # type: ignore
+    entity_0: _TScalar_0,
+    entity_1: Type[_TModel_1],
+    entity_2: _TScalar_2,
+    entity_3: _TScalar_3,
+    **kw: Any,
+) -> Select[Tuple[_TScalar_0, _TModel_1, _TScalar_2, _TScalar_3]]:
+    ...
+
+
+@overload
+def select(  # type: ignore
+    entity_0: _TScalar_0,
+    entity_1: Type[_TModel_1],
+    entity_2: _TScalar_2,
+    entity_3: Type[_TModel_3],
+    **kw: Any,
+) -> Select[Tuple[_TScalar_0, _TModel_1, _TScalar_2, _TModel_3]]:
+    ...
+
+
+@overload
+def select(  # type: ignore
+    entity_0: _TScalar_0,
+    entity_1: Type[_TModel_1],
+    entity_2: Type[_TModel_2],
+    entity_3: _TScalar_3,
+    **kw: Any,
+) -> Select[Tuple[_TScalar_0, _TModel_1, _TModel_2, _TScalar_3]]:
+    ...
+
+
+@overload
+def select(  # type: ignore
+    entity_0: _TScalar_0,
+    entity_1: Type[_TModel_1],
+    entity_2: Type[_TModel_2],
+    entity_3: Type[_TModel_3],
+    **kw: Any,
+) -> Select[Tuple[_TScalar_0, _TModel_1, _TModel_2, _TModel_3]]:
+    ...
+
+
+@overload
+def select(  # type: ignore
+    entity_0: Type[_TModel_0],
+    entity_1: _TScalar_1,
+    entity_2: _TScalar_2,
+    entity_3: _TScalar_3,
+    **kw: Any,
+) -> Select[Tuple[_TModel_0, _TScalar_1, _TScalar_2, _TScalar_3]]:
+    ...
+
+
+@overload
+def select(  # type: ignore
+    entity_0: Type[_TModel_0],
+    entity_1: _TScalar_1,
+    entity_2: _TScalar_2,
+    entity_3: Type[_TModel_3],
+    **kw: Any,
+) -> Select[Tuple[_TModel_0, _TScalar_1, _TScalar_2, _TModel_3]]:
+    ...
+
+
+@overload
+def select(  # type: ignore
+    entity_0: Type[_TModel_0],
+    entity_1: _TScalar_1,
+    entity_2: Type[_TModel_2],
+    entity_3: _TScalar_3,
+    **kw: Any,
+) -> Select[Tuple[_TModel_0, _TScalar_1, _TModel_2, _TScalar_3]]:
+    ...
+
+
+@overload
+def select(  # type: ignore
+    entity_0: Type[_TModel_0],
+    entity_1: _TScalar_1,
+    entity_2: Type[_TModel_2],
+    entity_3: Type[_TModel_3],
+    **kw: Any,
+) -> Select[Tuple[_TModel_0, _TScalar_1, _TModel_2, _TModel_3]]:
+    ...
+
+
+@overload
+def select(  # type: ignore
+    entity_0: Type[_TModel_0],
+    entity_1: Type[_TModel_1],
+    entity_2: _TScalar_2,
+    entity_3: _TScalar_3,
+    **kw: Any,
+) -> Select[Tuple[_TModel_0, _TModel_1, _TScalar_2, _TScalar_3]]:
+    ...
+
+
+@overload
+def select(  # type: ignore
+    entity_0: Type[_TModel_0],
+    entity_1: Type[_TModel_1],
+    entity_2: _TScalar_2,
+    entity_3: Type[_TModel_3],
+    **kw: Any,
+) -> Select[Tuple[_TModel_0, _TModel_1, _TScalar_2, _TModel_3]]:
+    ...
+
+
+@overload
+def select(  # type: ignore
+    entity_0: Type[_TModel_0],
+    entity_1: Type[_TModel_1],
+    entity_2: Type[_TModel_2],
+    entity_3: _TScalar_3,
+    **kw: Any,
+) -> Select[Tuple[_TModel_0, _TModel_1, _TModel_2, _TScalar_3]]:
+    ...
+
+
+@overload
+def select(  # type: ignore
+    entity_0: Type[_TModel_0],
+    entity_1: Type[_TModel_1],
+    entity_2: Type[_TModel_2],
+    entity_3: Type[_TModel_3],
+    **kw: Any,
+) -> Select[Tuple[_TModel_0, _TModel_1, _TModel_2, _TModel_3]]:
+    ...
+
+
+# Generated overloads end
+
+
+def select(*entities: Any, **kw: Any) -> Union[Select, SelectOfScalar]:
+    if len(entities) == 1:
+        return SelectOfScalar._create(*entities, **kw)  # type: ignore
+    return Select._create(*entities, **kw)  # type: ignore
+
+
+# TODO: add several @overload from Python types to SQLAlchemy equivalents
+def col(column_expression: Any) -> ColumnClause:
+    if not isinstance(column_expression, (ColumnClause, Column, InstrumentedAttribute)):
+        raise RuntimeError(f"Not a SQLAlchemy column: {column_expression}")
+    return column_expression
diff --git a/sqlmodel/sql/expression.py.jinja2 b/sqlmodel/sql/expression.py.jinja2
new file mode 100644 (file)
index 0000000..b39d636
--- /dev/null
@@ -0,0 +1,119 @@
+import sys
+from datetime import datetime
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Generic,
+    Mapping,
+    Sequence,
+    Tuple,
+    Type,
+    TypeVar,
+    Union,
+    cast,
+    overload,
+)
+from uuid import UUID
+
+from sqlalchemy import Column
+from sqlalchemy.orm import InstrumentedAttribute
+from sqlalchemy.sql.elements import ColumnClause
+from sqlalchemy.sql.expression import Select as _Select
+
+_TSelect = TypeVar("_TSelect")
+
+# Workaround Generics incompatibility in Python 3.6
+# Ref: https://github.com/python/typing/issues/449#issuecomment-316061322
+if sys.version_info.minor >= 7:
+
+    class Select(_Select, Generic[_TSelect]):
+        pass
+
+    # This is not comparable to sqlalchemy.sql.selectable.ScalarSelect, that has a different
+    # purpose. This is the same as a normal SQLAlchemy Select class where there's only one
+    # entity, so the result will be converted to a scalar by default. This way writing
+    # for loops on the results will feel natural.
+    class SelectOfScalar(_Select, Generic[_TSelect]):
+        pass
+
+
+else:
+    from typing import GenericMeta  # type: ignore
+
+    class GenericSelectMeta(GenericMeta, _Select.__class__):  # type: ignore
+        pass
+
+    class _Py36Select(_Select, Generic[_TSelect], metaclass=GenericSelectMeta):
+        pass
+
+    class _Py36SelectOfScalar(_Select, Generic[_TSelect], metaclass=GenericSelectMeta):
+        pass
+
+    # Cast them for editors to work correctly, from several tricks tried, this works
+    # for both VS Code and PyCharm
+    Select = cast("Select", _Py36Select)  # type: ignore
+    SelectOfScalar = cast("SelectOfScalar", _Py36SelectOfScalar)  # type: ignore
+
+
+if TYPE_CHECKING:  # pragma: no cover
+    from ..main import SQLModel
+
+# Generated TypeVars start
+
+{% for i in range(number_of_types) %}
+_TScalar_{{ i }} = TypeVar(
+    "_TScalar_{{ i }}",
+    Column,
+    Sequence,
+    Mapping,
+    UUID,
+    datetime,
+    float,
+    int,
+    bool,
+    bytes,
+    str,
+    None,
+)
+
+_TModel_{{ i }} = TypeVar("_TModel_{{ i }}", bound="SQLModel")
+
+{% endfor %}
+
+# Generated TypeVars end
+
+@overload
+def select(entity_0: _TScalar_0, **kw: Any) -> SelectOfScalar[_TScalar_0]:  # type: ignore
+    ...
+
+
+@overload
+def select(entity_0: Type[_TModel_0], **kw: Any) -> SelectOfScalar[_TModel_0]:  # type: ignore
+    ...
+
+
+# Generated overloads start
+
+{% for signature in signatures %}
+
+@overload
+def select(  # type: ignore
+    {% for arg in signature[0] %}{{ arg.name }}: {{ arg.annotation }}, {% endfor %}**kw: Any,
+    ) -> Select[Tuple[{%for ret in signature[1] %}{{ ret }} {% if not loop.last %}, {% endif %}{% endfor %}]]:
+    ...
+
+{% endfor %}
+
+# Generated overloads end
+
+def select(*entities: Any, **kw: Any) -> Union[Select, SelectOfScalar]:
+    if len(entities) == 1:
+        return SelectOfScalar._create(*entities, **kw)  # type: ignore
+    return Select._create(*entities, **kw)
+
+
+# TODO: add several @overload from Python types to SQLAlchemy equivalents
+def col(column_expression: Any) -> ColumnClause:
+    if not isinstance(column_expression, (ColumnClause, Column, InstrumentedAttribute)):
+        raise RuntimeError(f"Not a SQLAlchemy column: {column_expression}")
+    return column_expression
diff --git a/sqlmodel/sql/sqltypes.py b/sqlmodel/sql/sqltypes.py
new file mode 100644 (file)
index 0000000..e7b77b8
--- /dev/null
@@ -0,0 +1,60 @@
+import uuid
+from typing import Any, cast
+
+from sqlalchemy import types
+from sqlalchemy.dialects.postgresql import UUID
+from sqlalchemy.engine.interfaces import Dialect
+from sqlalchemy.types import CHAR, TypeDecorator
+
+
+class AutoString(types.TypeDecorator):
+
+    impl = types.String
+    cache_ok = True
+    mysql_default_length = 255
+
+    def load_dialect_impl(self, dialect: Dialect) -> "types.TypeEngine[Any]":
+        impl = cast(types.String, self.impl)
+        if impl.length is None and dialect.name == "mysql":
+            return dialect.type_descriptor(types.String(self.mysql_default_length))  # type: ignore
+        return super().load_dialect_impl(dialect)
+
+
+# Reference form SQLAlchemy docs: https://docs.sqlalchemy.org/en/14/core/custom_types.html#backend-agnostic-guid-type
+# with small modifications
+class GUID(TypeDecorator):
+    """Platform-independent GUID type.
+
+    Uses PostgreSQL's UUID type, otherwise uses
+    CHAR(32), storing as stringified hex values.
+
+    """
+
+    impl = CHAR
+    cache_ok = True
+
+    def load_dialect_impl(self, dialect):
+        if dialect.name == "postgresql":
+            return dialect.type_descriptor(UUID())
+        else:
+            return dialect.type_descriptor(CHAR(32))
+
+    def process_bind_param(self, value, dialect):
+        if value is None:
+            return value
+        elif dialect.name == "postgresql":
+            return str(value)
+        else:
+            if not isinstance(value, uuid.UUID):
+                return f"{uuid.UUID(value).int:x}"
+            else:
+                # hexstring
+                return f"{value.int:x}"
+
+    def process_result_value(self, value, dialect):
+        if value is None:
+            return value
+        else:
+            if not isinstance(value, uuid.UUID):
+                value = uuid.UUID(value)
+            return value