From: Mike Bayer Date: Thu, 17 Mar 2022 20:18:55 +0000 (-0400) Subject: pep 484 for types X-Git-Tag: rel_2_0_0b1~417^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=6c3d738757d7be32dc9f99d8e1c6b5c81c596d5f;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git pep 484 for types strict types type_api.py, including TypeDecorator, NativeForEmulated, etc. Change-Id: Ib2eba26de0981324a83733954cb7044a29bbd7db --- diff --git a/lib/sqlalchemy/connectors/pyodbc.py b/lib/sqlalchemy/connectors/pyodbc.py index 61b40d935e..59951cd041 100644 --- a/lib/sqlalchemy/connectors/pyodbc.py +++ b/lib/sqlalchemy/connectors/pyodbc.py @@ -143,7 +143,9 @@ class PyODBCConnector(Connector): def is_disconnect( self, e: Exception, - connection: Optional[pool.PoolProxiedConnection], + connection: Optional[ + Union[pool.PoolProxiedConnection, interfaces.DBAPIConnection] + ], cursor: Optional[interfaces.DBAPICursor], ) -> bool: if isinstance(e, self.dbapi.ProgrammingError): diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index dd792f70d7..07ff495a71 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -2311,7 +2311,7 @@ class MSDDLCompiler(compiler.DDLCompiler): if column.computed is not None: colspec += " " + self.process(column.computed) else: - colspec += " " + self.dialect.type_compiler.process( + colspec += " " + self.dialect.type_compiler_instance.process( column.type, type_expression=column ) @@ -2719,7 +2719,7 @@ class MSDialect(default.DefaultDialect): statement_compiler = MSSQLCompiler ddl_compiler = MSDDLCompiler - type_compiler = MSTypeCompiler + type_compiler_cls = MSTypeCompiler preparer = MSIdentifierPreparer construct_arguments = [ diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index b2ccfc90f2..cfdc3deb2b 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -1414,25 +1414,25 @@ class MySQLCompiler(compiler.SQLCompiler): sqltypes.Time, ), ): - return self.dialect.type_compiler.process(type_) + return self.dialect.type_compiler_instance.process(type_) elif isinstance(type_, sqltypes.String) and not isinstance( type_, (ENUM, SET) ): adapted = CHAR._adapt_string_for_cast(type_) - return self.dialect.type_compiler.process(adapted) + return self.dialect.type_compiler_instance.process(adapted) elif isinstance(type_, sqltypes._Binary): return "BINARY" elif isinstance(type_, sqltypes.JSON): return "JSON" elif isinstance(type_, sqltypes.NUMERIC): - return self.dialect.type_compiler.process(type_).replace( + return self.dialect.type_compiler_instance.process(type_).replace( "NUMERIC", "DECIMAL" ) elif ( isinstance(type_, sqltypes.Float) and self.dialect._support_float_cast ): - return self.dialect.type_compiler.process(type_) + return self.dialect.type_compiler_instance.process(type_) else: return None @@ -1442,7 +1442,9 @@ class MySQLCompiler(compiler.SQLCompiler): util.warn( "Datatype %s does not support CAST on MySQL/MariaDb; " "the CAST will be skipped." - % self.dialect.type_compiler.process(cast.typeclause.type) + % self.dialect.type_compiler_instance.process( + cast.typeclause.type + ) ) return self.process(cast.clause.self_group(), **kw) @@ -1699,7 +1701,7 @@ class MySQLDDLCompiler(compiler.DDLCompiler): colspec = [ self.preparer.format_column(column), - self.dialect.type_compiler.process( + self.dialect.type_compiler_instance.process( column.type, type_expression=column ), ] @@ -2358,7 +2360,7 @@ class MySQLDialect(default.DefaultDialect): statement_compiler = MySQLCompiler ddl_compiler = MySQLDDLCompiler - type_compiler = MySQLTypeCompiler + type_compiler_cls = MySQLTypeCompiler ischema_names = ischema_names preparer = MySQLIdentifierPreparer diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py index 1ae58b8f47..3ee38c0cf1 100644 --- a/lib/sqlalchemy/dialects/oracle/base.py +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -743,9 +743,6 @@ class INTERVAL(sqltypes.NativeForEmulated, sqltypes._AbstractInterval): day_precision=self.day_precision, ) - def coerce_compared_value(self, op, value): - return self - class ROWID(sqltypes.TypeEngine): """Oracle ROWID type. @@ -1537,7 +1534,7 @@ class OracleDialect(default.DefaultDialect): statement_compiler = OracleCompiler ddl_compiler = OracleDDLCompiler - type_compiler = OracleTypeCompiler + type_compiler_cls = OracleTypeCompiler preparer = OracleIdentifierPreparer execution_ctx_cls = OracleExecutionContext diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index e265cd0f7f..a1401ea031 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -1712,9 +1712,6 @@ class INTERVAL(sqltypes.NativeForEmulated, sqltypes._AbstractInterval): def python_type(self): return dt.timedelta - def coerce_compared_value(self, op, value): - return self - PGInterval = INTERVAL @@ -2168,7 +2165,7 @@ ischema_names = { class PGCompiler(compiler.SQLCompiler): def render_bind_cast(self, type_, dbapi_type, sqltext): return f"""{sqltext}::{ - self.dialect.type_compiler.process( + self.dialect.type_compiler_instance.process( dbapi_type, identifier_preparer=self.preparer ) }""" @@ -2318,7 +2315,7 @@ class PGCompiler(compiler.SQLCompiler): return "SELECT %s WHERE 1!=1" % ( ", ".join( "CAST(NULL AS %s)" - % self.dialect.type_compiler.process( + % self.dialect.type_compiler_instance.process( INTEGER() if type_._isnull else type_ ) for type_ in element_types or [INTEGER()] @@ -2604,7 +2601,7 @@ class PGDDLCompiler(compiler.DDLCompiler): else: colspec += " SERIAL" else: - colspec += " " + self.dialect.type_compiler.process( + colspec += " " + self.dialect.type_compiler_instance.process( column.type, type_expression=column, identifier_preparer=self.preparer, @@ -3225,7 +3222,7 @@ class PGDialect(default.DefaultDialect): statement_compiler = PGCompiler ddl_compiler = PGDDLCompiler - type_compiler = PGTypeCompiler + type_compiler_cls = PGTypeCompiler preparer = PGIdentifierPreparer execution_ctx_cls = PGExecutionContext inspector = PGInspector diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py index 79068c75f0..03f35d5e2d 100644 --- a/lib/sqlalchemy/dialects/sqlite/base.py +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -1426,7 +1426,7 @@ class SQLiteCompiler(compiler.SQLCompiler): class SQLiteDDLCompiler(compiler.DDLCompiler): def get_column_specification(self, column, **kwargs): - coltype = self.dialect.type_compiler.process( + coltype = self.dialect.type_compiler_instance.process( column.type, type_expression=column ) colspec = self.preparer.format_column(column) + " " + coltype @@ -1815,7 +1815,7 @@ class SQLiteDialect(default.DefaultDialect): execution_ctx_cls = SQLiteExecutionContext statement_compiler = SQLiteCompiler ddl_compiler = SQLiteDDLCompiler - type_compiler = SQLiteTypeCompiler + type_compiler_cls = SQLiteTypeCompiler preparer = SQLiteIdentifierPreparer ischema_names = ischema_names colspecs = colspecs diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 5b66c537aa..061794bded 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -141,7 +141,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): if connection is None: try: self._dbapi_connection = engine.raw_connection() - except dialect.dbapi.Error as err: + except dialect.loaded_dbapi.Error as err: Connection._handle_dbapi_exception_noconnection( err, dialect, engine ) @@ -1809,7 +1809,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): if not self._is_disconnect: self._is_disconnect = ( - isinstance(e, self.dialect.dbapi.Error) + isinstance(e, self.dialect.loaded_dbapi.Error) and not self.closed and self.dialect.is_disconnect( e, @@ -1825,7 +1825,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): statement, parameters, e, - self.dialect.dbapi.Error, + self.dialect.loaded_dbapi.Error, hide_parameters=self.engine.hide_parameters, dialect=self.dialect, ismulti=context.executemany if context is not None else None, @@ -1834,7 +1834,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): try: # non-DBAPI error - if we already got a context, # or there's no string statement, don't wrap it - should_wrap = isinstance(e, self.dialect.dbapi.Error) or ( + should_wrap = isinstance(e, self.dialect.loaded_dbapi.Error) or ( statement is not None and context is None and not is_exit_exception @@ -1845,7 +1845,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): statement, parameters, cast(Exception, e), - self.dialect.dbapi.Error, + self.dialect.loaded_dbapi.Error, hide_parameters=self.engine.hide_parameters, connection_invalidated=self._is_disconnect, dialect=self.dialect, @@ -1943,17 +1943,17 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): exc_info = sys.exc_info() is_disconnect = isinstance( - e, dialect.dbapi.Error + e, dialect.loaded_dbapi.Error ) and dialect.is_disconnect(e, None, None) - should_wrap = isinstance(e, dialect.dbapi.Error) + should_wrap = isinstance(e, dialect.loaded_dbapi.Error) if should_wrap: sqlalchemy_exception = exc.DBAPIError.instance( None, None, cast(Exception, e), - dialect.dbapi.Error, + dialect.loaded_dbapi.Error, hide_parameters=engine.hide_parameters, connection_invalidated=is_disconnect, ) diff --git a/lib/sqlalchemy/engine/cursor.py b/lib/sqlalchemy/engine/cursor.py index f776e59753..4f151e79cc 100644 --- a/lib/sqlalchemy/engine/cursor.py +++ b/lib/sqlalchemy/engine/cursor.py @@ -23,6 +23,7 @@ from typing import List from typing import Optional from typing import Sequence from typing import Tuple +from typing import TYPE_CHECKING from typing import Union from .result import MergedResult @@ -57,7 +58,7 @@ if typing.TYPE_CHECKING: from .result import _KeyMapType from .result import _KeyType from .result import _ProcessorsType - from .result import _ProcessorType + from ..sql.type_api import _ResultProcessorType # metadata entry tuple indexes. # using raw tuple is faster than namedtuple. @@ -77,7 +78,7 @@ MD_UNTRANSLATED: Literal[6] = 6 # raw name from cursor.description _CursorKeyMapRecType = Tuple[ - int, int, List[Any], str, str, Optional["_ProcessorType"], str + int, int, List[Any], str, str, Optional["_ResultProcessorType"], str ] _CursorKeyMapType = Dict["_KeyType", _CursorKeyMapRecType] @@ -164,6 +165,9 @@ class CursorResultMetaData(ResultMetaData): compiled_statement = context.compiled.statement invoked_statement = context.invoked_statement + if TYPE_CHECKING: + assert isinstance(invoked_statement, elements.ClauseElement) + if compiled_statement is invoked_statement: return self diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index ba34a0d421..4a833d2e54 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -57,6 +57,8 @@ from ..sql.compiler import SQLCompiler from ..sql.elements import quoted_name if typing.TYPE_CHECKING: + from types import ModuleType + from .base import Connection from .base import Engine from .characteristics import ConnectionCharacteristic @@ -67,8 +69,10 @@ if typing.TYPE_CHECKING: from .interfaces import _DBAPIMultiExecuteParams from .interfaces import _DBAPISingleExecuteParams from .interfaces import _ExecuteOptions + from .interfaces import _IsolationLevel from .interfaces import _MutableCoreSingleExecuteParams - from .result import _ProcessorType + from .interfaces import _ParamStyle + from .interfaces import DBAPIConnection from .row import Row from .url import URL from ..event import _ListenerFnType @@ -76,12 +80,16 @@ if typing.TYPE_CHECKING: from ..pool import PoolProxiedConnection from ..sql import Executable from ..sql.compiler import Compiled + from ..sql.compiler import Linting from ..sql.compiler import ResultColumnsEntry from ..sql.compiler import TypeCompiler from ..sql.dml import DMLState + from ..sql.dml import UpdateBase from ..sql.elements import BindParameter from ..sql.schema import Column from ..sql.schema import ColumnDefault + from ..sql.type_api import _BindProcessorType + from ..sql.type_api import _ResultProcessorType from ..sql.type_api import TypeEngine # When we're handed literal SQL, ensure it's a SELECT query @@ -102,10 +110,7 @@ class DefaultDialect(Dialect): statement_compiler = compiler.SQLCompiler ddl_compiler = compiler.DDLCompiler - if typing.TYPE_CHECKING: - type_compiler: TypeCompiler - else: - type_compiler = compiler.GenericTypeCompiler + type_compiler_cls = compiler.GenericTypeCompiler preparer = compiler.IdentifierPreparer supports_alter = True @@ -253,20 +258,19 @@ class DefaultDialect(Dialect): ) def __init__( self, - paramstyle=None, - isolation_level=None, - dbapi=None, - implicit_returning=None, - supports_native_boolean=None, - max_identifier_length=None, - label_length=None, - # int() is because the @deprecated_params decorator cannot accommodate - # the direct reference to the "NO_LINTING" object - compiler_linting=int(compiler.NO_LINTING), - server_side_cursors=False, - **kwargs, + paramstyle: Optional[_ParamStyle] = None, + isolation_level: Optional[_IsolationLevel] = None, + dbapi: Optional[ModuleType] = None, + implicit_returning: Optional[bool] = None, + supports_native_boolean: Optional[bool] = None, + max_identifier_length: Optional[int] = None, + label_length: Optional[int] = None, + # util.deprecated_params decorator cannot render the + # Linting.NO_LINTING constant + compiler_linting: Linting = int(compiler.NO_LINTING), # type: ignore + server_side_cursors: bool = False, + **kwargs: Any, ): - if server_side_cursors: if not self.supports_server_side_cursors: raise exc.ArgumentError( @@ -286,7 +290,9 @@ class DefaultDialect(Dialect): self.positional = False self._ischema = None + self.dbapi = dbapi + if paramstyle is not None: self.paramstyle = paramstyle elif self.dbapi is not None: @@ -299,11 +305,17 @@ class DefaultDialect(Dialect): self.identifier_preparer = self.preparer(self) self._on_connect_isolation_level = isolation_level - tt_callable = cast( - Type[compiler.GenericTypeCompiler], - self.type_compiler, - ) - self.type_compiler = tt_callable(self) + legacy_tt_callable = getattr(self, "type_compiler", None) + if legacy_tt_callable is not None: + tt_callable = cast( + Type[compiler.GenericTypeCompiler], + self.type_compiler, + ) + else: + tt_callable = self.type_compiler_cls + + self.type_compiler_instance = self.type_compiler = tt_callable(self) + if supports_native_boolean is not None: self.supports_native_boolean = supports_native_boolean @@ -315,6 +327,15 @@ class DefaultDialect(Dialect): self.label_length = label_length self.compiler_linting = compiler_linting + @util.memoized_property + def loaded_dbapi(self) -> ModuleType: + if self.dbapi is None: + raise exc.InvalidRequestError( + f"Dialect {self} does not have a Python DBAPI established " + "and cannot be used for actual database interaction" + ) + return self.dbapi + @util.memoized_property def _bind_typing_render_casts(self): return self.bind_typing is interfaces.BindTyping.RENDER_CASTS @@ -495,7 +516,7 @@ class DefaultDialect(Dialect): def connect(self, *cargs, **cparams): # inherits the docstring from interfaces.Dialect.connect - return self.dbapi.connect(*cargs, **cparams) + return self.loaded_dbapi.connect(*cargs, **cparams) def create_connect_args(self, url): # inherits the docstring from interfaces.Dialect.create_connect_args @@ -584,7 +605,7 @@ class DefaultDialect(Dialect): def _dialect_specific_select_one(self): return str(expression.select(1).compile(dialect=self)) - def do_ping(self, dbapi_connection): + def do_ping(self, dbapi_connection: DBAPIConnection) -> bool: cursor = None try: cursor = dbapi_connection.cursor() @@ -592,7 +613,7 @@ class DefaultDialect(Dialect): cursor.execute(self._dialect_specific_select_one) finally: cursor.close() - except self.dbapi.Error as err: + except self.loaded_dbapi.Error as err: if self.is_disconnect(err, dbapi_connection, cursor): return False else: @@ -747,7 +768,7 @@ class StrCompileDialect(DefaultDialect): statement_compiler = compiler.StrSQLCompiler ddl_compiler = compiler.DDLCompiler - type_compiler = compiler.StrSQLTypeCompiler # type: ignore + type_compiler_cls = compiler.StrSQLTypeCompiler preparer = compiler.IdentifierPreparer supports_statement_cache = True @@ -906,6 +927,8 @@ class DefaultExecutionContext(ExecutionContext): self.is_text = compiled.isplaintext if self.isinsert or self.isupdate or self.isdelete: + if TYPE_CHECKING: + assert isinstance(compiled.statement, UpdateBase) self.is_crud = True self._is_explicit_returning = bool(compiled.statement._returning) self._is_implicit_returning = bool( @@ -943,7 +966,7 @@ class DefaultExecutionContext(ExecutionContext): processors = compiled._bind_processors flattened_processors: Mapping[ - str, _ProcessorType + str, _BindProcessorType[Any] ] = processors # type: ignore[assignment] if compiled.literal_execute_params or compiled.post_compile_params: @@ -1354,7 +1377,7 @@ class DefaultExecutionContext(ExecutionContext): type_ = bindparam.type impl_type = type_.dialect_impl(self.dialect) - dbapi_type = impl_type.get_dbapi_type(self.dialect.dbapi) + dbapi_type = impl_type.get_dbapi_type(self.dialect.loaded_dbapi) result_processor = impl_type.result_processor( self.dialect, dbapi_type ) diff --git a/lib/sqlalchemy/engine/interfaces.py b/lib/sqlalchemy/engine/interfaces.py index 3ca30d1bc9..1b178641e1 100644 --- a/lib/sqlalchemy/engine/interfaces.py +++ b/lib/sqlalchemy/engine/interfaces.py @@ -14,12 +14,14 @@ from types import ModuleType from typing import Any from typing import Awaitable from typing import Callable +from typing import ClassVar from typing import Dict from typing import List from typing import Mapping from typing import MutableMapping from typing import Optional from typing import Sequence +from typing import Set from typing import Tuple from typing import Type from typing import TYPE_CHECKING @@ -60,6 +62,7 @@ if TYPE_CHECKING: from ..sql.schema import DefaultGenerator from ..sql.schema import Sequence as Sequence_SchemaItem from ..sql.sqltypes import Integer + from ..sql.type_api import _TypeMemoDict from ..sql.type_api import TypeEngine ConnectArgsType = Tuple[Tuple[str], MutableMapping[str, Any]] @@ -166,7 +169,7 @@ class DBAPICursor(Protocol): def execute( self, operation: Any, - parameters: Optional[_DBAPISingleExecuteParams], + parameters: Optional[_DBAPISingleExecuteParams] = None, ) -> Any: ... @@ -577,7 +580,7 @@ class Dialect(EventTarget): driver: str """identifying name for the dialect's DBAPI""" - dbapi: ModuleType + dbapi: Optional[ModuleType] """A reference to the DBAPI module object itself. SQLAlchemy dialects import DBAPI modules using the classmethod @@ -600,6 +603,16 @@ class Dialect(EventTarget): """ + @util.non_memoized_property + def loaded_dbapi(self) -> ModuleType: + """same as .dbapi, but is never None; will raise an error if no + DBAPI was set up. + + .. versionadded:: 2.0 + + """ + raise NotImplementedError() + positional: bool """True if the paramstyle for this Dialect is positional.""" @@ -616,8 +629,28 @@ class Dialect(EventTarget): ddl_compiler: Type[DDLCompiler] """a :class:`.Compiled` class used to compile DDL statements""" - type_compiler: Union[Type[TypeCompiler], TypeCompiler] - """a :class:`.Compiled` class used to compile SQL type objects""" + type_compiler_cls: ClassVar[Type[TypeCompiler]] + """a :class:`.Compiled` class used to compile SQL type objects + + .. versionadded:: 2.0 + + """ + + type_compiler_instance: TypeCompiler + """instance of a :class:`.Compiled` class used to compile SQL type + objects + + .. versionadded:: 2.0 + + """ + + type_compiler: Any + """legacy; this is a TypeCompiler class at the class level, a + TypeCompiler instance at the instance level. + + Refer to type_compiler_instance instead. + + """ preparer: Type[IdentifierPreparer] """a :class:`.IdentifierPreparer` class used to @@ -683,9 +716,17 @@ class Dialect(EventTarget): """ supports_default_values: bool - """Indicates if the construct ``INSERT INTO tablename DEFAULT - VALUES`` is supported - """ + """dialect supports INSERT... DEFAULT VALUES syntax""" + + supports_default_metavalue: bool + """dialect supports INSERT... VALUES (DEFAULT) syntax""" + + supports_empty_insert: bool + """dialect supports INSERT () VALUES ()""" + + supports_multivalues_insert: bool + """Target database supports INSERT...VALUES with multiple value + sets""" preexecute_autoincrement_sequences: bool """True if 'implicit' primary key functions must be executed separately @@ -723,6 +764,12 @@ class Dialect(EventTarget): other backends. """ + default_sequence_base: int + """the default value that will be rendered as the "START WITH" portion of + a CREATE SEQUENCE DDL statement. + + """ + supports_native_enum: bool """Indicates if the dialect supports a native ENUM construct. This will prevent :class:`_types.Enum` from generating a CHECK @@ -735,6 +782,10 @@ class Dialect(EventTarget): constraint when that type is used. """ + supports_native_decimal: bool + """indicates if Decimal objects are handled and returned for precision + numeric types, or if floats are returned""" + construct_arguments: Optional[ List[Tuple[Type[ClauseElement], Mapping[str, Any]]] ] = None @@ -842,6 +893,52 @@ class Dialect(EventTarget): """ + label_length: Optional[int] + """optional user-defined max length for SQL labels""" + + include_set_input_sizes: Optional[Set[Any]] + """set of DBAPI type objects that should be included in + automatic cursor.setinputsizes() calls. + + This is only used if bind_typing is BindTyping.SET_INPUT_SIZES + + """ + + exclude_set_input_sizes: Optional[Set[Any]] + """set of DBAPI type objects that should be excluded in + automatic cursor.setinputsizes() calls. + + This is only used if bind_typing is BindTyping.SET_INPUT_SIZES + + """ + + supports_simple_order_by_label: bool + """target database supports ORDER BY , where + refers to a label in the columns clause of the SELECT""" + + div_is_floordiv: bool + """target database treats the / division operator as "floor division" """ + + tuple_in_values: bool + """target database supports tuple IN, i.e. (x, y) IN ((q, p), (r, z))""" + + _bind_typing_render_casts: bool + + supports_identity_columns: bool + """target database supports IDENTITY""" + + cte_follows_insert: bool + """target database, when given a CTE with an INSERT statement, needs + the CTE to be below the INSERT""" + + insert_executemany_returning: bool + """dialect / driver / database supports some means of providing RETURNING + support when dialect.do_executemany() is used. + + """ + + _type_memos: MutableMapping[TypeEngine[Any], "_TypeMemoDict"] + def _builtin_onconnect(self) -> Optional[_ListenerFnType]: raise NotImplementedError() @@ -1495,7 +1592,7 @@ class Dialect(EventTarget): def is_disconnect( self, e: Exception, - connection: Optional[PoolProxiedConnection], + connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]], cursor: Optional[DBAPICursor], ) -> bool: """Return True if the given DB-API error indicates an invalid diff --git a/lib/sqlalchemy/engine/processors.py b/lib/sqlalchemy/engine/processors.py index 7a6a57c03f..5879267529 100644 --- a/lib/sqlalchemy/engine/processors.py +++ b/lib/sqlalchemy/engine/processors.py @@ -20,23 +20,29 @@ from ._py_processors import str_to_datetime_processor_factory # noqa from ..util._has_cy import HAS_CYEXTENSION if typing.TYPE_CHECKING or not HAS_CYEXTENSION: - from ._py_processors import int_to_boolean # noqa - from ._py_processors import str_to_date # noqa - from ._py_processors import str_to_datetime # noqa - from ._py_processors import str_to_time # noqa - from ._py_processors import to_decimal_processor_factory # noqa - from ._py_processors import to_float # noqa - from ._py_processors import to_str # noqa + from ._py_processors import int_to_boolean as int_to_boolean + from ._py_processors import str_to_date as str_to_date + from ._py_processors import str_to_datetime as str_to_datetime + from ._py_processors import str_to_time as str_to_time + from ._py_processors import ( + to_decimal_processor_factory as to_decimal_processor_factory, + ) + from ._py_processors import to_float as to_float + from ._py_processors import to_str as to_str else: from sqlalchemy.cyextension.processors import ( DecimalResultProcessor, ) # noqa - from sqlalchemy.cyextension.processors import int_to_boolean # noqa - from sqlalchemy.cyextension.processors import str_to_date # noqa - from sqlalchemy.cyextension.processors import str_to_datetime # noqa - from sqlalchemy.cyextension.processors import str_to_time # noqa - from sqlalchemy.cyextension.processors import to_float # noqa - from sqlalchemy.cyextension.processors import to_str # noqa + from sqlalchemy.cyextension.processors import ( + int_to_boolean as int_to_boolean, + ) + from sqlalchemy.cyextension.processors import str_to_date as str_to_date + from sqlalchemy.cyextension.processors import ( + str_to_datetime as str_to_datetime, + ) + from sqlalchemy.cyextension.processors import str_to_time as str_to_time + from sqlalchemy.cyextension.processors import to_float as to_float + from sqlalchemy.cyextension.processors import to_str as to_str def to_decimal_processor_factory(target_class, scale): # Note that the scale argument is not taken into account for integer diff --git a/lib/sqlalchemy/engine/result.py b/lib/sqlalchemy/engine/result.py index d428b8a9d4..05b06e8465 100644 --- a/lib/sqlalchemy/engine/result.py +++ b/lib/sqlalchemy/engine/result.py @@ -46,6 +46,7 @@ else: if typing.TYPE_CHECKING: from .row import RowMapping from ..sql.schema import Column + from ..sql.type_api import _ResultProcessorType _KeyType = Union[str, "Column[Any]"] _KeyIndexType = Union[str, "Column[Any]", int] @@ -70,8 +71,7 @@ across all the result types """ -_ProcessorType = Callable[[Any], Any] -_ProcessorsType = Sequence[Optional[_ProcessorType]] +_ProcessorsType = Sequence[Optional["_ResultProcessorType[Any]"]] _TupleGetterType = Callable[[Sequence[Any]], Tuple[Any, ...]] _UniqueFilterType = Callable[[Any], Any] _UniqueFilterStateType = Tuple[Set[Any], Optional[_UniqueFilterType]] diff --git a/lib/sqlalchemy/engine/row.py b/lib/sqlalchemy/engine/row.py index ff63199d40..4ba39b55d6 100644 --- a/lib/sqlalchemy/engine/row.py +++ b/lib/sqlalchemy/engine/row.py @@ -41,6 +41,7 @@ else: if typing.TYPE_CHECKING: from .result import _KeyType from .result import RMKeyView + from ..sql.type_api import _ResultProcessorType class Row(BaseRow, typing.Sequence[Any]): @@ -105,7 +106,7 @@ class Row(BaseRow, typing.Sequence[Any]): ) def _filter_on_values( - self, filters: Optional[Sequence[Optional[Callable[[Any], Any]]]] + self, filters: Optional[Sequence[Optional[_ResultProcessorType[Any]]]] ) -> Row: return Row( self._parent, diff --git a/lib/sqlalchemy/pool/base.py b/lib/sqlalchemy/pool/base.py index 5a9d28a289..9ab1477100 100644 --- a/lib/sqlalchemy/pool/base.py +++ b/lib/sqlalchemy/pool/base.py @@ -82,7 +82,7 @@ class _ConnDialect: def do_close(self, dbapi_connection: DBAPIConnection) -> None: dbapi_connection.close() - def do_ping(self, dbapi_connection: DBAPIConnection) -> None: + def do_ping(self, dbapi_connection: DBAPIConnection) -> bool: raise NotImplementedError( "The ping feature requires that a dialect is " "passed to the connection pool." diff --git a/lib/sqlalchemy/sql/_elements_constructors.py b/lib/sqlalchemy/sql/_elements_constructors.py index 770fbe40c9..aabd3871e1 100644 --- a/lib/sqlalchemy/sql/_elements_constructors.py +++ b/lib/sqlalchemy/sql/_elements_constructors.py @@ -58,7 +58,7 @@ if typing.TYPE_CHECKING: _T = TypeVar("_T") -def all_(expr: _ColumnExpression[_T]) -> CollectionAggregate[_T]: +def all_(expr: _ColumnExpression[_T]) -> CollectionAggregate[bool]: """Produce an ALL expression. For dialects such as that of PostgreSQL, this operator applies @@ -173,7 +173,7 @@ def and_(*clauses: _ColumnExpression[bool]) -> BooleanClauseList: return BooleanClauseList.and_(*clauses) -def any_(expr: _ColumnExpression[_T]) -> CollectionAggregate[_T]: +def any_(expr: _ColumnExpression[_T]) -> CollectionAggregate[bool]: """Produce an ANY expression. For dialects such as that of PostgreSQL, this operator applies diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index 29f9028c8b..6a6b389de8 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -1132,10 +1132,12 @@ class SchemaEventTarget: """ - def _set_parent(self, parent, **kw): + def _set_parent(self, parent: SchemaEventTarget, **kw: Any) -> None: """Associate with this SchemaEvent's parent object.""" - def _set_parent_with_dispatch(self, parent, **kw): + def _set_parent_with_dispatch( + self, parent: SchemaEventTarget, **kw: Any + ) -> None: self.dispatch.before_parent_attach(self, parent) self._set_parent(parent, **kw) self.dispatch.after_parent_attach(self, parent) diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 8f878b66c0..f8019b9c64 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -43,6 +43,7 @@ from typing import List from typing import Mapping from typing import MutableMapping from typing import NamedTuple +from typing import NoReturn from typing import Optional from typing import Sequence from typing import Set @@ -51,6 +52,7 @@ from typing import Type from typing import TYPE_CHECKING from typing import Union +from sqlalchemy.sql.ddl import DDLElement from . import base from . import coercions from . import crud @@ -61,7 +63,9 @@ from . import schema from . import selectable from . import sqltypes from .base import _from_objects +from .base import Executable from .base import NO_ARG +from .elements import ClauseElement from .elements import quoted_name from .schema import Column from .sqltypes import TupleType @@ -78,6 +82,10 @@ if typing.TYPE_CHECKING: from .base import _AmbiguousTableNameMap from .base import CompileState from .cache_key import CacheKey + from .dml import Insert + from .dml import UpdateBase + from .dml import ValuesBase + from .elements import _truncated_label from .elements import BindParameter from .elements import ColumnClause from .elements import Label @@ -91,12 +99,13 @@ if typing.TYPE_CHECKING: from .selectable import ReturnsRows from .selectable import Select from .selectable import SelectState + from .type_api import _BindProcessorType from ..engine.cursor import CursorResultMetaData from ..engine.interfaces import _CoreSingleExecuteParams from ..engine.interfaces import _ExecuteOptions from ..engine.interfaces import _MutableCoreSingleExecuteParams from ..engine.interfaces import _SchemaTranslateMapType - from ..engine.result import _ProcessorType + from ..engine.interfaces import Dialect _FromHintsType = Dict["FromClause", str] @@ -378,7 +387,7 @@ class _CompilerStackEntry(_BaseCompilerStackEntry, total=False): class ExpandedState(NamedTuple): statement: str additional_parameters: _CoreSingleExecuteParams - processors: Mapping[str, _ProcessorType] + processors: Mapping[str, _BindProcessorType[Any]] positiontup: Optional[Sequence[str]] parameter_expansion: Mapping[str, List[str]] @@ -531,11 +540,11 @@ class Compiled: def __init__( self, - dialect, - statement, - schema_translate_map=None, - render_schema_translate=False, - compile_kwargs=util.immutabledict(), + dialect: Dialect, + statement: Optional[ClauseElement], + schema_translate_map: Optional[_SchemaTranslateMapType] = None, + render_schema_translate: bool = False, + compile_kwargs: Mapping[str, Any] = util.immutabledict(), ): """Construct a new :class:`.Compiled` object. @@ -571,6 +580,8 @@ class Compiled: self.can_execute = statement.supports_execution self._annotations = statement._annotations if self.can_execute: + if TYPE_CHECKING: + assert isinstance(statement, Executable) self.execution_options = statement._execution_options self.string = self.process(self.statement, **compile_kwargs) @@ -636,10 +647,10 @@ class TypeCompiler(util.EnsureKWArg): ensure_kwarg = r"visit_\w+" - def __init__(self, dialect): + def __init__(self, dialect: Dialect): self.dialect = dialect - def process(self, type_, **kw): + def process(self, type_: TypeEngine[Any], **kw: Any) -> str: if ( type_._variant_mapping and self.dialect.name in type_._variant_mapping @@ -647,7 +658,9 @@ class TypeCompiler(util.EnsureKWArg): type_ = type_._variant_mapping[self.dialect.name] return type_._compiler_dispatch(self, **kw) - def visit_unsupported_compilation(self, element, err, **kw): + def visit_unsupported_compilation( + self, element: Any, err: Exception, **kw: Any + ) -> NoReturn: raise exc.UnsupportedCompilationError(self, element) from err @@ -877,13 +890,13 @@ class SQLCompiler(Compiled): def __init__( self, - dialect, - statement, - cache_key=None, - column_keys=None, - for_executemany=False, - linting=NO_LINTING, - **kwargs, + dialect: Dialect, + statement: Optional[ClauseElement], + cache_key: Optional[CacheKey] = None, + column_keys: Optional[Sequence[str]] = None, + for_executemany: bool = False, + linting: Linting = NO_LINTING, + **kwargs: Any, ): """Construct a new :class:`.SQLCompiler` object. @@ -954,15 +967,21 @@ class SQLCompiler(Compiled): # a map which tracks "truncated" names based on # dialect.label_length or dialect.max_identifier_length - self.truncated_names = {} + self.truncated_names: Dict[Tuple[str, str], str] = {} + self._truncated_counters: Dict[str, int] = {} Compiled.__init__(self, dialect, statement, **kwargs) if self.isinsert or self.isupdate or self.isdelete: + if TYPE_CHECKING: + assert isinstance(statement, UpdateBase) + if statement._returning: self.returning = statement._returning if self.isinsert or self.isupdate: + if TYPE_CHECKING: + assert isinstance(statement, ValuesBase) if statement._inline: self.inline = True elif self.for_executemany and ( @@ -1082,9 +1101,14 @@ class SQLCompiler(Compiled): @util.memoized_property def _bind_processors( self, - ) -> MutableMapping[str, Union[_ProcessorType, Sequence[_ProcessorType]]]: + ) -> MutableMapping[ + str, Union[_BindProcessorType[Any], Sequence[_BindProcessorType[Any]]] + ]: + + # mypy is not able to see the two value types as the above Union, + # it just sees "object". don't know how to resolve return dict( - (key, value) + (key, value) # type: ignore for key, value in ( ( self.bind_names[bindparam], @@ -1301,12 +1325,14 @@ class SQLCompiler(Compiled): positiontup = None processors = self._bind_processors - single_processors = cast("Mapping[str, _ProcessorType]", processors) + single_processors = cast( + "Mapping[str, _BindProcessorType[Any]]", processors + ) tuple_processors = cast( - "Mapping[str, Sequence[_ProcessorType]]", processors + "Mapping[str, Sequence[_BindProcessorType[Any]]]", processors ) - new_processors: Dict[str, _ProcessorType] = {} + new_processors: Dict[str, _BindProcessorType[Any]] = {} if self.positional and self._numeric_binds: # I'm not familiar with any DBAPI that uses 'numeric'. @@ -1484,6 +1510,10 @@ class SQLCompiler(Compiled): result = util.preloaded.engine_result param_key_getter = self._within_exec_param_key_getter + + if TYPE_CHECKING: + assert isinstance(self.statement, Insert) + table = self.statement.table getters = [ @@ -1530,6 +1560,9 @@ class SQLCompiler(Compiled): else: result = util.preloaded.engine_result + if TYPE_CHECKING: + assert isinstance(self.statement, Insert) + param_key_getter = self._within_exec_param_key_getter table = self.statement.table @@ -1796,7 +1829,9 @@ class SQLCompiler(Compiled): def visit_typeclause(self, typeclause, **kw): kw["type_expression"] = typeclause kw["identifier_preparer"] = self.preparer - return self.dialect.type_compiler.process(typeclause.type, **kw) + return self.dialect.type_compiler_instance.process( + typeclause.type, **kw + ) def post_process_text(self, text): if self.preparer._double_percents: @@ -2855,26 +2890,28 @@ class SQLCompiler(Compiled): return bind_name - def _truncated_identifier(self, ident_class, name): + def _truncated_identifier( + self, ident_class: str, name: _truncated_label + ) -> str: if (ident_class, name) in self.truncated_names: return self.truncated_names[(ident_class, name)] anonname = name.apply_map(self.anon_map) if len(anonname) > self.label_length - 6: - counter = self.truncated_names.get(ident_class, 1) + counter = self._truncated_counters.get(ident_class, 1) truncname = ( anonname[0 : max(self.label_length - 6, 0)] + "_" + hex(counter)[2:] ) - self.truncated_names[ident_class] = counter + 1 + self._truncated_counters[ident_class] = counter + 1 else: truncname = anonname self.truncated_names[(ident_class, name)] = truncname return truncname - def _anonymize(self, name): + def _anonymize(self, name: str) -> str: return name % self.anon_map def bindparam_string( @@ -3221,7 +3258,7 @@ class SQLCompiler(Compiled): % ( self.preparer.quote(col.name), " %s" - % self.dialect.type_compiler.process( + % self.dialect.type_compiler_instance.process( col.type, **kwargs ) if alias._render_derived_w_types @@ -4685,6 +4722,18 @@ class StrSQLCompiler(SQLCompiler): class DDLCompiler(Compiled): + if TYPE_CHECKING: + + def __init__( + self, + dialect: Dialect, + statement: DDLElement, + schema_translate_map: Optional[_SchemaTranslateMapType] = ..., + render_schema_translate: bool = ..., + compile_kwargs: Mapping[str, Any] = ..., + ): + ... + @util.memoized_property def sql_compiler(self): return self.dialect.statement_compiler( @@ -4693,7 +4742,7 @@ class DDLCompiler(Compiled): @util.memoized_property def type_compiler(self): - return self.dialect.type_compiler + return self.dialect.type_compiler_instance def construct_params( self, @@ -5010,7 +5059,7 @@ class DDLCompiler(Compiled): colspec = ( self.preparer.format_column(column) + " " - + self.dialect.type_compiler.process( + + self.dialect.type_compiler_instance.process( column.type, type_expression=column ) ) diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py index 96e90b0ea1..1271c5977c 100644 --- a/lib/sqlalchemy/sql/dml.py +++ b/lib/sqlalchemy/sql/dml.py @@ -433,7 +433,7 @@ class ValuesBase(UpdateBase): _multi_values = () _ordered_values = None _select_names = None - + _inline: bool = False _returning = () def __init__(self, table): @@ -742,7 +742,6 @@ class Insert(ValuesBase): select = None include_insert_from_select_defaults = False - _inline = False is_insert = True @@ -959,7 +958,6 @@ class Update(DMLWhereBase, ValuesBase): is_update = True _preserve_parameter_order = False - _inline = False _traverse_internals = ( [ diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 696d3c6f2e..48c3c3be66 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -258,6 +258,8 @@ class CompilerElement(Visitable): """Return a compiler appropriate for this ClauseElement, given a Dialect.""" + if TYPE_CHECKING: + assert isinstance(self, ClauseElement) return dialect.statement_compiler(dialect, self, **kw) def __str__(self) -> str: @@ -663,6 +665,11 @@ class DQLDMLClauseElement(ClauseElement): if typing.TYPE_CHECKING: + def _compiler(self, dialect: Dialect, **kw: Any) -> SQLCompiler: + """Return a compiler appropriate for this ClauseElement, given a + Dialect.""" + ... + def compile( # noqa: A001 self, bind: Optional[Union[Engine, Connection]] = None, @@ -671,9 +678,6 @@ class DQLDMLClauseElement(ClauseElement): ) -> SQLCompiler: ... - def _compiler(self, dialect: Dialect, **kw: Any) -> SQLCompiler: - ... - class CompilerColumnElement( roles.DMLColumnRole, @@ -1272,6 +1276,12 @@ class ColumnElement( _alt_names: Sequence[str] = () + @overload + def self_group( + self: ColumnElement[_T], against: Optional[OperatorType] = None + ) -> ColumnElement[_T]: + ... + @overload def self_group( self: ColumnElement[bool], against: Optional[OperatorType] = None @@ -1280,8 +1290,8 @@ class ColumnElement( @overload def self_group( - self: ColumnElement[_T], against: Optional[OperatorType] = None - ) -> ColumnElement[_T]: + self: ColumnElement[Any], against: Optional[OperatorType] = None + ) -> ColumnElement[Any]: ... def self_group( @@ -1777,7 +1787,7 @@ class BindParameter(roles.InElementRole, ColumnElement[_T]): value = None if quote is not None: - key = quoted_name(key, quote) + key = quoted_name.construct(key, quote) if unique: self.key = _anonymous_label.safe_construct( @@ -3121,7 +3131,11 @@ class UnaryExpression(ColumnElement[_T]): self.element = element.self_group( against=self.operator or self.modifier ) - self.type: TypeEngine[_T] = type_api.to_instance(type_) + + # if type is None, we get NULLTYPE, which is our _T. But I don't + # know how to get the overloads to express that correctly + self.type = type_api.to_instance(type_) # type: ignore + self.wraps_column_expression = wraps_column_expression @classmethod @@ -3224,27 +3238,32 @@ class CollectionAggregate(UnaryExpression[_T]): @classmethod def _create_any( cls, expr: _ColumnExpression[_T] - ) -> CollectionAggregate[_T]: - expr = coercions.expect(roles.ExpressionElementRole, expr) - - expr = expr.self_group() - return CollectionAggregate( + ) -> CollectionAggregate[bool]: + col_expr = coercions.expect( + roles.ExpressionElementRole, expr, + ) + col_expr = col_expr.self_group() + return CollectionAggregate( + col_expr, operator=operators.any_op, - type_=type_api.NULLTYPE, + type_=type_api.BOOLEANTYPE, wraps_column_expression=False, ) @classmethod def _create_all( cls, expr: _ColumnExpression[_T] - ) -> CollectionAggregate[_T]: - expr = coercions.expect(roles.ExpressionElementRole, expr) - expr = expr.self_group() - return CollectionAggregate( + ) -> CollectionAggregate[bool]: + col_expr = coercions.expect( + roles.ExpressionElementRole, expr, + ) + col_expr = col_expr.self_group() + return CollectionAggregate( + col_expr, operator=operators.all_op, - type_=type_api.NULLTYPE, + type_=type_api.BOOLEANTYPE, wraps_column_expression=False, ) @@ -3347,7 +3366,11 @@ class BinaryExpression(ColumnElement[_T]): self.left = left.self_group(against=operator) self.right = right.self_group(against=operator) self.operator = operator - self.type: TypeEngine[_T] = type_api.to_instance(type_) + + # if type is None, we get NULLTYPE, which is our _T. But I don't + # know how to get the overloads to express that correctly + self.type = type_api.to_instance(type_) # type: ignore + self.negate = negate self._is_implicitly_boolean = operators.is_boolean(operator) @@ -3509,7 +3532,9 @@ class Grouping(GroupedElement, ColumnElement[_T]): self, element: Union[TextClause, ClauseList, ColumnElement[_T]] ): self.element = element - self.type = getattr(element, "type", type_api.NULLTYPE) + + # nulltype assignment issue + self.type = getattr(element, "type", type_api.NULLTYPE) # type: ignore def _with_binary_element_type(self, type_): return self.__class__(self.element._with_binary_element_type(type_)) @@ -3926,10 +3951,13 @@ class Label(roles.LabeledColumnExprRole[_T], ColumnElement[_T]): self.key = self._tq_label = self._tq_key_label = self.name self._element = element - # self._type = type_ - self.type = type_api.to_instance( - type_ or getattr(self._element, "type", None) + + self.type = ( + type_api.to_instance(type_) + if type_ is not None + else self._element.type ) + self._proxies = [element] def __reduce__(self): @@ -4178,7 +4206,11 @@ class ColumnClause( ): self.key = self.name = text self.table = _selectable - self.type: TypeEngine[_T] = type_api.to_instance(type_) + + # if type is None, we get NULLTYPE, which is our _T. But I don't + # know how to get the overloads to express that correctly + self.type = type_api.to_instance(type_) # type: ignore + self.is_literal = is_literal def get_children(self, column_tables=False, **kw): @@ -4465,19 +4497,32 @@ class quoted_name(util.MemoizedSlots, str): quote: Optional[bool] - def __new__(cls, value, quote): + @overload + @classmethod + def construct(cls, value: str, quote: Optional[bool]) -> quoted_name: + ... + + @overload + @classmethod + def construct(cls, value: None, quote: Optional[bool]) -> None: + ... + + @classmethod + def construct( + cls, value: Optional[str], quote: Optional[bool] + ) -> Optional[quoted_name]: if value is None: return None - # experimental - don't bother with quoted_name - # if quote flag is None. doesn't seem to make any dent - # in performance however - # elif not sprcls and quote is None: - # return value - elif isinstance(value, cls) and ( - quote is None or value.quote == quote - ): + else: + return quoted_name(value, quote) + + def __new__(cls, value: str, quote: Optional[bool]) -> quoted_name: + assert ( + value is not None + ), "use quoted_name.construct() for None passthrough" + if isinstance(value, cls) and (quote is None or value.quote == quote): return value - self = super(quoted_name, cls).__new__(cls, value) + self = super().__new__(cls, value) self.quote = quote return self @@ -4579,15 +4624,15 @@ class _truncated_label(quoted_name): __slots__ = () - def __new__(cls, value, quote=None): + def __new__(cls, value: str, quote: Optional[bool] = None) -> Any: quote = getattr(value, "quote", quote) # return super(_truncated_label, cls).__new__(cls, value, quote, True) return super(_truncated_label, cls).__new__(cls, value, quote) - def __reduce__(self): + def __reduce__(self) -> Any: return self.__class__, (str(self), self.quote) - def apply_map(self, map_): + def apply_map(self, map_: Mapping[str, Any]) -> str: return self diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index 78d5241278..5cfb55603f 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -3077,7 +3077,7 @@ class Sequence(HasSchemaAttr, IdentityOptions, DefaultGenerator): elif metadata is not None and schema is None and metadata.schema: self.schema = schema = metadata.schema else: - self.schema = quoted_name(schema, quote_schema) + self.schema = quoted_name.construct(schema, quote_schema) self.metadata = metadata self._key = _get_table_key(name, schema) if metadata: @@ -4258,7 +4258,7 @@ class Index(DialectKWArgs, ColumnCollectionMixin, SchemaItem): """ self.table = table = None - self.name = quoted_name(name, kw.pop("quote", None)) + self.name = quoted_name.construct(name, kw.pop("quote", None)) self.unique = kw.pop("unique", False) _column_flag = kw.pop("_column_flag", False) if "info" in kw: @@ -4493,7 +4493,7 @@ class MetaData(HasSchemaAttr): """ self.tables = util.FacadeDict() - self.schema = quoted_name(schema, quote_schema) + self.schema = quoted_name.construct(schema, quote_schema) self.naming_convention = ( naming_convention if naming_convention diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index 4d0169370c..829c1b72e7 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -17,9 +17,15 @@ import enum import json import pickle from typing import Any +from typing import cast +from typing import Dict +from typing import List +from typing import Optional from typing import overload from typing import Sequence from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING from typing import TypeVar from typing import Union @@ -40,6 +46,7 @@ from .type_api import NativeForEmulated # noqa from .type_api import to_instance from .type_api import TypeDecorator from .type_api import TypeEngine +from .type_api import TypeEngineMixin from .type_api import Variant # noqa from .visitors import InternalTraversal from .. import event @@ -51,11 +58,19 @@ from ..util import langhelpers from ..util import OrderedDict from ..util.typing import Literal +if TYPE_CHECKING: + from .operators import OperatorType + from .type_api import _BindProcessorType + from .type_api import _ComparatorFactory + from .type_api import _ResultProcessorType + from ..engine.interfaces import Dialect _T = TypeVar("_T", bound="Any") +_CT = TypeVar("_CT", bound=Any) +_TE = TypeVar("_TE", bound="TypeEngine[Any]") -class _LookupExpressionAdapter: +class HasExpressionLookup(TypeEngineMixin): """Mixin expression adaptations based on lookup tables. @@ -68,11 +83,18 @@ class _LookupExpressionAdapter: def _expression_adaptations(self): raise NotImplementedError() - class Comparator(TypeEngine.Comparator[_T]): - _blank_dict = util.immutabledict() + class Comparator(TypeEngine.Comparator[_CT]): + + _blank_dict = util.EMPTY_DICT - def _adapt_expression(self, op, other_comparator): + def _adapt_expression( + self, + op: OperatorType, + other_comparator: TypeEngine.Comparator[Any], + ) -> Tuple[OperatorType, TypeEngine[Any]]: othertype = other_comparator.type._type_affinity + if TYPE_CHECKING: + assert isinstance(self.type, HasExpressionLookup) lookup = self.type._expression_adaptations.get( op, self._blank_dict ).get(othertype, self.type) @@ -83,16 +105,20 @@ class _LookupExpressionAdapter: else: return (op, to_instance(lookup)) - comparator_factory = Comparator + comparator_factory: _ComparatorFactory[Any] = Comparator -class Concatenable: +class Concatenable(TypeEngineMixin): """A mixin that marks a type as supporting 'concatenation', typically strings.""" class Comparator(TypeEngine.Comparator[_T]): - def _adapt_expression(self, op, other_comparator): + def _adapt_expression( + self, + op: OperatorType, + other_comparator: TypeEngine.Comparator[Any], + ) -> Tuple[OperatorType, TypeEngine[Any]]: if op is operators.add and isinstance( other_comparator, (Concatenable.Comparator, NullType.Comparator), @@ -103,10 +129,10 @@ class Concatenable: op, other_comparator ) - comparator_factory = Comparator + comparator_factory: _ComparatorFactory[Any] = Comparator -class Indexable: +class Indexable(TypeEngineMixin): """A mixin that marks a type as supporting indexing operations, such as array or JSON structures. @@ -151,7 +177,7 @@ class String(Concatenable, TypeEngine[str]): # note pylance appears to require the "self" type in a constructor # for the _T type to be correctly recognized when we send the # class as the argument, e.g. `column("somecol", String)` - self: "String", + self, length=None, collation=None, ): @@ -313,7 +339,7 @@ class UnicodeText(Text): super(UnicodeText, self).__init__(length=length, **kwargs) -class Integer(_LookupExpressionAdapter, TypeEngine[int]): +class Integer(HasExpressionLookup, TypeEngine[int]): """A type for ``int`` integers.""" @@ -378,7 +404,7 @@ class BigInteger(Integer): _N = TypeVar("_N", bound=Union[decimal.Decimal, float]) -class Numeric(_LookupExpressionAdapter, TypeEngine[_N]): +class Numeric(HasExpressionLookup, TypeEngine[_N]): """A type for fixed precision numbers, such as ``NUMERIC`` or ``DECIMAL``. @@ -423,7 +449,7 @@ class Numeric(_LookupExpressionAdapter, TypeEngine[_N]): _default_decimal_return_scale = 10 def __init__( - self: "Numeric", + self, precision=None, scale=None, decimal_return_scale=None, @@ -573,34 +599,26 @@ class Float(Numeric[_N]): @overload def __init__( self: Float[float], - precision=..., - decimal_return_scale=..., + precision: Optional[int] = ..., + asdecimal: Literal[False] = ..., + decimal_return_scale: Optional[int] = ..., ): ... @overload def __init__( self: Float[decimal.Decimal], - precision=..., + precision: Optional[int] = ..., asdecimal: Literal[True] = ..., - decimal_return_scale=..., - ): - ... - - @overload - def __init__( - self: Float[float], - precision=..., - asdecimal: Literal[False] = ..., - decimal_return_scale=..., + decimal_return_scale: Optional[int] = ..., ): ... def __init__( self: Float[_N], - precision=None, - asdecimal=False, - decimal_return_scale=None, + precision: Optional[int] = None, + asdecimal: bool = False, + decimal_return_scale: Optional[int] = None, ): r""" Construct a Float. @@ -662,7 +680,7 @@ class Float(Numeric[_N]): return None -class Double(Float): +class Double(Float[_N]): """A type for double ``FLOAT`` floating point types. Typically generates a ``DOUBLE`` or ``DOUBLE_PRECISION`` in DDL, @@ -676,7 +694,7 @@ class Double(Float): __visit_name__ = "double" -class DateTime(_LookupExpressionAdapter, TypeEngine[dt.datetime]): +class DateTime(HasExpressionLookup, TypeEngine[dt.datetime]): """A type for ``datetime.datetime()`` objects. @@ -738,7 +756,7 @@ class DateTime(_LookupExpressionAdapter, TypeEngine[dt.datetime]): } -class Date(_LookupExpressionAdapter, TypeEngine[dt.date]): +class Date(HasExpressionLookup, TypeEngine[dt.date]): """A type for ``datetime.date()`` objects.""" @@ -776,7 +794,7 @@ class Date(_LookupExpressionAdapter, TypeEngine[dt.date]): } -class Time(_LookupExpressionAdapter, TypeEngine[dt.time]): +class Time(HasExpressionLookup, TypeEngine[dt.time]): """A type for ``datetime.time()`` objects.""" @@ -895,9 +913,10 @@ class LargeBinary(_Binary): _Binary.__init__(self, length=length) -class SchemaType(SchemaEventTarget): +class SchemaType(SchemaEventTarget, TypeEngineMixin): - """Mark a type as possibly requiring schema-level DDL for usage. + """Add capabilities to a type which allow for schema-level DDL to be + associated with a type. Supports types that must be explicitly created/dropped (i.e. PG ENUM type) as well as types that are complimented by table or schema level @@ -920,6 +939,8 @@ class SchemaType(SchemaEventTarget): _use_schema_map = True + name: Optional[str] + def __init__( self, name=None, @@ -1021,33 +1042,37 @@ class SchemaType(SchemaEventTarget): ) def copy(self, **kw): - return self.adapt(self.__class__, _create_events=True) - - def adapt(self, impltype, **kw): - schema = kw.pop("schema", self.schema) - metadata = kw.pop("metadata", self.metadata) - _create_events = kw.pop("_create_events", False) - return impltype( - name=self.name, - schema=schema, - inherit_schema=self.inherit_schema, - metadata=metadata, - _create_events=_create_events, - **kw, + return self.adapt( + cast("Type[TypeEngine[Any]]", self.__class__), + _create_events=True, ) + @overload + def adapt(self, cls: Type[_TE], **kw: Any) -> _TE: + ... + + @overload + def adapt(self, cls: Type[TypeEngineMixin], **kw: Any) -> TypeEngine[Any]: + ... + + def adapt( + self, cls: Type[Union[TypeEngine[Any], TypeEngineMixin]], **kw: Any + ) -> TypeEngine[Any]: + kw.setdefault("_create_events", False) + return super().adapt(cls, **kw) + def create(self, bind, checkfirst=False): """Issue CREATE DDL for this type, if applicable.""" t = self.dialect_impl(bind.dialect) - if t.__class__ is not self.__class__ and isinstance(t, SchemaType): + if isinstance(t, SchemaType) and t.__class__ is not self.__class__: t.create(bind, checkfirst=checkfirst) def drop(self, bind, checkfirst=False): """Issue DROP DDL for this type, if applicable.""" t = self.dialect_impl(bind.dialect) - if t.__class__ is not self.__class__ and isinstance(t, SchemaType): + if isinstance(t, SchemaType) and t.__class__ is not self.__class__: t.drop(bind, checkfirst=checkfirst) def _on_table_create(self, target, bind, **kw): @@ -1055,7 +1080,7 @@ class SchemaType(SchemaEventTarget): return t = self.dialect_impl(bind.dialect) - if t.__class__ is not self.__class__ and isinstance(t, SchemaType): + if isinstance(t, SchemaType) and t.__class__ is not self.__class__: t._on_table_create(target, bind, **kw) def _on_table_drop(self, target, bind, **kw): @@ -1063,7 +1088,7 @@ class SchemaType(SchemaEventTarget): return t = self.dialect_impl(bind.dialect) - if t.__class__ is not self.__class__ and isinstance(t, SchemaType): + if isinstance(t, SchemaType) and t.__class__ is not self.__class__: t._on_table_drop(target, bind, **kw) def _on_metadata_create(self, target, bind, **kw): @@ -1071,7 +1096,7 @@ class SchemaType(SchemaEventTarget): return t = self.dialect_impl(bind.dialect) - if t.__class__ is not self.__class__ and isinstance(t, SchemaType): + if isinstance(t, SchemaType) and t.__class__ is not self.__class__: t._on_metadata_create(target, bind, **kw) def _on_metadata_drop(self, target, bind, **kw): @@ -1079,7 +1104,7 @@ class SchemaType(SchemaEventTarget): return t = self.dialect_impl(bind.dialect) - if t.__class__ is not self.__class__ and isinstance(t, SchemaType): + if isinstance(t, SchemaType) and t.__class__ is not self.__class__: t._on_metadata_drop(target, bind, **kw) def _is_impl_for_variant(self, dialect, kw): @@ -1112,7 +1137,7 @@ class SchemaType(SchemaEventTarget): return _we_are_the_impl(variant_mapping["_default"]) -class Enum(Emulated, String, TypeEngine[Union[str, enum.Enum]], SchemaType): +class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]): """Generic Enum Type. The :class:`.Enum` type provides a set of possible string values @@ -1464,8 +1489,14 @@ class Enum(Emulated, String, TypeEngine[Union[str, enum.Enum]], SchemaType): ) ) from err - class Comparator(String.Comparator[_T]): - def _adapt_expression(self, op, other_comparator): + class Comparator(String.Comparator[str]): + type: String + + def _adapt_expression( + self, + op: OperatorType, + other_comparator: TypeEngine.Comparator[Any], + ) -> Tuple[OperatorType, TypeEngine[Any]]: op, typ = super(Enum.Comparator, self)._adapt_expression( op, other_comparator ) @@ -1663,15 +1694,16 @@ class PickleType(TypeDecorator[object]): return PickleType, (self.protocol, None, self.comparator) def bind_processor(self, dialect): - impl_processor = self.impl.bind_processor(dialect) + impl_processor = self.impl_instance.bind_processor(dialect) dumps = self.pickler.dumps protocol = self.protocol if impl_processor: + fixed_impl_processor = impl_processor def process(value): if value is not None: value = dumps(value, protocol) - return impl_processor(value) + return fixed_impl_processor(value) else: @@ -1683,12 +1715,13 @@ class PickleType(TypeDecorator[object]): return process def result_processor(self, dialect, coltype): - impl_processor = self.impl.result_processor(dialect, coltype) + impl_processor = self.impl_instance.result_processor(dialect, coltype) loads = self.pickler.loads if impl_processor: + fixed_impl_processor = impl_processor def process(value): - value = impl_processor(value) + value = fixed_impl_processor(value) if value is None: return None return loads(value) @@ -1709,7 +1742,7 @@ class PickleType(TypeDecorator[object]): return x == y -class Boolean(Emulated, TypeEngine[bool], SchemaType): +class Boolean(SchemaType, Emulated, TypeEngine[bool]): """A bool datatype. @@ -1733,7 +1766,7 @@ class Boolean(Emulated, TypeEngine[bool], SchemaType): native = True def __init__( - self: "Boolean", + self, create_constraint=False, name=None, _create_events=True, @@ -1818,6 +1851,9 @@ class Boolean(Emulated, TypeEngine[bool], SchemaType): def bind_processor(self, dialect): _strict_as_bool = self._strict_as_bool + + _coerce: Union[Type[bool], Type[int]] + if dialect.supports_native_boolean: _coerce = bool else: @@ -1838,7 +1874,7 @@ class Boolean(Emulated, TypeEngine[bool], SchemaType): return processors.int_to_boolean -class _AbstractInterval(_LookupExpressionAdapter, TypeEngine[dt.timedelta]): +class _AbstractInterval(HasExpressionLookup, TypeEngine[dt.timedelta]): @util.memoized_property def _expression_adaptations(self): # Based on https://www.postgresql.org/docs/current/\ @@ -1856,16 +1892,12 @@ class _AbstractInterval(_LookupExpressionAdapter, TypeEngine[dt.timedelta]): operators.truediv: {Numeric: self.__class__}, } - @property - def _type_affinity(self): + @util.non_memoized_property + def _type_affinity(self) -> Optional[Type[TypeEngine[Any]]]: return Interval - def coerce_compared_value(self, op, value): - """See :meth:`.TypeEngine.coerce_compared_value` for a description.""" - return self.impl.coerce_compared_value(op, value) - -class Interval(Emulated, _AbstractInterval, TypeDecorator): +class Interval(Emulated, _AbstractInterval, TypeDecorator[dt.timedelta]): """A type for ``datetime.timedelta()`` objects. @@ -1909,6 +1941,14 @@ class Interval(Emulated, _AbstractInterval, TypeDecorator): self.second_precision = second_precision self.day_precision = day_precision + class Comparator( + TypeDecorator.Comparator[_CT], + _AbstractInterval.Comparator[_CT], + ): + pass + + comparator_factory = Comparator + @property def python_type(self): return dt.timedelta @@ -1916,42 +1956,63 @@ class Interval(Emulated, _AbstractInterval, TypeDecorator): def adapt_to_emulated(self, impltype, **kw): return _AbstractInterval.adapt(self, impltype, **kw) - def bind_processor(self, dialect): - impl_processor = self.impl.bind_processor(dialect) + def coerce_compared_value(self, op, value): + return self.impl_instance.coerce_compared_value(op, value) + + def bind_processor( + self, dialect: Dialect + ) -> _BindProcessorType[dt.timedelta]: + if TYPE_CHECKING: + assert isinstance(self.impl_instance, DateTime) + impl_processor = self.impl_instance.bind_processor(dialect) epoch = self.epoch if impl_processor: + fixed_impl_processor = impl_processor - def process(value): + def process( + value: Optional[dt.timedelta], + ) -> Any: if value is not None: - value = epoch + value - return impl_processor(value) + dt_value = epoch + value + else: + dt_value = None + return fixed_impl_processor(dt_value) else: - def process(value): + def process( + value: Optional[dt.timedelta], + ) -> Any: if value is not None: - value = epoch + value - return value + dt_value = epoch + value + else: + dt_value = None + return dt_value return process - def result_processor(self, dialect, coltype): - impl_processor = self.impl.result_processor(dialect, coltype) + def result_processor( + self, dialect: Dialect, coltype: Any + ) -> _ResultProcessorType[dt.timedelta]: + if TYPE_CHECKING: + assert isinstance(self.impl_instance, DateTime) + impl_processor = self.impl_instance.result_processor(dialect, coltype) epoch = self.epoch if impl_processor: + fixed_impl_processor = impl_processor - def process(value): - value = impl_processor(value) - if value is None: + def process(value: Any) -> Optional[dt.timedelta]: + dt_value = fixed_impl_processor(value) + if dt_value is None: return None - return value - epoch + return dt_value - epoch else: - def process(value): + def process(value: Any) -> Optional[dt.timedelta]: if value is None: return None - return value - epoch + return value - epoch # type: ignore return process @@ -2233,7 +2294,7 @@ class JSON(Indexable, TypeEngine[Any]): """ self.none_as_null = none_as_null - class JSONElementType(TypeEngine): + class JSONElementType(TypeEngine[Any]): """Common function for index / path elements in a JSON expression.""" _integer = Integer() @@ -2457,7 +2518,7 @@ class JSON(Indexable, TypeEngine[Any]): def python_type(self): return dict - @property + @property # type: ignore # mypy property bug def should_evaluate_none(self): """Alias of :attr:`_types.JSON.none_as_null`""" return not self.none_as_null @@ -2632,23 +2693,26 @@ class ARRAY( """ def _setup_getitem(self, index): + + arr_type = cast(ARRAY, self.type) + if isinstance(index, slice): - return_type = self.type - if self.type.zero_indexes: + return_type = arr_type + if arr_type.zero_indexes: index = slice(index.start + 1, index.stop + 1, index.step) slice_ = Slice( index.start, index.stop, index.step, _name=self.expr.key ) return operators.getitem, slice_, return_type else: - if self.type.zero_indexes: + if arr_type.zero_indexes: index += 1 - if self.type.dimensions is None or self.type.dimensions == 1: - return_type = self.type.item_type + if arr_type.dimensions is None or arr_type.dimensions == 1: + return_type = arr_type.item_type else: - adapt_kw = {"dimensions": self.type.dimensions - 1} - return_type = self.type.adapt( - self.type.__class__, **adapt_kw + adapt_kw = {"dimensions": arr_type.dimensions - 1} + return_type = arr_type.adapt( + arr_type.__class__, **adapt_kw ) return operators.getitem, index, return_type @@ -2853,7 +2917,7 @@ class TupleType(TypeEngine[Tuple[Any, ...]]): ) -class REAL(Float): +class REAL(Float[_N]): """The SQL REAL type. @@ -2866,7 +2930,7 @@ class REAL(Float): __visit_name__ = "REAL" -class FLOAT(Float): +class FLOAT(Float[_N]): """The SQL FLOAT type. @@ -2879,7 +2943,7 @@ class FLOAT(Float): __visit_name__ = "FLOAT" -class DOUBLE(Double): +class DOUBLE(Double[_N]): """The SQL DOUBLE type. .. versionadded:: 2.0 @@ -2893,7 +2957,7 @@ class DOUBLE(Double): __visit_name__ = "DOUBLE" -class DOUBLE_PRECISION(Double): +class DOUBLE_PRECISION(Double[_N]): """The SQL DOUBLE PRECISION type. .. versionadded:: 2.0 @@ -2907,7 +2971,7 @@ class DOUBLE_PRECISION(Double): __visit_name__ = "DOUBLE_PRECISION" -class NUMERIC(Numeric): +class NUMERIC(Numeric[_N]): """The SQL NUMERIC type. @@ -2920,7 +2984,7 @@ class NUMERIC(Numeric): __visit_name__ = "NUMERIC" -class DECIMAL(Numeric): +class DECIMAL(Numeric[_N]): """The SQL DECIMAL type. @@ -3099,7 +3163,7 @@ class BOOLEAN(Boolean): __visit_name__ = "BOOLEAN" -class NullType(TypeEngine): +class NullType(TypeEngine[None]): """An unknown type. @@ -3139,7 +3203,11 @@ class NullType(TypeEngine): return process class Comparator(TypeEngine.Comparator[_T]): - def _adapt_expression(self, op, other_comparator): + def _adapt_expression( + self, + op: OperatorType, + other_comparator: TypeEngine.Comparator[Any], + ) -> Tuple[OperatorType, TypeEngine[Any]]: if isinstance( other_comparator, NullType.Comparator ) or not operators.is_commutative(op): @@ -3150,7 +3218,7 @@ class NullType(TypeEngine): comparator_factory = Comparator -class TableValueType(HasCacheKey, TypeEngine): +class TableValueType(HasCacheKey, TypeEngine[Any]): """Refers to a table value type.""" _is_table_value = True @@ -3195,7 +3263,7 @@ _TIME = Time() _STRING = String() _UNICODE = Unicode() -_type_map = { +_type_map: Dict[Type[Any], TypeEngine[Any]] = { int: Integer(), float: Float(), bool: BOOLEANTYPE, @@ -3204,7 +3272,7 @@ _type_map = { dt.datetime: _DATETIME, dt.time: _TIME, dt.timedelta: Interval(), - util.NoneType: NULLTYPE, + type(None): NULLTYPE, bytes: LargeBinary(), str: _STRING, } @@ -3213,7 +3281,7 @@ _type_map = { _type_map_get = _type_map.get -def _resolve_value_to_type(value): +def _resolve_value_to_type(value: Any) -> TypeEngine[Any]: _result_type = _type_map_get(type(value), False) if _result_type is False: # use inspect() to detect SQLAlchemy built-in @@ -3231,7 +3299,9 @@ def _resolve_value_to_type(value): ) return NULLTYPE else: - return _result_type._resolve_for_literal(value) + return _result_type._resolve_for_literal( # type: ignore [union-attr] + value + ) # back-assign to type_api @@ -3240,7 +3310,6 @@ type_api.STRINGTYPE = STRINGTYPE type_api.INTEGERTYPE = INTEGERTYPE type_api.NULLTYPE = NULLTYPE type_api.MATCHTYPE = MATCHTYPE -type_api.INDEXABLE = Indexable +type_api.INDEXABLE = INDEXABLE = Indexable type_api.TABLEVALUE = TABLEVALUE type_api._resolve_value_to_type = _resolve_value_to_type -TypeEngine.Comparator.BOOLEANTYPE = BOOLEANTYPE diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py index 55997556af..5a0aba694f 100644 --- a/lib/sqlalchemy/sql/type_api.py +++ b/lib/sqlalchemy/sql/type_api.py @@ -11,41 +11,92 @@ from __future__ import annotations +from types import ModuleType import typing from typing import Any from typing import Callable +from typing import cast +from typing import Dict from typing import Generic +from typing import Mapping from typing import Optional +from typing import overload +from typing import Sequence from typing import Tuple from typing import Type +from typing import TYPE_CHECKING from typing import TypeVar from typing import Union from .base import SchemaEventTarget +from .cache_key import CacheConst from .cache_key import NO_CACHE from .operators import ColumnOperators from .visitors import Visitable from .. import exc from .. import util +from ..util.typing import Protocol +from ..util.typing import TypedDict +from ..util.typing import TypeGuard # these are back-assigned by sqltypes. if typing.TYPE_CHECKING: + from .elements import BindParameter from .elements import ColumnElement from .operators import OperatorType + from .schema import Column from .sqltypes import _resolve_value_to_type as _resolve_value_to_type from .sqltypes import BOOLEANTYPE as BOOLEANTYPE - from .sqltypes import Indexable as INDEXABLE + from .sqltypes import INDEXABLE as INDEXABLE from .sqltypes import INTEGERTYPE as INTEGERTYPE from .sqltypes import MATCHTYPE as MATCHTYPE from .sqltypes import NULLTYPE as NULLTYPE + from .sqltypes import NullType + from .sqltypes import STRINGTYPE as STRINGTYPE + from .sqltypes import TABLEVALUE as TABLEVALUE + from ..engine.interfaces import Dialect _T = TypeVar("_T", bound=Any) +_T_co = TypeVar("_T_co", bound=Any, covariant=True) +_T_con = TypeVar("_T_con", bound=Any, contravariant=True) +_O = TypeVar("_O", bound=object) _TE = TypeVar("_TE", bound="TypeEngine[Any]") _CT = TypeVar("_CT", bound=Any) # replace with pep-673 when applicable -SelfTypeEngine = typing.TypeVar("SelfTypeEngine", bound="TypeEngine") +SelfTypeEngine = typing.TypeVar("SelfTypeEngine", bound="TypeEngine[Any]") + + +class _LiteralProcessorType(Protocol[_T_co]): + def __call__(self, value: Any) -> str: + ... + + +class _BindProcessorType(Protocol[_T_con]): + def __call__(self, value: Optional[_T_con]) -> Any: + ... + + +class _ResultProcessorType(Protocol[_T_co]): + def __call__(self, value: Any) -> Optional[_T_co]: + ... + + +class _BaseTypeMemoDict(TypedDict): + impl: TypeEngine[Any] + result: Dict[Any, Optional[_ResultProcessorType[Any]]] + + +class _TypeMemoDict(_BaseTypeMemoDict, total=False): + literal: Optional[_LiteralProcessorType[Any]] + bind: Optional[_BindProcessorType[Any]] + custom: Dict[Any, object] + + +class _ComparatorFactory(Protocol[_T]): + def __call__(self, expr: ColumnElement[_T]) -> TypeEngine.Comparator[_T]: + ... class TypeEngine(Visitable, Generic[_T]): @@ -70,8 +121,6 @@ class TypeEngine(Visitable, Generic[_T]): _is_array = False _is_type_decorator = False - _block_from_type_affinity = False - render_bind_cast = False """Render bind casts for :attr:`.BindTyping.RENDER_CASTS` mode. @@ -99,38 +148,41 @@ class TypeEngine(Visitable, Generic[_T]): __slots__ = "expr", "type" - default_comparator = None + expr: ColumnElement[_CT] + type: TypeEngine[_CT] - def __clause_element__(self): + def __clause_element__(self) -> ColumnElement[_CT]: return self.expr - def __init__(self, expr: "ColumnElement[_CT]"): + def __init__(self, expr: ColumnElement[_CT]): self.expr = expr - self.type: TypeEngine[_CT] = expr.type + self.type = expr.type @util.preload_module("sqlalchemy.sql.default_comparator") def operate( - self, op: "OperatorType", *other, **kwargs - ) -> "ColumnElement": + self, op: OperatorType, *other: Any, **kwargs: Any + ) -> ColumnElement[_CT]: default_comparator = util.preloaded.sql_default_comparator op_fn, addtl_kw = default_comparator.operator_lookup[op.__name__] if kwargs: addtl_kw = addtl_kw.union(kwargs) - return op_fn(self.expr, op, *other, **addtl_kw) + return op_fn(self.expr, op, *other, **addtl_kw) # type: ignore @util.preload_module("sqlalchemy.sql.default_comparator") def reverse_operate( - self, op: "OperatorType", other, **kwargs - ) -> "ColumnElement": + self, op: OperatorType, other: Any, **kwargs: Any + ) -> ColumnElement[_CT]: default_comparator = util.preloaded.sql_default_comparator op_fn, addtl_kw = default_comparator.operator_lookup[op.__name__] if kwargs: addtl_kw = addtl_kw.union(kwargs) - return op_fn(self.expr, op, other, reverse=True, **addtl_kw) + return op_fn(self.expr, op, other, reverse=True, **addtl_kw) # type: ignore # noqa E501 def _adapt_expression( - self, op: "OperatorType", other_comparator - ) -> Tuple["OperatorType", "TypeEngine[_CT]"]: + self, + op: OperatorType, + other_comparator: TypeEngine.Comparator[Any], + ) -> Tuple[OperatorType, TypeEngine[Any]]: """evaluate the return type of , and apply any adaptations to the given operator. @@ -159,7 +211,7 @@ class TypeEngine(Visitable, Generic[_T]): return op, self.type - def __reduce__(self): + def __reduce__(self) -> Any: return _reconstitute_comparator, (self.expr,) hashable = True @@ -169,7 +221,7 @@ class TypeEngine(Visitable, Generic[_T]): """ - comparator_factory = Comparator + comparator_factory: _ComparatorFactory[Any] = Comparator """A :class:`.TypeEngine.Comparator` class which will apply to operations performed by owning :class:`_expression.ColumnElement` objects. @@ -193,7 +245,7 @@ class TypeEngine(Visitable, Generic[_T]): """ - sort_key_function = None + sort_key_function: Optional[Callable[[Any], Any]] = None """A sorting function that can be passed as the key to sorted. The default value of ``None`` indicates that the values stored by @@ -203,7 +255,7 @@ class TypeEngine(Visitable, Generic[_T]): """ - should_evaluate_none = False + should_evaluate_none: bool = False """If True, the Python constant ``None`` is considered to be handled explicitly by this type. @@ -226,9 +278,11 @@ class TypeEngine(Visitable, Generic[_T]): """ - _variant_mapping = util.EMPTY_DICT + _variant_mapping: util.immutabledict[ + str, TypeEngine[Any] + ] = util.EMPTY_DICT - def evaluates_none(self): + def evaluates_none(self: SelfTypeEngine) -> SelfTypeEngine: """Return a copy of this type which has the :attr:`.should_evaluate_none` flag set to True. @@ -280,10 +334,12 @@ class TypeEngine(Visitable, Generic[_T]): typ.should_evaluate_none = True return typ - def copy(self, **kw): + def copy(self: SelfTypeEngine, **kw: Any) -> SelfTypeEngine: return self.adapt(self.__class__) - def compare_against_backend(self, dialect, conn_type): + def compare_against_backend( + self, dialect: Dialect, conn_type: TypeEngine[Any] + ) -> Optional[bool]: """Compare this type against the given backend type. This function is currently not implemented for SQLAlchemy @@ -310,10 +366,12 @@ class TypeEngine(Visitable, Generic[_T]): """ return None - def copy_value(self, value): + def copy_value(self, value: Any) -> Any: return value - def literal_processor(self, dialect): + def literal_processor( + self, dialect: Dialect + ) -> Optional[_LiteralProcessorType[_T]]: """Return a conversion function for processing literal values that are to be rendered directly without using binds. @@ -348,7 +406,9 @@ class TypeEngine(Visitable, Generic[_T]): """ return None - def bind_processor(self, dialect): + def bind_processor( + self, dialect: Dialect + ) -> Optional[_BindProcessorType[_T]]: """Return a conversion function for processing bind values. Returns a callable which will receive a bind parameter value @@ -382,7 +442,9 @@ class TypeEngine(Visitable, Generic[_T]): """ return None - def result_processor(self, dialect, coltype): + def result_processor( + self, dialect: Dialect, coltype: object + ) -> Optional[_ResultProcessorType[_T]]: """Return a conversion function for processing result row values. Returns a callable which will receive a result row column @@ -417,7 +479,9 @@ class TypeEngine(Visitable, Generic[_T]): """ return None - def column_expression(self, colexpr): + def column_expression( + self, colexpr: ColumnElement[_T] + ) -> Optional[ColumnElement[_T]]: """Given a SELECT column expression, return a wrapping SQL expression. This is typically a SQL function that wraps a column expression @@ -461,7 +525,7 @@ class TypeEngine(Visitable, Generic[_T]): return None @util.memoized_property - def _has_column_expression(self): + def _has_column_expression(self) -> bool: """memoized boolean, check if column_expression is implemented. Allows the method to be skipped for the vast majority of expression @@ -474,7 +538,9 @@ class TypeEngine(Visitable, Generic[_T]): is not TypeEngine.column_expression.__code__ ) - def bind_expression(self, bindvalue): + def bind_expression( + self, bindvalue: BindParameter[_T] + ) -> Optional[ColumnElement[_T]]: """Given a bind value (i.e. a :class:`.BindParameter` instance), return a SQL expression in its place. @@ -521,7 +587,7 @@ class TypeEngine(Visitable, Generic[_T]): return None @util.memoized_property - def _has_bind_expression(self): + def _has_bind_expression(self) -> bool: """memoized boolean, check if bind_expression is implemented. Allows the method to be skipped for the vast majority of expression @@ -535,12 +601,12 @@ class TypeEngine(Visitable, Generic[_T]): def _to_instance(cls_or_self: Union[Type[_TE], _TE]) -> _TE: return to_instance(cls_or_self) - def compare_values(self, x, y): + def compare_values(self, x: Any, y: Any) -> bool: """Compare two values for equality.""" - return x == y + return x == y # type: ignore[no-any-return] - def get_dbapi_type(self, dbapi): + def get_dbapi_type(self, dbapi: ModuleType) -> Optional[Any]: """Return the corresponding type object from the underlying DB-API, if any. @@ -550,7 +616,7 @@ class TypeEngine(Visitable, Generic[_T]): return None @property - def python_type(self): + def python_type(self) -> Type[Any]: """Return the Python type object expected to be returned by instances of this type, if known. @@ -569,7 +635,7 @@ class TypeEngine(Visitable, Generic[_T]): raise NotImplementedError() def with_variant( - self: SelfTypeEngine, type_: "TypeEngine", *dialect_names: str + self: SelfTypeEngine, type_: TypeEngine[Any], *dialect_names: str ) -> SelfTypeEngine: r"""Produce a copy of this type object that will utilize the given type when applied to the dialect of the given name. @@ -626,7 +692,9 @@ class TypeEngine(Visitable, Generic[_T]): ) return new_type - def _resolve_for_literal(self, value): + def _resolve_for_literal( + self: SelfTypeEngine, value: Any + ) -> SelfTypeEngine: """adjust this type given a literal Python value that will be stored in a bound parameter. @@ -638,28 +706,28 @@ class TypeEngine(Visitable, Generic[_T]): return self @util.memoized_property - def _type_affinity(self): + def _type_affinity(self) -> Optional[Type[TypeEngine[_T]]]: """Return a rudimental 'affinity' value expressing the general class of type.""" typ = None for t in self.__class__.__mro__: - if t in (TypeEngine, UserDefinedType): + if t is TypeEngine or TypeEngineMixin in t.__bases__: return typ - elif issubclass( - t, (TypeEngine, UserDefinedType) - ) and not t.__dict__.get("_block_from_type_affinity", False): + elif issubclass(t, TypeEngine): typ = t else: return self.__class__ @util.memoized_property - def _generic_type_affinity(self): + def _generic_type_affinity( + self, + ) -> Type[TypeEngine[_T]]: best_camelcase = None best_uppercase = None - if not isinstance(self, (TypeEngine, UserDefinedType)): - return self.__class__ + if not isinstance(self, TypeEngine): + return self.__class__ # type: ignore # mypy bug? for t in self.__class__.__mro__: if ( @@ -669,7 +737,8 @@ class TypeEngine(Visitable, Generic[_T]): "sqlalchemy.sql.type_api", ) and issubclass(t, TypeEngine) - and t is not TypeEngine + and TypeEngineMixin not in t.__bases__ + and t not in (TypeEngine, TypeEngineMixin) and t.__name__[0] != "_" ): if t.__name__.isupper() and not best_uppercase: @@ -677,9 +746,13 @@ class TypeEngine(Visitable, Generic[_T]): elif not t.__name__.isupper() and not best_camelcase: best_camelcase = t - return best_camelcase or best_uppercase or NULLTYPE.__class__ + return ( + best_camelcase + or best_uppercase + or cast("Type[TypeEngine[_T]]", NULLTYPE.__class__) + ) - def as_generic(self, allow_nulltype=False): + def as_generic(self, allow_nulltype: bool = False) -> TypeEngine[_T]: """ Return an instance of the generic type corresponding to this type using heuristic rule. The method may be overridden if this @@ -719,18 +792,20 @@ class TypeEngine(Visitable, Generic[_T]): return util.constructor_copy(self, self._generic_type_affinity) - def dialect_impl(self, dialect): + def dialect_impl(self, dialect: Dialect) -> TypeEngine[_T]: """Return a dialect-specific implementation for this :class:`.TypeEngine`. """ try: - return dialect._type_memos[self]["impl"] + tm = dialect._type_memos[self] except KeyError: pass + else: + return tm["impl"] return self._dialect_info(dialect)["impl"] - def _unwrapped_dialect_impl(self, dialect): + def _unwrapped_dialect_impl(self, dialect: Dialect) -> TypeEngine[_T]: """Return the 'unwrapped' dialect impl for this type. For a type that applies wrapping logic (e.g. TypeDecorator), give @@ -744,60 +819,80 @@ class TypeEngine(Visitable, Generic[_T]): """ return self.dialect_impl(dialect) - def _cached_literal_processor(self, dialect): + def _cached_literal_processor( + self, dialect: Dialect + ) -> Optional[_LiteralProcessorType[_T]]: """Return a dialect-specific literal processor for this type.""" + try: return dialect._type_memos[self]["literal"] except KeyError: pass + # avoid KeyError context coming into literal_processor() function # raises d = self._dialect_info(dialect) d["literal"] = lp = d["impl"].literal_processor(dialect) return lp - def _cached_bind_processor(self, dialect): + def _cached_bind_processor( + self, dialect: Dialect + ) -> Optional[_BindProcessorType[_T]]: """Return a dialect-specific bind processor for this type.""" try: return dialect._type_memos[self]["bind"] except KeyError: pass + # avoid KeyError context coming into bind_processor() function # raises d = self._dialect_info(dialect) d["bind"] = bp = d["impl"].bind_processor(dialect) return bp - def _cached_result_processor(self, dialect, coltype): + def _cached_result_processor( + self, dialect: Dialect, coltype: Any + ) -> Optional[_ResultProcessorType[_T]]: """Return a dialect-specific result processor for this type.""" try: - return dialect._type_memos[self][coltype] + return dialect._type_memos[self]["result"][coltype] except KeyError: pass + # avoid KeyError context coming into result_processor() function # raises d = self._dialect_info(dialect) # key assumption: DBAPI type codes are # constants. Else this dictionary would # grow unbounded. - d[coltype] = rp = d["impl"].result_processor(dialect, coltype) + rp = d["impl"].result_processor(dialect, coltype) + d["result"][coltype] = rp return rp - def _cached_custom_processor(self, dialect, key, fn): + def _cached_custom_processor( + self, dialect: Dialect, key: str, fn: Callable[[TypeEngine[_T]], _O] + ) -> _O: + """return a dialect-specific processing object for + custom purposes. + + The cx_Oracle dialect uses this at the moment. + + """ try: - return dialect._type_memos[self][key] + return cast(_O, dialect._type_memos[self]["custom"][key]) except KeyError: pass # avoid KeyError context coming into fn() function # raises d = self._dialect_info(dialect) impl = d["impl"] - d[key] = result = fn(impl) + custom_dict = d.setdefault("custom", {}) + custom_dict[key] = result = fn(impl) return result - def _dialect_info(self, dialect): + def _dialect_info(self, dialect: Dialect) -> _TypeMemoDict: """Return a dialect-specific registry which caches a dialect-specific implementation, bind processing function, and one or more result processing functions.""" @@ -810,10 +905,11 @@ class TypeEngine(Visitable, Generic[_T]): impl = self.adapt(type(self)) # this can't be self, else we create a cycle assert impl is not self - dialect._type_memos[self] = d = {"impl": impl} + d: _TypeMemoDict = {"impl": impl, "result": {}} + dialect._type_memos[self] = d return d - def _gen_dialect_impl(self, dialect): + def _gen_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]: if dialect.name in self._variant_mapping: return self._variant_mapping[dialect.name]._gen_dialect_impl( dialect @@ -822,7 +918,9 @@ class TypeEngine(Visitable, Generic[_T]): return dialect.type_descriptor(self) @util.memoized_property - def _static_cache_key(self): + def _static_cache_key( + self, + ) -> Union[CacheConst, Tuple[Any, ...]]: names = util.get_cls_kwargs(self.__class__) return (self.__class__,) + tuple( ( @@ -835,7 +933,17 @@ class TypeEngine(Visitable, Generic[_T]): if k in self.__dict__ and not k.startswith("_") ) - def adapt(self, cls, **kw): + @overload + def adapt(self, cls: Type[_TE], **kw: Any) -> _TE: + ... + + @overload + def adapt(self, cls: Type[TypeEngineMixin], **kw: Any) -> TypeEngine[Any]: + ... + + def adapt( + self, cls: Type[Union[TypeEngine[Any], TypeEngineMixin]], **kw: Any + ) -> TypeEngine[Any]: """Produce an "adapted" form of this type, given an "impl" class to work with. @@ -843,9 +951,13 @@ class TypeEngine(Visitable, Generic[_T]): types with "implementation" types that are specific to a particular dialect. """ - return util.constructor_copy(self, cls, **kw) + return util.constructor_copy( + self, cast(Type[TypeEngine[Any]], cls), **kw + ) - def coerce_compared_value(self, op, value): + def coerce_compared_value( + self, op: Optional[OperatorType], value: Any + ) -> TypeEngine[Any]: """Suggest a type for a 'coerced' Python value in an expression. Given an operator and value, gives the type a chance @@ -873,10 +985,10 @@ class TypeEngine(Visitable, Generic[_T]): else: return _coerced_type - def _compare_type_affinity(self, other): + def _compare_type_affinity(self, other: TypeEngine[Any]) -> bool: return self._type_affinity is other._type_affinity - def compile(self, dialect=None): + def compile(self, dialect: Optional[Dialect] = None) -> str: """Produce a string-compiled form of this :class:`.TypeEngine`. When called with no arguments, uses a "default" dialect @@ -888,24 +1000,65 @@ class TypeEngine(Visitable, Generic[_T]): # arg, return value is inconsistent with # ClauseElement.compile()....this is a mistake. - if not dialect: + if dialect is None: dialect = self._default_dialect() - return dialect.type_compiler.process(self) + return dialect.type_compiler_instance.process(self) @util.preload_module("sqlalchemy.engine.default") - def _default_dialect(self): - default = util.preloaded.engine_default - return default.StrCompileDialect() + def _default_dialect(self) -> Dialect: - def __str__(self): + if TYPE_CHECKING: + from ..engine import default + else: + default = util.preloaded.engine_default + + # dmypy / mypy seems to sporadically keep thinking this line is + # returning Any, which seems to be caused by the @deprecated_params + # decorator on the DefaultDialect constructor + return default.StrCompileDialect() # type: ignore + + def __str__(self) -> str: return str(self.compile()) - def __repr__(self): + def __repr__(self) -> str: return util.generic_repr(self) -class ExternalType: +class TypeEngineMixin: + """classes which subclass this can act as "mixin" classes for + TypeEngine.""" + + __slots__ = () + + if TYPE_CHECKING: + + @util.memoized_property + def _static_cache_key( + self, + ) -> Union[CacheConst, Tuple[Any, ...]]: + ... + + @overload + def adapt(self, cls: Type[_TE], **kw: Any) -> _TE: + ... + + @overload + def adapt( + self, cls: Type[TypeEngineMixin], **kw: Any + ) -> TypeEngine[Any]: + ... + + def adapt( + self, cls: Type[Union[TypeEngine[Any], TypeEngineMixin]], **kw: Any + ) -> TypeEngine[Any]: + ... + + def dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]: + ... + + +class ExternalType(TypeEngineMixin): """mixin that defines attributes and behaviors specific to third-party datatypes. @@ -1057,13 +1210,18 @@ class ExternalType: """ # noqa: E501 - @property - def _static_cache_key(self): + @util.non_memoized_property + def _static_cache_key( + self, + ) -> Union[CacheConst, Tuple[Any, ...]]: cache_ok = self.__class__.__dict__.get("cache_ok", None) if cache_ok is None: - subtype_idx = self.__class__.__mro__.index(ExternalType) - subtype = self.__class__.__mro__[max(subtype_idx - 1, 0)] + for subtype in self.__class__.__mro__: + if ExternalType in subtype.__bases__: + break + else: + subtype = self.__class__.__mro__[1] util.warn( "%s %r will not produce a cache key because " @@ -1076,12 +1234,14 @@ class ExternalType: code="cprf", ) elif cache_ok is True: - return super(ExternalType, self)._static_cache_key + return super()._static_cache_key return NO_CACHE -class UserDefinedType(ExternalType, TypeEngine, util.EnsureKWArg): +class UserDefinedType( + ExternalType, TypeEngineMixin, TypeEngine[_T], util.EnsureKWArg +): """Base for user defined types. This should be the base of new types. Note that @@ -1148,7 +1308,9 @@ class UserDefinedType(ExternalType, TypeEngine, util.EnsureKWArg): ensure_kwarg = "get_col_spec" - def coerce_compared_value(self, op, value): + def coerce_compared_value( + self, op: Optional[OperatorType], value: Any + ) -> TypeEngine[Any]: """Suggest a type for a 'coerced' Python value in an expression. Default behavior for :class:`.UserDefinedType` is the @@ -1162,7 +1324,7 @@ class UserDefinedType(ExternalType, TypeEngine, util.EnsureKWArg): return self -class Emulated: +class Emulated(TypeEngineMixin): """Mixin for base types that emulate the behavior of a DB-native type. An :class:`.Emulated` type will use an available database type @@ -1180,7 +1342,13 @@ class Emulated: """ - def adapt_to_emulated(self, impltype, **kw): + native: bool + + def adapt_to_emulated( + self, + impltype: Type[Union[TypeEngine[Any], TypeEngineMixin]], + **kw: Any, + ) -> TypeEngine[Any]: """Given an impl class, adapt this type to the impl assuming "emulated". The impl should also be an "emulated" version of this type, @@ -1189,27 +1357,43 @@ class Emulated: e.g.: sqltypes.Enum adapts to the Enum class. """ - return super(Emulated, self).adapt(impltype, **kw) + return super().adapt(impltype, **kw) - def adapt(self, impltype, **kw): - if hasattr(impltype, "adapt_emulated_to_native"): + @overload + def adapt(self, cls: Type[_TE], **kw: Any) -> _TE: + ... + + @overload + def adapt(self, cls: Type[TypeEngineMixin], **kw: Any) -> TypeEngine[Any]: + ... + + def adapt( + self, cls: Type[Union[TypeEngine[Any], TypeEngineMixin]], **kw: Any + ) -> TypeEngine[Any]: + if _is_native_for_emulated(cls): if self.native: # native support requested, dialect gave us a native # implementor, pass control over to it - return impltype.adapt_emulated_to_native(self, **kw) + return cls.adapt_emulated_to_native(self, **kw) else: # non-native support, let the native implementor # decide also, at the moment this is just to help debugging # as only the default logic is implemented. - return impltype.adapt_native_to_emulated(self, **kw) + return cls.adapt_native_to_emulated(self, **kw) else: - if issubclass(impltype, self.__class__): - return self.adapt_to_emulated(impltype, **kw) + if issubclass(cls, self.__class__): + return self.adapt_to_emulated(cls, **kw) else: - return super(Emulated, self).adapt(impltype, **kw) + return super().adapt(cls, **kw) + +def _is_native_for_emulated( + typ: Type[Union[TypeEngine[Any], TypeEngineMixin]], +) -> TypeGuard["Type[NativeForEmulated]"]: + return hasattr(typ, "adapt_emulated_to_native") -class NativeForEmulated: + +class NativeForEmulated(TypeEngineMixin): """Indicates DB-native types supported by an :class:`.Emulated` type. .. versionadded:: 1.2.0b3 @@ -1217,7 +1401,11 @@ class NativeForEmulated: """ @classmethod - def adapt_native_to_emulated(cls, impl, **kw): + def adapt_native_to_emulated( + cls, + impl: Union[TypeEngine[Any], TypeEngineMixin], + **kw: Any, + ) -> TypeEngine[Any]: """Given an impl, adapt this type's class to the impl assuming "emulated". @@ -1227,7 +1415,12 @@ class NativeForEmulated: return impl.adapt(impltype, **kw) @classmethod - def adapt_emulated_to_native(cls, impl, **kw): + def adapt_emulated_to_native( + cls, + impl: Union[TypeEngine[Any], TypeEngineMixin], + **kw: Any, + ) -> TypeEngine[Any]: + """Given an impl, adapt this type's class to the impl assuming "native". The impl will be an :class:`.Emulated` class but not a @@ -1236,10 +1429,20 @@ class NativeForEmulated: e.g.: postgresql.ENUM produces a type given an Enum instance. """ - return cls(**kw) + # dmypy seems to crash on this + return cls(**kw) # type: ignore -class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine[_T]): + # dmypy seems to crash with this, on repeated runs with changes + # if TYPE_CHECKING: + # def __init__(self, **kw: Any): + # ... + + +SelfTypeDecorator = TypeVar("SelfTypeDecorator", bound="TypeDecorator[Any]") + + +class TypeDecorator(SchemaEventTarget, ExternalType, TypeEngine[_T]): """Allows the creation of types which add additional functionality to an existing type. @@ -1358,9 +1561,24 @@ class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine[_T]): _is_type_decorator = True + # this is that pattern I've used in a few places (Dialect.dbapi, + # Dialect.type_compiler) where the "cls.attr" is a class to make something, + # and "instance.attr" is an instance of that thing. It's such a nifty, + # great pattern, and there is zero chance Python typing tools will ever be + # OK with it. For TypeDecorator.impl, this is a highly public attribute so + # we really can't change its behavior without a major deprecation routine. impl: Union[TypeEngine[Any], Type[TypeEngine[Any]]] - def __init__(self, *args, **kwargs): + # we are changing its behavior *slightly*, which is that we now consume + # the instance level version from this memoized property instead, so you + # can't reassign "impl" on an existing TypeDecorator that's already been + # used (something one shouldn't do anyway) without also updating + # impl_instance. + @util.memoized_property + def impl_instance(self) -> TypeEngine[Any]: + return self.impl # type: ignore + + def __init__(self, *args: Any, **kwargs: Any): """Construct a :class:`.TypeDecorator`. Arguments sent here are passed to the constructor @@ -1385,9 +1603,10 @@ class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine[_T]): "'impl' which refers to the class of " "type being decorated" ) + self.impl = to_instance(self.__class__.impl, *args, **kwargs) - coerce_to_is_types = (util.NoneType,) + coerce_to_is_types: Sequence[Type[Any]] = (type(None),) """Specify those Python types which should be coerced at the expression level to "IS " when compared using ``==`` (and same for ``IS NOT`` in conjunction with ``!=``). @@ -1416,33 +1635,42 @@ class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine[_T]): __slots__ = () - def operate(self, op, *other, **kwargs): + def operate( + self, op: OperatorType, *other: Any, **kwargs: Any + ) -> ColumnElement[_CT]: + if TYPE_CHECKING: + assert isinstance(self.expr.type, TypeDecorator) kwargs["_python_is_types"] = self.expr.type.coerce_to_is_types return super(TypeDecorator.Comparator, self).operate( op, *other, **kwargs ) - def reverse_operate(self, op, other, **kwargs): + def reverse_operate( + self, op: OperatorType, other: Any, **kwargs: Any + ) -> ColumnElement[_CT]: + if TYPE_CHECKING: + assert isinstance(self.expr.type, TypeDecorator) kwargs["_python_is_types"] = self.expr.type.coerce_to_is_types return super(TypeDecorator.Comparator, self).reverse_operate( op, other, **kwargs ) @property - def comparator_factory(self) -> Callable[..., TypeEngine.Comparator[_T]]: - if TypeDecorator.Comparator in self.impl.comparator_factory.__mro__: + def comparator_factory( # type: ignore # mypy properties bug + self, + ) -> _ComparatorFactory[Any]: + if TypeDecorator.Comparator in self.impl.comparator_factory.__mro__: # type: ignore # noqa E501 return self.impl.comparator_factory else: + # reconcile the Comparator class on the impl with that + # of TypeDecorator return type( "TDComparator", - (TypeDecorator.Comparator, self.impl.comparator_factory), + (TypeDecorator.Comparator, self.impl.comparator_factory), # type: ignore # noqa E501 {}, ) - def _gen_dialect_impl(self, dialect): - """ - #todo - """ + def _gen_dialect_impl(self, dialect: Dialect) -> TypeEngine[_T]: if dialect.name in self._variant_mapping: adapted = dialect.type_descriptor( self._variant_mapping[dialect.name] @@ -1463,35 +1691,34 @@ class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine[_T]): "implement the copy() method, it must " "return an object of type %s" % (self, self.__class__) ) - tt.impl = typedesc + tt.impl = tt.impl_instance = typedesc return tt - @property - def _type_affinity(self): - """ - #todo - """ - return self.impl._type_affinity + @util.non_memoized_property + def _type_affinity(self) -> Optional[Type[TypeEngine[Any]]]: + return self.impl_instance._type_affinity - def _set_parent(self, column, outer=False, **kw): + def _set_parent( + self, parent: SchemaEventTarget, outer: bool = False, **kw: Any + ) -> None: """Support SchemaEventTarget""" - super(TypeDecorator, self)._set_parent(column) + super()._set_parent(parent) - if not outer and isinstance(self.impl, SchemaEventTarget): - self.impl._set_parent(column, outer=False, **kw) + if not outer and isinstance(self.impl_instance, SchemaEventTarget): + self.impl_instance._set_parent(parent, outer=False, **kw) - def _set_parent_with_dispatch(self, parent): + def _set_parent_with_dispatch( + self, parent: SchemaEventTarget, **kw: Any + ) -> None: """Support SchemaEventTarget""" - super(TypeDecorator, self)._set_parent_with_dispatch( - parent, outer=True - ) + super()._set_parent_with_dispatch(parent, outer=True, **kw) - if isinstance(self.impl, SchemaEventTarget): - self.impl._set_parent_with_dispatch(parent) + if isinstance(self.impl_instance, SchemaEventTarget): + self.impl_instance._set_parent_with_dispatch(parent) - def type_engine(self, dialect): + def type_engine(self, dialect: Dialect) -> TypeEngine[Any]: """Return a dialect-specific :class:`.TypeEngine` instance for this :class:`.TypeDecorator`. @@ -1508,7 +1735,7 @@ class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine[_T]): else: return self.load_dialect_impl(dialect) - def load_dialect_impl(self, dialect): + def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]: """Return a :class:`.TypeEngine` object corresponding to a dialect. This is an end-user override hook that can be used to provide @@ -1520,9 +1747,9 @@ class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine[_T]): By default returns ``self.impl``. """ - return self.impl + return self.impl_instance - def _unwrapped_dialect_impl(self, dialect): + def _unwrapped_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]: """Return the 'unwrapped' dialect impl for this type. This is used by the :meth:`.DefaultDialect.set_input_sizes` @@ -1540,12 +1767,14 @@ class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine[_T]): else: return typ - def __getattr__(self, key): + def __getattr__(self, key: str) -> Any: """Proxy all other undefined accessors to the underlying implementation.""" - return getattr(self.impl, key) + return getattr(self.impl_instance, key) - def process_literal_param(self, value, dialect): + def process_literal_param( + self, value: Optional[_T], dialect: Dialect + ) -> str: """Receive a literal parameter value to be rendered inline within a statement. @@ -1568,7 +1797,7 @@ class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine[_T]): """ raise NotImplementedError() - def process_bind_param(self, value, dialect): + def process_bind_param(self, value: Optional[_T], dialect: Dialect) -> Any: """Receive a bound parameter value to be converted. Custom subclasses of :class:`_types.TypeDecorator` should override @@ -1595,7 +1824,9 @@ class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine[_T]): raise NotImplementedError() - def process_result_value(self, value, dialect): + def process_result_value( + self, value: Optional[Any], dialect: Any + ) -> Optional[_T]: """Receive a result-row column value to be converted. Custom subclasses of :class:`_types.TypeDecorator` should override @@ -1624,7 +1855,7 @@ class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine[_T]): raise NotImplementedError() @util.memoized_property - def _has_bind_processor(self): + def _has_bind_processor(self) -> bool: """memoized boolean, check if process_bind_param is implemented. Allows the base process_bind_param to raise @@ -1638,14 +1869,16 @@ class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine[_T]): ) @util.memoized_property - def _has_literal_processor(self): + def _has_literal_processor(self) -> bool: """memoized boolean, check if process_literal_param is implemented.""" return util.method_is_overridden( self, TypeDecorator.process_literal_param ) - def literal_processor(self, dialect): + def literal_processor( + self, dialect: Dialect + ) -> Optional[_LiteralProcessorType[_T]]: """Provide a literal processing function for the given :class:`.Dialect`. @@ -1661,34 +1894,59 @@ class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine[_T]): "inner" processing provided by the implementing type is maintained. """ + if self._has_literal_processor: - process_param = self.process_literal_param + process_literal_param = self.process_literal_param + process_bind_param = None elif self._has_bind_processor: - # the bind processor should normally be OK - # for TypeDecorator since it isn't doing DB-level - # handling, the handling here won't be different for bound vs. - # literals. - process_param = self.process_bind_param + # use the bind processor if dont have a literal processor, + # but we have an impl literal processor + process_literal_param = None + process_bind_param = self.process_bind_param else: - process_param = None + process_literal_param = None + process_bind_param = None - if process_param: - impl_processor = self.impl.literal_processor(dialect) + if process_literal_param is not None: + impl_processor = self.impl_instance.literal_processor(dialect) if impl_processor: - def process(value): - return impl_processor(process_param(value, dialect)) + fixed_impl_processor = impl_processor + fixed_process_literal_param = process_literal_param + + def process(value: Any) -> str: + return fixed_impl_processor( + fixed_process_literal_param(value, dialect) + ) else: + fixed_process_literal_param = process_literal_param - def process(value): - return process_param(value, dialect) + def process(value: Any) -> str: + return fixed_process_literal_param(value, dialect) return process + + elif process_bind_param is not None: + impl_processor = self.impl_instance.literal_processor(dialect) + if not impl_processor: + return None + else: + fixed_impl_processor = impl_processor + fixed_process_bind_param = process_bind_param + + def process(value: Any) -> str: + return fixed_impl_processor( + fixed_process_bind_param(value, dialect) + ) + + return process else: - return self.impl.literal_processor(dialect) + return self.impl_instance.literal_processor(dialect) - def bind_processor(self, dialect): + def bind_processor( + self, dialect: Dialect + ) -> Optional[_BindProcessorType[_T]]: """Provide a bound value processing function for the given :class:`.Dialect`. @@ -1708,23 +1966,28 @@ class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine[_T]): """ if self._has_bind_processor: process_param = self.process_bind_param - impl_processor = self.impl.bind_processor(dialect) + impl_processor = self.impl_instance.bind_processor(dialect) if impl_processor: + fixed_impl_processor = impl_processor + fixed_process_param = process_param - def process(value): - return impl_processor(process_param(value, dialect)) + def process(value: Optional[_T]) -> Any: + return fixed_impl_processor( + fixed_process_param(value, dialect) + ) else: + fixed_process_param = process_param - def process(value): - return process_param(value, dialect) + def process(value: Optional[_T]) -> Any: + return fixed_process_param(value, dialect) return process else: - return self.impl.bind_processor(dialect) + return self.impl_instance.bind_processor(dialect) @util.memoized_property - def _has_result_processor(self): + def _has_result_processor(self) -> bool: """memoized boolean, check if process_result_value is implemented. Allows the base process_result_value to raise @@ -1737,7 +2000,9 @@ class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine[_T]): self, TypeDecorator.process_result_value ) - def result_processor(self, dialect, coltype): + def result_processor( + self, dialect: Dialect, coltype: Any + ) -> Optional[_ResultProcessorType[_T]]: """Provide a result value processing function for the given :class:`.Dialect`. @@ -1758,30 +2023,39 @@ class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine[_T]): """ if self._has_result_processor: process_value = self.process_result_value - impl_processor = self.impl.result_processor(dialect, coltype) + impl_processor = self.impl_instance.result_processor( + dialect, coltype + ) if impl_processor: + fixed_process_value = process_value + fixed_impl_processor = impl_processor - def process(value): - return process_value(impl_processor(value), dialect) + def process(value: Any) -> Optional[_T]: + return fixed_process_value( + fixed_impl_processor(value), dialect + ) else: + fixed_process_value = process_value - def process(value): - return process_value(value, dialect) + def process(value: Any) -> Optional[_T]: + return fixed_process_value(value, dialect) return process else: - return self.impl.result_processor(dialect, coltype) + return self.impl_instance.result_processor(dialect, coltype) @util.memoized_property - def _has_bind_expression(self): + def _has_bind_expression(self) -> bool: return ( util.method_is_overridden(self, TypeDecorator.bind_expression) - or self.impl._has_bind_expression + or self.impl_instance._has_bind_expression ) - def bind_expression(self, bindparam): + def bind_expression( + self, bindparam: BindParameter[_T] + ) -> Optional[ColumnElement[_T]]: """Given a bind value (i.e. a :class:`.BindParameter` instance), return a SQL expression which will typically wrap the given parameter. @@ -1800,10 +2074,10 @@ class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine[_T]): type. """ - return self.impl.bind_expression(bindparam) + return self.impl_instance.bind_expression(bindparam) @util.memoized_property - def _has_column_expression(self): + def _has_column_expression(self) -> bool: """memoized boolean, check if column_expression is implemented. Allows the method to be skipped for the vast majority of expression @@ -1813,10 +2087,12 @@ class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine[_T]): return ( util.method_is_overridden(self, TypeDecorator.column_expression) - or self.impl._has_column_expression + or self.impl_instance._has_column_expression ) - def column_expression(self, column): + def column_expression( + self, column: ColumnElement[_T] + ) -> Optional[ColumnElement[_T]]: """Given a SELECT column expression, return a wrapping SQL expression. .. note:: @@ -1838,9 +2114,11 @@ class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine[_T]): """ - return self.impl.column_expression(column) + return self.impl_instance.column_expression(column) - def coerce_compared_value(self, op, value): + def coerce_compared_value( + self, op: Optional[OperatorType], value: Any + ) -> Any: """Suggest a type for a 'coerced' Python value in an expression. By default, returns self. This method is called by @@ -1858,7 +2136,7 @@ class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine[_T]): """ return self - def copy(self, **kw): + def copy(self: SelfTypeDecorator, **kw: Any) -> SelfTypeDecorator: """Produce a copy of this :class:`.TypeDecorator` instance. This is a shallow copy and is provided to fulfill part of @@ -1872,16 +2150,16 @@ class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine[_T]): instance.__dict__.update(self.__dict__) return instance - def get_dbapi_type(self, dbapi): + def get_dbapi_type(self, dbapi: ModuleType) -> Optional[Any]: """Return the DBAPI type object represented by this :class:`.TypeDecorator`. By default this calls upon :meth:`.TypeEngine.get_dbapi_type` of the underlying "impl". """ - return self.impl.get_dbapi_type(dbapi) + return self.impl_instance.get_dbapi_type(dbapi) - def compare_values(self, x, y): + def compare_values(self, x: Any, y: Any) -> bool: """Given two values, compare them for equality. By default this calls upon :meth:`.TypeEngine.compare_values` @@ -1894,46 +2172,60 @@ class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine[_T]): has occurred. """ - return self.impl.compare_values(x, y) + return self.impl_instance.compare_values(x, y) + # mypy property bug @property - def sort_key_function(self): - return self.impl.sort_key_function + def sort_key_function(self) -> Optional[Callable[[Any], Any]]: # type: ignore # noqa E501 + return self.impl_instance.sort_key_function - def __repr__(self): - return util.generic_repr(self, to_inspect=self.impl) + def __repr__(self) -> str: + return util.generic_repr(self, to_inspect=self.impl_instance) -class Variant(TypeDecorator): +class Variant(TypeDecorator[_T]): """deprecated. symbol is present for backwards-compatibility with workaround recipes, however this actual type should not be used. """ - def __init__(self, *arg, **kw): + def __init__(self, *arg: Any, **kw: Any): raise NotImplementedError( "Variant is no longer used in SQLAlchemy; this is a " "placeholder symbol for backwards compatibility." ) -def _reconstitute_comparator(expression): +def _reconstitute_comparator(expression: Any) -> Any: return expression.comparator +@overload +def to_instance(typeobj: Union[Type[_TE], _TE], *arg: Any, **kw: Any) -> _TE: + ... + + +@overload +def to_instance(typeobj: None, *arg: Any, **kw: Any) -> TypeEngine[None]: + ... + + def to_instance( - typeobj: Union[Type[TypeEngine[_T]], TypeEngine[_T], None], *arg, **kw -) -> TypeEngine[_T]: + typeobj: Union[Type[_TE], _TE, None], *arg: Any, **kw: Any +) -> Union[_TE, TypeEngine[None]]: if typeobj is None: return NULLTYPE if callable(typeobj): - return typeobj(*arg, **kw) + return typeobj(*arg, **kw) # type: ignore # for pyright else: return typeobj -def adapt_type(typeobj, colspecs): +def adapt_type( + typeobj: TypeEngine[Any], + colspecs: Mapping[Type[Any], Type[TypeEngine[Any]]], +) -> TypeEngine[Any]: if isinstance(typeobj, type): typeobj = typeobj() for t in typeobj.__class__.__mro__[0:-1]: diff --git a/lib/sqlalchemy/util/deprecations.py b/lib/sqlalchemy/util/deprecations.py index a8e58a8bfb..5c536b675f 100644 --- a/lib/sqlalchemy/util/deprecations.py +++ b/lib/sqlalchemy/util/deprecations.py @@ -36,7 +36,7 @@ _T = TypeVar("_T", bound=Any) # https://mypy.readthedocs.io/en/stable/generics.html#declaring-decorators -_F = TypeVar("_F", bound=Callable[..., Any]) +_F = TypeVar("_F", bound="Callable[..., Any]") def _warn_with_version( diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py index ef8520c4f8..35294715cb 100644 --- a/lib/sqlalchemy/util/langhelpers.py +++ b/lib/sqlalchemy/util/langhelpers.py @@ -178,7 +178,8 @@ def clsname_as_plain_name(cls: Type[Any]) -> str: def method_is_overridden( - instance_or_cls: Union[Type[Any], object], against_method: types.MethodType + instance_or_cls: Union[Type[Any], object], + against_method: Callable[..., Any], ) -> bool: """Return True if the two class methods don't match.""" @@ -815,7 +816,12 @@ def unbound_method_to_callable(func_or_cls): return func_or_cls -def generic_repr(obj, additional_kw=(), to_inspect=None, omit_kwarg=()): +def generic_repr( + obj: Any, + additional_kw: Sequence[Tuple[str, Any]] = (), + to_inspect: Optional[Union[object, List[object]]] = None, + omit_kwarg: Sequence[str] = (), +) -> str: """Produce a __repr__() based on direct association of the __init__() specification vs. same-named attributes present. diff --git a/pyproject.toml b/pyproject.toml index 6cfa8db469..7117c2689f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -86,6 +86,7 @@ module = [ "sqlalchemy.sql.cache_key", "sqlalchemy.sql._elements_constructors", "sqlalchemy.sql.operators", + "sqlalchemy.sql.type_api", "sqlalchemy.sql.roles", "sqlalchemy.sql.visitors", "sqlalchemy.sql._py_util", @@ -109,6 +110,7 @@ strict = true module = [ #"sqlalchemy.sql.*", + "sqlalchemy.sql.sqltypes", "sqlalchemy.sql.elements", "sqlalchemy.sql.coercions", "sqlalchemy.sql.compiler", diff --git a/test/engine/test_execute.py b/test/engine/test_execute.py index 8b950026f2..e50d515fb4 100644 --- a/test/engine/test_execute.py +++ b/test/engine/test_execute.py @@ -355,7 +355,6 @@ class ExecuteTest(fixtures.TablesTest): ) ) ) - assert_raises_message( TypeError, "I'm not a DBAPI error", @@ -379,6 +378,8 @@ class ExecuteTest(fixtures.TablesTest): # have any special behaviors with patch.object( testing.db.dialect, "dbapi", Mock(Error=DBAPIError) + ), patch.object( + testing.db.dialect, "loaded_dbapi", Mock(Error=DBAPIError) ), patch.object( testing.db.dialect, "is_disconnect", lambda *arg: False ), patch.object( diff --git a/test/engine/test_pool.py b/test/engine/test_pool.py index c1613069e1..7f2dd434c3 100644 --- a/test/engine/test_pool.py +++ b/test/engine/test_pool.py @@ -1682,7 +1682,7 @@ class QueuePoolTest(PoolTestBase): dialect = Mock() dialect.is_disconnect = lambda *arg, **kw: True - dialect.dbapi.Error = Error + dialect.dbapi.Error = dialect.loaded_dbapi.Error = Error pools = [] diff --git a/test/ext/mypy/plain_files/sql_operations.py b/test/ext/mypy/plain_files/sql_operations.py index b7bae0185d..78b0a467ce 100644 --- a/test/ext/mypy/plain_files/sql_operations.py +++ b/test/ext/mypy/plain_files/sql_operations.py @@ -25,7 +25,8 @@ expr4 = -c2 expr5 = ~(c2 == 5) -expr6 = ~column("q", Boolean) +q = column("q", Boolean) +expr6 = ~q expr7 = c1 + "x" diff --git a/test/sql/test_metadata.py b/test/sql/test_metadata.py index 0924941658..11c3e83b7d 100644 --- a/test/sql/test_metadata.py +++ b/test/sql/test_metadata.py @@ -2100,6 +2100,8 @@ class SchemaTypeTest(fixtures.TestBase): # causes collection-mutate-while-iterated errors in the event system # since the hooks here call upon the adapted type. Need to figure out # why Enum and Boolean don't have this problem. + # NOTE: it's likely the need for the SchemaType.adapt() method, + # which Enum / Boolean don't use (and crash if it comes first) class MyType(TrackEvents, sqltypes.SchemaType, sqltypes.TypeEngine): pass @@ -2141,7 +2143,7 @@ class SchemaTypeTest(fixtures.TestBase): # [ticket:3832] # this also serves as the test for [ticket:6152] - class MySchemaType(sqltypes.TypeEngine, sqltypes.SchemaType): + class MySchemaType(sqltypes.SchemaType): pass target_typ = MySchemaType() diff --git a/test/sql/test_quote.py b/test/sql/test_quote.py index 07003abcb5..9812f84c15 100644 --- a/test/sql/test_quote.py +++ b/test/sql/test_quote.py @@ -954,7 +954,7 @@ class QuotedIdentTest(fixtures.TestBase): eq_(q2.quote, False) def test_coerce_none(self): - q1 = quoted_name(None, False) + q1 = quoted_name.construct(None, False) eq_(q1, None) def test_apply_map_quoted(self): diff --git a/test/sql/test_types.py b/test/sql/test_types.py index acf16565a0..d496b323b5 100644 --- a/test/sql/test_types.py +++ b/test/sql/test_types.py @@ -72,7 +72,9 @@ from sqlalchemy.sql import null from sqlalchemy.sql import operators from sqlalchemy.sql import sqltypes from sqlalchemy.sql import table +from sqlalchemy.sql import type_api from sqlalchemy.sql import visitors +from sqlalchemy.sql.compiler import TypeCompiler from sqlalchemy.sql.sqltypes import TypeEngine from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_message @@ -117,10 +119,15 @@ def _types_for_mod(mod): def _all_types(omit_special_types=False): seen = set() for typ in _types_for_mod(types): - if omit_special_types and typ in ( - types.TypeDecorator, - types.TypeEngine, - types.Variant, + if omit_special_types and ( + typ + in ( + TypeEngine, + type_api.TypeEngineMixin, + types.Variant, + types.TypeDecorator, + ) + or type_api.TypeEngineMixin in typ.__bases__ ): continue @@ -3553,7 +3560,7 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): return "MYINTEGER %s" % kw["type_expression"].name dialect = default.DefaultDialect() - dialect.type_compiler = SomeTypeCompiler(dialect) + dialect.type_compiler_instance = SomeTypeCompiler(dialect) self.assert_compile( ddl.CreateColumn(Column("bar", VARCHAR(50))), "bar MYVARCHAR", @@ -3565,6 +3572,34 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): dialect=dialect, ) + def test_legacy_typecompiler_attribute(self): + """the .type_compiler attribute was broken into + .type_compiler_cls and .type_compiler_instance for 2.0 so that it can + be properly typed. However it is expected that the majority of + dialects make use of the .type_compiler attribute both at the class + level as well as the instance level, so make sure it still functions + in exactly the same way, both as the type compiler class to be + used as well as that it's present as an instance on an instance + of the dialect. + + """ + + dialect = default.DefaultDialect() + assert isinstance( + dialect.type_compiler_instance, dialect.type_compiler_cls + ) + is_(dialect.type_compiler_instance, dialect.type_compiler) + + class MyTypeCompiler(TypeCompiler): + pass + + class MyDialect(default.DefaultDialect): + type_compiler = MyTypeCompiler + + dialect = MyDialect() + assert isinstance(dialect.type_compiler_instance, MyTypeCompiler) + is_(dialect.type_compiler_instance, dialect.type_compiler) + class TestKWArgPassThru(AssertsCompiledSQL, fixtures.TestBase): __backend__ = True