From: Mike Bayer Date: Thu, 10 Mar 2022 16:57:00 +0000 (-0500) Subject: additional mypy strictness X-Git-Tag: rel_2_0_0b1~428^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=4c28867f944637ef313f98d5f09da05255418c6d;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git additional mypy strictness enable type checking within untyped defs. This allowed some more internals to be fixed up with assertions etc. some internals that were unnecessary or not even used at all were removed. BaseCursorResult was no longer necessary since we only have one kind of CursorResult now. The different ResultProxy subclasses that had alternate "strategies" dont appear to be used at all even in 1.4.x, as there's no code that accesses the _cursor_strategy_cls attribute, which is also removed. As these were mostly private constructs that weren't even functioning correctly in any case, it's fine to remove these over the 2.0 boundary. Change-Id: Ifd536987d104b1cd8b546cefdbd5c1e5d1801082 --- diff --git a/lib/sqlalchemy/__init__.py b/lib/sqlalchemy/__init__.py index 7ceb33c7ca..de01a1b461 100644 --- a/lib/sqlalchemy/__init__.py +++ b/lib/sqlalchemy/__init__.py @@ -9,12 +9,8 @@ from __future__ import annotations from . import util as _util from .engine import AdaptedConnection as AdaptedConnection -from .engine import BaseCursorResult as BaseCursorResult from .engine import BaseRow as BaseRow from .engine import BindTyping as BindTyping -from .engine import BufferedColumnResultProxy as BufferedColumnResultProxy -from .engine import BufferedColumnRow as BufferedColumnRow -from .engine import BufferedRowResultProxy as BufferedRowResultProxy from .engine import ChunkedIteratorResult as ChunkedIteratorResult from .engine import Compiled as Compiled from .engine import Connection as Connection @@ -28,7 +24,6 @@ from .engine import engine_from_config as engine_from_config from .engine import ExceptionContext as ExceptionContext from .engine import ExecutionContext as ExecutionContext from .engine import FrozenResult as FrozenResult -from .engine import FullyBufferedResultProxy as FullyBufferedResultProxy from .engine import Inspector as Inspector from .engine import IteratorResult as IteratorResult from .engine import make_url as make_url diff --git a/lib/sqlalchemy/engine/__init__.py b/lib/sqlalchemy/engine/__init__.py index 32f3f2eccd..29dd6aff90 100644 --- a/lib/sqlalchemy/engine/__init__.py +++ b/lib/sqlalchemy/engine/__init__.py @@ -25,12 +25,7 @@ from .base import Transaction as Transaction from .base import TwoPhaseTransaction as TwoPhaseTransaction from .create import create_engine as create_engine from .create import engine_from_config as engine_from_config -from .cursor import BaseCursorResult as BaseCursorResult -from .cursor import BufferedColumnResultProxy as BufferedColumnResultProxy -from .cursor import BufferedColumnRow as BufferedColumnRow -from .cursor import BufferedRowResultProxy as BufferedRowResultProxy from .cursor import CursorResult as CursorResult -from .cursor import FullyBufferedResultProxy as FullyBufferedResultProxy from .cursor import ResultProxy as ResultProxy from .interfaces import AdaptedConnection as AdaptedConnection from .interfaces import BindTyping as BindTyping diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 37faa880ec..d8009e26c6 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -23,6 +23,7 @@ from typing import Tuple from typing import Type from typing import Union +from .interfaces import _IsolationLevel from .interfaces import BindTyping from .interfaces import ConnectionEventsTarget from .interfaces import DBAPICursor @@ -510,7 +511,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): self._handle_dbapi_exception(e, None, None, None, None) @property - def default_isolation_level(self) -> str: + def default_isolation_level(self) -> Optional[_IsolationLevel]: """The default isolation level assigned to this :class:`_engine.Connection`. diff --git a/lib/sqlalchemy/engine/cursor.py b/lib/sqlalchemy/engine/cursor.py index 78805bac1b..821c0cb8e3 100644 --- a/lib/sqlalchemy/engine/cursor.py +++ b/lib/sqlalchemy/engine/cursor.py @@ -23,8 +23,9 @@ from typing import List from typing import Optional from typing import Sequence from typing import Tuple -from typing import Type +from typing import Union +from .result import MergedResult from .result import Result from .result import ResultMetaData from .result import SimpleResultMetaData @@ -36,10 +37,12 @@ from ..sql import elements from ..sql import sqltypes from ..sql import util as sql_util from ..sql.base import _generative +from ..sql.compiler import ResultColumnsEntry from ..sql.compiler import RM_NAME from ..sql.compiler import RM_OBJECTS from ..sql.compiler import RM_RENDERED_NAME from ..sql.compiler import RM_TYPE +from ..sql.type_api import TypeEngine from ..util import compat from ..util.typing import Literal @@ -101,6 +104,7 @@ class CursorResultMetaData(ResultMetaData): _keymap_by_result_column_idx: Optional[Dict[int, _KeyMapRecType]] _unpickled: bool _safe_for_cache: bool + _translated_indexes: Optional[List[int]] returns_rows: ClassVar[bool] = True @@ -123,7 +127,6 @@ class CursorResultMetaData(ResultMetaData): if self._translated_indexes: indexes = [self._translated_indexes[idx] for idx in indexes] - tup = tuplegetter(*indexes) new_metadata = self.__class__.__new__(self.__class__) @@ -526,7 +529,7 @@ class CursorResultMetaData(ResultMetaData): def _merge_textual_cols_by_position( self, context, cursor_description, result_columns ): - num_ctx_cols = len(result_columns) if result_columns else None + num_ctx_cols = len(result_columns) if num_ctx_cols > len(cursor_description): util.warn( @@ -568,6 +571,8 @@ class CursorResultMetaData(ResultMetaData): match_map = self._create_description_match_map( result_columns, loose_column_name_matching ) + mapped_type: TypeEngine[Any] + for ( idx, colname, @@ -597,15 +602,17 @@ class CursorResultMetaData(ResultMetaData): @classmethod def _create_description_match_map( cls, - result_columns, - loose_column_name_matching=False, - ): + result_columns: List[ResultColumnsEntry], + loose_column_name_matching: bool = False, + ) -> Dict[Union[str, object], Tuple[str, List[Any], TypeEngine[Any], int]]: """when matching cursor.description to a set of names that are present in a Compiled object, as is the case with TextualSelect, get all the names we expect might match those in cursor.description. """ - d = {} + d: Dict[ + Union[str, object], Tuple[str, List[Any], TypeEngine[Any], int] + ] = {} for ridx, elem in enumerate(result_columns): key = elem[RM_RENDERED_NAME] @@ -630,7 +637,6 @@ class CursorResultMetaData(ResultMetaData): r_key, (elem[RM_NAME], elem[RM_OBJECTS], elem[RM_TYPE], ridx), ) - return d def _merge_cols_by_none(self, context, cursor_description): @@ -739,7 +745,9 @@ class CursorResultMetaData(ResultMetaData): self._keys = state["_keys"] self._unpickled = True if state["_translated_indexes"]: - self._translated_indexes = state["_translated_indexes"] + self._translated_indexes = cast( + "List[int]", state["_translated_indexes"] + ) self._tuplefilter = tuplegetter(*self._translated_indexes) else: self._translated_indexes = self._tuplefilter = None @@ -1144,12 +1152,32 @@ class _NoResultMetaData(ResultMetaData): _NO_RESULT_METADATA = _NoResultMetaData() -class BaseCursorResult: - """Base class for database result objects.""" +class CursorResult(Result): + """A Result that is representing state from a DBAPI cursor. + + .. versionchanged:: 1.4 The :class:`.CursorResult`` + class replaces the previous :class:`.ResultProxy` interface. + This classes are based on the :class:`.Result` calling API + which provides an updated usage model and calling facade for + SQLAlchemy Core and SQLAlchemy ORM. + + Returns database rows via the :class:`.Row` class, which provides + additional API features and behaviors on top of the raw data returned by + the DBAPI. Through the use of filters such as the :meth:`.Result.scalars` + method, other kinds of objects may also be returned. + + .. seealso:: + + :ref:`coretutorial_selecting` - introductory material for accessing + :class:`_engine.CursorResult` and :class:`.Row` objects. - _metadata: ResultMetaData + """ + + _metadata: Union[CursorResultMetaData, _NoResultMetaData] + _no_result_metadata = _NO_RESULT_METADATA _soft_closed: bool = False closed: bool = False + _is_cursor = True def __init__(self, context, cursor_strategy, cursor_description): self.context = context @@ -1169,11 +1197,11 @@ class BaseCursorResult: if echo: log = self.context.connection._log_debug - def log_row(row): + def _log_row(row): log("Row %r", sql_util._repr_row(row)) return row - self._row_logging_fn = log_row + self._row_logging_fn = log_row = _log_row else: log_row = None @@ -1188,13 +1216,16 @@ class BaseCursorResult: ) if log_row: - def make_row(row): + def _make_row_2(row): made_row = _make_row(row) + assert log_row is not None log_row(made_row) return made_row + make_row = _make_row_2 else: make_row = _make_row + self._set_memoized_attribute("_row_getter", make_row) else: @@ -1208,7 +1239,7 @@ class BaseCursorResult: if compiled._cached_metadata: metadata = compiled._cached_metadata else: - metadata = self._cursor_metadata(self, cursor_description) + metadata = CursorResultMetaData(self, cursor_description) if metadata._safe_for_cache: compiled._cached_metadata = metadata @@ -1239,7 +1270,7 @@ class BaseCursorResult: self._metadata = metadata else: - self._metadata = metadata = self._cursor_metadata( + self._metadata = metadata = CursorResultMetaData( self, cursor_description ) if self._echo: @@ -1669,33 +1700,6 @@ class BaseCursorResult: """ return self.context.isinsert - -class CursorResult(BaseCursorResult, Result): - """A Result that is representing state from a DBAPI cursor. - - .. versionchanged:: 1.4 The :class:`.CursorResult`` - class replaces the previous :class:`.ResultProxy` interface. - This classes are based on the :class:`.Result` calling API - which provides an updated usage model and calling facade for - SQLAlchemy Core and SQLAlchemy ORM. - - Returns database rows via the :class:`.Row` class, which provides - additional API features and behaviors on top of the raw data returned by - the DBAPI. Through the use of filters such as the :meth:`.Result.scalars` - method, other kinds of objects may also be returned. - - .. seealso:: - - :ref:`coretutorial_selecting` - introductory material for accessing - :class:`_engine.CursorResult` and :class:`.Row` objects. - - """ - - _cursor_metadata: Type[ResultMetaData] = CursorResultMetaData - _cursor_strategy_cls = CursorFetchStrategy - _no_result_metadata = _NO_RESULT_METADATA - _is_cursor = True - def _fetchiter_impl(self): fetchone = self.cursor_strategy.fetchone @@ -1717,12 +1721,13 @@ class CursorResult(BaseCursorResult, Result): def _raw_row_iterator(self): return self._fetchiter_impl() - def merge(self, *others): - merged_result = super(CursorResult, self).merge(*others) + def merge(self, *others: Result) -> MergedResult: + merged_result = super().merge(*others) setup_rowcounts = not self._metadata.returns_rows if setup_rowcounts: merged_result.rowcount = sum( - result.rowcount for result in (self,) + others + cast(CursorResult, result).rowcount + for result in (self,) + others ) return merged_result @@ -1756,40 +1761,3 @@ class CursorResult(BaseCursorResult, Result): ResultProxy = CursorResult - - -class BufferedRowResultProxy(ResultProxy): - """A ResultProxy with row buffering behavior. - - .. deprecated:: 1.4 this class is now supplied using a strategy object. - See :class:`.BufferedRowCursorFetchStrategy`. - - """ - - _cursor_strategy_cls: Type[ - CursorFetchStrategy - ] = BufferedRowCursorFetchStrategy - - -class FullyBufferedResultProxy(ResultProxy): - """A result proxy that buffers rows fully upon creation. - - .. deprecated:: 1.4 this class is now supplied using a strategy object. - See :class:`.FullyBufferedCursorFetchStrategy`. - - """ - - _cursor_strategy_cls = FullyBufferedCursorFetchStrategy - - -class BufferedColumnRow(Row): - """Row is now BufferedColumn in all cases""" - - -class BufferedColumnResultProxy(ResultProxy): - """A ResultProxy with column buffering behavior. - - .. versionchanged:: 1.4 This is now the default behavior of the Row - and this class does not change behavior in any way. - - """ diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 0e0c76389a..2579f573c5 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -55,6 +55,9 @@ from ..sql.compiler import SQLCompiler from ..sql.elements import quoted_name if typing.TYPE_CHECKING: + from .base import Connection + from .base import Engine + from .characteristics import ConnectionCharacteristic from .interfaces import _AnyMultiExecuteParams from .interfaces import _CoreMultiExecuteParams from .interfaces import _CoreSingleExecuteParams @@ -62,6 +65,7 @@ if typing.TYPE_CHECKING: from .interfaces import _DBAPIMultiExecuteParams from .interfaces import _DBAPISingleExecuteParams from .interfaces import _ExecuteOptions + from .interfaces import _MutableCoreSingleExecuteParams from .result import _ProcessorType from .row import Row from .url import URL @@ -71,6 +75,7 @@ if typing.TYPE_CHECKING: from ..sql import Executable from ..sql.compiler import Compiled from ..sql.compiler import ResultColumnsEntry + from ..sql.compiler import TypeCompiler from ..sql.schema import Column from ..sql.type_api import TypeEngine @@ -92,7 +97,11 @@ class DefaultDialect(Dialect): statement_compiler = compiler.SQLCompiler ddl_compiler = compiler.DDLCompiler - type_compiler = compiler.GenericTypeCompiler # type: ignore + if typing.TYPE_CHECKING: + type_compiler: TypeCompiler + else: + type_compiler = compiler.GenericTypeCompiler + preparer = compiler.IdentifierPreparer supports_alter = True supports_comments = False @@ -202,7 +211,7 @@ class DefaultDialect(Dialect): server_version_info = None - default_schema_name = None + default_schema_name: Optional[str] = None # indicates symbol names are # UPPERCASEd if they are case insensitive @@ -290,7 +299,12 @@ class DefaultDialect(Dialect): self.positional = self.paramstyle in ("qmark", "format", "numeric") self.identifier_preparer = self.preparer(self) self._on_connect_isolation_level = isolation_level - self.type_compiler = self.type_compiler(self) + + tt_callable = cast( + Type[compiler.GenericTypeCompiler], + self.type_compiler, + ) + self.type_compiler = tt_callable(self) if supports_native_boolean is not None: self.supports_native_boolean = supports_native_boolean @@ -490,12 +504,14 @@ class DefaultDialect(Dialect): opts.update(url.query) return [[], opts] - def set_engine_execution_options(self, engine, opts): + def set_engine_execution_options( + self, engine: Engine, opts: Mapping[str, str] + ) -> None: supported_names = set(self.connection_characteristics).intersection( opts ) if supported_names: - characteristics = util.immutabledict( + characteristics: Mapping[str, str] = util.immutabledict( (name, opts[name]) for name in supported_names ) @@ -505,12 +521,14 @@ class DefaultDialect(Dialect): connection, characteristics ) - def set_connection_execution_options(self, connection, opts): + def set_connection_execution_options( + self, connection: Connection, opts: Mapping[str, str] + ) -> None: supported_names = set(self.connection_characteristics).intersection( opts ) if supported_names: - characteristics = util.immutabledict( + characteristics: Mapping[str, str] = util.immutabledict( (name, opts[name]) for name in supported_names ) self._set_connection_characteristics(connection, characteristics) @@ -800,7 +818,7 @@ class DefaultExecutionContext(ExecutionContext): dialect: Dialect unicode_statement: str cursor: DBAPICursor - compiled_parameters: _CoreMultiExecuteParams + compiled_parameters: List[_MutableCoreSingleExecuteParams] parameters: _DBAPIMultiExecuteParams extracted_parameters: _CoreSingleExecuteParams @@ -1157,7 +1175,11 @@ class DefaultExecutionContext(ExecutionContext): parameters = {} conn._cursor_execute(self.cursor, stmt, parameters, context=self) - r = self.cursor.fetchone()[0] + row = self.cursor.fetchone() + if row is not None: + r = row[0] + else: + r = None if type_ is not None: # apply type post processors to the result proc = type_._cached_result_processor( @@ -1299,10 +1321,11 @@ class DefaultExecutionContext(ExecutionContext): result = _cursor.CursorResult(self, strategy, cursor_description) + compiled = self.compiled if ( - self.compiled + compiled and not self.isddl - and self.compiled.has_out_parameters + and cast(SQLCompiler, compiled).has_out_parameters ): self._setup_out_parameters(result) @@ -1311,10 +1334,11 @@ class DefaultExecutionContext(ExecutionContext): return result def _setup_out_parameters(self, result): + compiled = cast(SQLCompiler, self.compiled) out_bindparams = [ (param, name) - for param, name in self.compiled.bind_names.items() + for param, name in compiled.bind_names.items() if param.isoutparam ] out_parameters = {} @@ -1339,9 +1363,10 @@ class DefaultExecutionContext(ExecutionContext): result.out_parameters = out_parameters def _setup_dml_or_text_result(self): + compiled = cast(SQLCompiler, self.compiled) if self.isinsert: - if self.compiled.postfetch_lastrowid: + if compiled.postfetch_lastrowid: self.inserted_primary_key_rows = ( self._setup_ins_pk_from_lastrowid() ) @@ -1397,7 +1422,8 @@ class DefaultExecutionContext(ExecutionContext): result.rowcount row = result.fetchone() - self.returned_default_rows = [row] + if row is not None: + self.returned_default_rows = [row] result._soft_close() @@ -1420,13 +1446,17 @@ class DefaultExecutionContext(ExecutionContext): return self._setup_ins_pk_from_empty() def _setup_ins_pk_from_lastrowid(self): - getter = self.compiled._inserted_primary_key_from_lastrowid_getter + getter = cast( + SQLCompiler, self.compiled + )._inserted_primary_key_from_lastrowid_getter lastrowid = self.get_lastrowid() return [getter(lastrowid, self.compiled_parameters[0])] def _setup_ins_pk_from_empty(self): - getter = self.compiled._inserted_primary_key_from_lastrowid_getter + getter = cast( + SQLCompiler, self.compiled + )._inserted_primary_key_from_lastrowid_getter return [getter(None, param) for param in self.compiled_parameters] def _setup_ins_pk_from_implicit_returning(self, result, rows): @@ -1434,7 +1464,9 @@ class DefaultExecutionContext(ExecutionContext): if not rows: return [] - getter = self.compiled._inserted_primary_key_from_returning_getter + getter = cast( + SQLCompiler, self.compiled + )._inserted_primary_key_from_returning_getter compiled_params = self.compiled_parameters return [ @@ -1443,7 +1475,7 @@ class DefaultExecutionContext(ExecutionContext): def lastrow_has_defaults(self): return (self.isinsert or self.isupdate) and bool( - self.compiled.postfetch + cast(SQLCompiler, self.compiled).postfetch ) def _set_input_sizes(self): @@ -1464,7 +1496,7 @@ class DefaultExecutionContext(ExecutionContext): if self.isddl or self.is_text: return - compiled = self.compiled + compiled = cast(SQLCompiler, self.compiled) inputsizes = compiled._get_set_input_sizes_lookup() @@ -1487,7 +1519,8 @@ class DefaultExecutionContext(ExecutionContext): if dialect.positional: items = [ - (key, compiled.binds[key]) for key in compiled.positiontup + (key, compiled.binds[key]) + for key in compiled.positiontup or () ] else: items = [ @@ -1495,7 +1528,7 @@ class DefaultExecutionContext(ExecutionContext): for bindparam, key in compiled.bind_names.items() ] - generic_inputsizes = [] + generic_inputsizes: List[Tuple[str, Any, TypeEngine[Any]]] = [] for key, bindparam in items: if bindparam in compiled.literal_execute_params: continue @@ -1578,20 +1611,19 @@ class DefaultExecutionContext(ExecutionContext): compiled_params = compiled.construct_params() processors = compiled._bind_processors if compiled.positional: - positiontup = compiled.positiontup parameters = self.dialect.execute_sequence_format( [ - processors[key](compiled_params[key]) + processors[key](compiled_params[key]) # type: ignore if key in processors else compiled_params[key] - for key in positiontup + for key in compiled.positiontup or () ] ) else: parameters = dict( ( key, - processors[key](compiled_params[key]) + processors[key](compiled_params[key]) # type: ignore if key in processors else compiled_params[key], ) @@ -1667,15 +1699,18 @@ class DefaultExecutionContext(ExecutionContext): "get_current_parameters() can only be invoked in the " "context of a Python side column default function" ) - - compile_state = self.compiled.compile_state + else: + assert column is not None + assert parameters is not None + compile_state = cast(SQLCompiler, self.compiled).compile_state + assert compile_state is not None if ( isolate_multiinsert_groups and self.isinsert and compile_state._has_multi_parameters ): if column._is_multiparam_column: - index = column.index + 1 + index = column.index + 1 # type: ignore d = {column.original.key: parameters[column.key]} else: d = {column.key: parameters[column.key]} @@ -1701,12 +1736,14 @@ class DefaultExecutionContext(ExecutionContext): return self._exec_default(column, column.onupdate, column.type) def _process_executemany_defaults(self): - key_getter = self.compiled._within_exec_param_key_getter + compiled = cast(SQLCompiler, self.compiled) - scalar_defaults = {} + key_getter = compiled._within_exec_param_key_getter - insert_prefetch = self.compiled.insert_prefetch - update_prefetch = self.compiled.update_prefetch + scalar_defaults: Dict[Column[Any], Any] = {} + + insert_prefetch = compiled.insert_prefetch + update_prefetch = compiled.update_prefetch # pre-determine scalar Python-side defaults # to avoid many calls of get_insert_default()/ @@ -1739,12 +1776,14 @@ class DefaultExecutionContext(ExecutionContext): del self.current_parameters def _process_executesingle_defaults(self): - key_getter = self.compiled._within_exec_param_key_getter + compiled = cast(SQLCompiler, self.compiled) + + key_getter = compiled._within_exec_param_key_getter self.current_parameters = ( compiled_parameters ) = self.compiled_parameters[0] - for c in self.compiled.insert_prefetch: + for c in compiled.insert_prefetch: if c.default and not c.default.is_sequence and c.default.is_scalar: val = c.default.arg else: @@ -1753,7 +1792,7 @@ class DefaultExecutionContext(ExecutionContext): if val is not None: compiled_parameters[key_getter(c)] = val - for c in self.compiled.update_prefetch: + for c in compiled.update_prefetch: val = self.get_update_default(c) if val is not None: diff --git a/lib/sqlalchemy/engine/interfaces.py b/lib/sqlalchemy/engine/interfaces.py index 5aefcf5b56..e65546eb77 100644 --- a/lib/sqlalchemy/engine/interfaces.py +++ b/lib/sqlalchemy/engine/interfaces.py @@ -36,7 +36,6 @@ from ..sql.compiler import TypeCompiler as TypeCompiler from ..sql.compiler import TypeCompiler # noqa from ..util import immutabledict from ..util.concurrency import await_only -from ..util.typing import _TypeToInstance from ..util.typing import Literal from ..util.typing import NotRequired from ..util.typing import Protocol @@ -58,6 +57,8 @@ if TYPE_CHECKING: from ..sql.elements import ClauseElement from ..sql.schema import Column from ..sql.schema import ColumnDefault + from ..sql.schema import Sequence as Sequence_SchemaItem + from ..sql.sqltypes import Integer from ..sql.type_api import TypeEngine ConnectArgsType = Tuple[Tuple[str], MutableMapping[str, Any]] @@ -156,6 +157,8 @@ class DBAPICursor(Protocol): arraysize: int + lastrowid: int + def close(self) -> None: ... @@ -196,6 +199,7 @@ class DBAPICursor(Protocol): _CoreSingleExecuteParams = Mapping[str, Any] +_MutableCoreSingleExecuteParams = MutableMapping[str, Any] _CoreMultiExecuteParams = Sequence[_CoreSingleExecuteParams] _CoreAnyExecuteParams = Union[ _CoreMultiExecuteParams, _CoreSingleExecuteParams @@ -605,7 +609,7 @@ class Dialect(EventTarget): ddl_compiler: Type[DDLCompiler] """a :class:`.Compiled` class used to compile DDL statements""" - type_compiler: _TypeToInstance[TypeCompiler] + type_compiler: Union[Type[TypeCompiler], TypeCompiler] """a :class:`.Compiled` class used to compile SQL type objects""" preparer: Type[IdentifierPreparer] @@ -633,7 +637,7 @@ class Dialect(EventTarget): """ - default_isolation_level: _IsolationLevel + default_isolation_level: Optional[_IsolationLevel] """the isolation that is implicitly present on new connections""" execution_ctx_cls: Type["ExecutionContext"] @@ -653,6 +657,13 @@ class Dialect(EventTarget): max_identifier_length: int """The maximum length of identifier names.""" + supports_server_side_cursors: bool + """indicates if the dialect supports server side cursors""" + + server_side_cursors: bool + """deprecated; indicates if the dialect should attempt to use server + side cursors by default""" + supports_sane_rowcount: bool """Indicate whether the dialect properly implements rowcount for ``UPDATE`` and ``DELETE`` statements. @@ -2302,6 +2313,11 @@ class ExecutionContext: def _setup_result_proxy(self) -> Result: raise NotImplementedError() + def fire_sequence(self, seq: Sequence_SchemaItem, type_: Integer) -> int: + """given a :class:`.Sequence`, invoke it and return the next int + value""" + raise NotImplementedError() + def create_cursor(self) -> DBAPICursor: """Return a new cursor generated from this ExecutionContext's connection. diff --git a/lib/sqlalchemy/engine/result.py b/lib/sqlalchemy/engine/result.py index 0951d57702..87d3cac1c7 100644 --- a/lib/sqlalchemy/engine/result.py +++ b/lib/sqlalchemy/engine/result.py @@ -1880,6 +1880,7 @@ class MergedResult(IteratorResult): """ closed = False + rowcount: Optional[int] def __init__( self, cursor_metadata: ResultMetaData, results: Sequence[Result] diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 09e38a5ab9..423c3d446e 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -34,6 +34,7 @@ import re from time import perf_counter import typing from typing import Any +from typing import Callable from typing import Dict from typing import List from typing import Mapping @@ -629,11 +630,11 @@ class SQLCompiler(Compiled): """list of columns that can be post-fetched after INSERT or UPDATE to receive server-updated values""" - insert_prefetch: Optional[List[Column[Any]]] + insert_prefetch: Sequence[Column[Any]] = () """list of columns for which default values should be evaluated before an INSERT takes place""" - update_prefetch: Optional[List[Column[Any]]] + update_prefetch: Sequence[Column[Any]] = () """list of columns for which onupdate default values should be evaluated before an UPDATE takes place""" @@ -739,8 +740,6 @@ class SQLCompiler(Compiled): """if True, there are bindparam() objects that have the isoutparam flag set.""" - insert_prefetch = update_prefetch = () - postfetch_lastrowid = False """if True, and this in insert, use cursor.lastrowid to populate result.inserted_primary_key. """ @@ -1340,7 +1339,7 @@ class SQLCompiler(Compiled): ) @util.memoized_property - def _within_exec_param_key_getter(self): + def _within_exec_param_key_getter(self) -> Callable[[Any], str]: getter = self._key_getters_for_crud_column[2] if self.escaped_bind_names: diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 4c38c4efab..168da17ccc 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -58,12 +58,13 @@ from ..util.langhelpers import TypingOnly if typing.TYPE_CHECKING: from decimal import Decimal + from .compiler import Compiled + from .compiler import SQLCompiler from .operators import OperatorType from .selectable import FromClause from .selectable import Select from .sqltypes import Boolean # noqa from .type_api import TypeEngine - from ..engine import Compiled from ..engine import Connection from ..engine import Dialect from ..engine import Engine @@ -573,6 +574,25 @@ class ClauseElement( ) +class DQLDMLClauseElement(ClauseElement): + """represents a :class:`.ClauseElement` that compiles to a DQL or DML + expression, not DDL. + + .. versionadded:: 2.0 + + """ + + if typing.TYPE_CHECKING: + + def compile( # noqa: A001 + self, + bind: Optional[Union[Engine, Connection]] = None, + dialect: Optional[Dialect] = None, + **kw: Any, + ) -> SQLCompiler: + ... + + class CompilerColumnElement( roles.DMLColumnRole, roles.DDLConstraintColumnRole, @@ -955,7 +975,7 @@ class ColumnElement( roles.DDLExpressionRole, SQLCoreOperations[_T], operators.ColumnOperators[SQLCoreOperations], - ClauseElement, + DQLDMLClauseElement, ): """Represent a column-oriented SQL expression suitable for usage in the "columns" clause, WHERE clause etc. of a statement. @@ -1820,7 +1840,7 @@ class BindParameter(roles.InElementRole, ColumnElement[_T]): ) -class TypeClause(ClauseElement): +class TypeClause(DQLDMLClauseElement): """Handle a type keyword in a SQL statement. Used by the ``Case`` statement. @@ -1849,7 +1869,7 @@ class TextClause( roles.BinaryElementRole, roles.InElementRole, Executable, - ClauseElement, + DQLDMLClauseElement, ): """Represent a literal SQL text fragment. @@ -2285,7 +2305,7 @@ class ClauseList( roles.OrderByRole, roles.ColumnsClauseRole, roles.DMLColumnRole, - ClauseElement, + DQLDMLClauseElement, ): """Describe a list of clauses, separated by an operator. @@ -3205,7 +3225,7 @@ class IndexExpression(BinaryExpression): inherit_cache = True -class GroupedElement(ClauseElement): +class GroupedElement(DQLDMLClauseElement): """Represent any parenthesized expression""" __visit_name__ = "grouping" diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index fdae4d7b04..c270e15648 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -1131,6 +1131,7 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]): __visit_name__ = "column" inherit_cache = True + key: str @overload def __init__( diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index a5cbffb5e1..e5c2bef686 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -62,6 +62,7 @@ from .elements import ClauseElement from .elements import ClauseList from .elements import ColumnClause from .elements import ColumnElement +from .elements import DQLDMLClauseElement from .elements import GroupedElement from .elements import Grouping from .elements import literal_column @@ -85,7 +86,7 @@ class _OffsetLimitParam(BindParameter): return self.effective_value -class ReturnsRows(roles.ReturnsRowsRole, ClauseElement): +class ReturnsRows(roles.ReturnsRowsRole, DQLDMLClauseElement): """The base-most class for Core constructs that have some concept of columns that can represent rows. diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py index 45e31aaf7f..b0df99c415 100644 --- a/lib/sqlalchemy/types.py +++ b/lib/sqlalchemy/types.py @@ -12,117 +12,63 @@ from __future__ import annotations -from .sql.sqltypes import _Binary -from .sql.sqltypes import ARRAY -from .sql.sqltypes import BIGINT -from .sql.sqltypes import BigInteger -from .sql.sqltypes import BINARY -from .sql.sqltypes import BLOB -from .sql.sqltypes import BOOLEAN -from .sql.sqltypes import Boolean -from .sql.sqltypes import CHAR -from .sql.sqltypes import CLOB -from .sql.sqltypes import Concatenable -from .sql.sqltypes import DATE -from .sql.sqltypes import Date -from .sql.sqltypes import DATETIME -from .sql.sqltypes import DateTime -from .sql.sqltypes import DECIMAL -from .sql.sqltypes import DOUBLE -from .sql.sqltypes import Double -from .sql.sqltypes import DOUBLE_PRECISION -from .sql.sqltypes import Enum -from .sql.sqltypes import FLOAT -from .sql.sqltypes import Float -from .sql.sqltypes import Indexable -from .sql.sqltypes import INT -from .sql.sqltypes import INTEGER -from .sql.sqltypes import Integer -from .sql.sqltypes import Interval -from .sql.sqltypes import JSON -from .sql.sqltypes import LargeBinary -from .sql.sqltypes import MatchType -from .sql.sqltypes import NCHAR -from .sql.sqltypes import NULLTYPE -from .sql.sqltypes import NullType -from .sql.sqltypes import NUMERIC -from .sql.sqltypes import Numeric -from .sql.sqltypes import NVARCHAR -from .sql.sqltypes import PickleType -from .sql.sqltypes import REAL -from .sql.sqltypes import SchemaType -from .sql.sqltypes import SMALLINT -from .sql.sqltypes import SmallInteger -from .sql.sqltypes import String -from .sql.sqltypes import STRINGTYPE -from .sql.sqltypes import TEXT -from .sql.sqltypes import Text -from .sql.sqltypes import TIME -from .sql.sqltypes import Time -from .sql.sqltypes import TIMESTAMP -from .sql.sqltypes import TupleType -from .sql.sqltypes import Unicode -from .sql.sqltypes import UnicodeText -from .sql.sqltypes import VARBINARY -from .sql.sqltypes import VARCHAR -from .sql.type_api import adapt_type -from .sql.type_api import ExternalType -from .sql.type_api import to_instance -from .sql.type_api import TypeDecorator -from .sql.type_api import TypeEngine -from .sql.type_api import UserDefinedType -from .sql.type_api import Variant - -__all__ = [ - "TypeEngine", - "TypeDecorator", - "UserDefinedType", - "ExternalType", - "INT", - "CHAR", - "VARCHAR", - "NCHAR", - "NVARCHAR", - "TEXT", - "Text", - "FLOAT", - "NUMERIC", - "REAL", - "DECIMAL", - "TIMESTAMP", - "DATETIME", - "CLOB", - "BLOB", - "BINARY", - "VARBINARY", - "BOOLEAN", - "BIGINT", - "SMALLINT", - "INTEGER", - "DATE", - "TIME", - "TupleType", - "String", - "Integer", - "SmallInteger", - "BigInteger", - "Numeric", - "Float", - "Double", - "DOUBLE", - "DOUBLE_PRECISION", - "DateTime", - "Date", - "Time", - "LargeBinary", - "Boolean", - "Unicode", - "Concatenable", - "UnicodeText", - "PickleType", - "Interval", - "Enum", - "Indexable", - "ARRAY", - "JSON", -] +from .sql.sqltypes import _Binary as _Binary +from .sql.sqltypes import ARRAY as ARRAY +from .sql.sqltypes import BIGINT as BIGINT +from .sql.sqltypes import BigInteger as BigInteger +from .sql.sqltypes import BINARY as BINARY +from .sql.sqltypes import BLOB as BLOB +from .sql.sqltypes import BOOLEAN as BOOLEAN +from .sql.sqltypes import Boolean as Boolean +from .sql.sqltypes import CHAR as CHAR +from .sql.sqltypes import CLOB as CLOB +from .sql.sqltypes import Concatenable as Concatenable +from .sql.sqltypes import DATE as DATE +from .sql.sqltypes import Date as Date +from .sql.sqltypes import DATETIME as DATETIME +from .sql.sqltypes import DateTime as DateTime +from .sql.sqltypes import DECIMAL as DECIMAL +from .sql.sqltypes import DOUBLE as DOUBLE +from .sql.sqltypes import Double as Double +from .sql.sqltypes import DOUBLE_PRECISION as DOUBLE_PRECISION +from .sql.sqltypes import Enum as Enum +from .sql.sqltypes import FLOAT as FLOAT +from .sql.sqltypes import Float as Float +from .sql.sqltypes import Indexable as Indexable +from .sql.sqltypes import INT as INT +from .sql.sqltypes import INTEGER as INTEGER +from .sql.sqltypes import Integer as Integer +from .sql.sqltypes import Interval as Interval +from .sql.sqltypes import JSON as JSON +from .sql.sqltypes import LargeBinary as LargeBinary +from .sql.sqltypes import MatchType as MatchType +from .sql.sqltypes import NCHAR as NCHAR +from .sql.sqltypes import NULLTYPE as NULLTYPE +from .sql.sqltypes import NullType as NullType +from .sql.sqltypes import NUMERIC as NUMERIC +from .sql.sqltypes import Numeric as Numeric +from .sql.sqltypes import NVARCHAR as NVARCHAR +from .sql.sqltypes import PickleType as PickleType +from .sql.sqltypes import REAL as REAL +from .sql.sqltypes import SchemaType as SchemaType +from .sql.sqltypes import SMALLINT as SMALLINT +from .sql.sqltypes import SmallInteger as SmallInteger +from .sql.sqltypes import String as String +from .sql.sqltypes import STRINGTYPE as STRINGTYPE +from .sql.sqltypes import TEXT as TEXT +from .sql.sqltypes import Text as Text +from .sql.sqltypes import TIME as TIME +from .sql.sqltypes import Time as Time +from .sql.sqltypes import TIMESTAMP as TIMESTAMP +from .sql.sqltypes import TupleType as TupleType +from .sql.sqltypes import Unicode as Unicode +from .sql.sqltypes import UnicodeText as UnicodeText +from .sql.sqltypes import VARBINARY as VARBINARY +from .sql.sqltypes import VARCHAR as VARCHAR +from .sql.type_api import adapt_type as adapt_type +from .sql.type_api import ExternalType as ExternalType +from .sql.type_api import to_instance as to_instance +from .sql.type_api import TypeDecorator as TypeDecorator +from .sql.type_api import TypeEngine as TypeEngine +from .sql.type_api import UserDefinedType as UserDefinedType +from .sql.type_api import Variant as Variant diff --git a/lib/sqlalchemy/util/_collections.py b/lib/sqlalchemy/util/_collections.py index e0b53b4450..06a009c5b6 100644 --- a/lib/sqlalchemy/util/_collections.py +++ b/lib/sqlalchemy/util/_collections.py @@ -34,6 +34,7 @@ import weakref from ._has_cy import HAS_CYEXTENSION from .typing import Literal +from .typing import Protocol if typing.TYPE_CHECKING or not HAS_CYEXTENSION: from ._py_collections import immutabledict as immutabledict @@ -62,7 +63,7 @@ else: _T = TypeVar("_T", bound=Any) _KT = TypeVar("_KT", bound=Any) _VT = TypeVar("_VT", bound=Any) - +_T_co = TypeVar("_T_co", covariant=True) EMPTY_SET: FrozenSet[Any] = frozenset() @@ -597,7 +598,17 @@ class LRUCache(typing.MutableMapping[_KT, _VT]): self._mutex.release() -class ScopedRegistry: +class _CreateFuncType(Protocol[_T_co]): + def __call__(self) -> _T_co: + ... + + +class _ScopeFuncType(Protocol): + def __call__(self) -> Any: + ... + + +class ScopedRegistry(Generic[_T]): """A Registry that can store one or multiple instances of a single class on the basis of a "scope" function. @@ -614,6 +625,10 @@ class ScopedRegistry: __slots__ = "createfunc", "scopefunc", "registry" + createfunc: _CreateFuncType[_T] + scopefunc: _ScopeFuncType + registry: Any + def __init__(self, createfunc, scopefunc): """Construct a new :class:`.ScopedRegistry`. @@ -629,24 +644,24 @@ class ScopedRegistry: self.scopefunc = scopefunc self.registry = {} - def __call__(self): + def __call__(self) -> _T: key = self.scopefunc() try: - return self.registry[key] + return self.registry[key] # type: ignore[no-any-return] except KeyError: - return self.registry.setdefault(key, self.createfunc()) + return self.registry.setdefault(key, self.createfunc()) # type: ignore[no-any-return] # noqa: E501 - def has(self): + def has(self) -> bool: """Return True if an object is present in the current scope.""" return self.scopefunc() in self.registry - def set(self, obj): + def set(self, obj: _T) -> None: """Set the value for the current scope.""" self.registry[self.scopefunc()] = obj - def clear(self): + def clear(self) -> None: """Clear the current scope, if any.""" try: @@ -655,32 +670,32 @@ class ScopedRegistry: pass -class ThreadLocalRegistry(ScopedRegistry): +class ThreadLocalRegistry(ScopedRegistry[_T]): """A :class:`.ScopedRegistry` that uses a ``threading.local()`` variable for storage. """ - def __init__(self, createfunc): + def __init__(self, createfunc: Callable[[], _T]): self.createfunc = createfunc self.registry = threading.local() - def __call__(self): + def __call__(self) -> _T: try: - return self.registry.value + return self.registry.value # type: ignore[no-any-return] except AttributeError: val = self.registry.value = self.createfunc() - return val + return val # type: ignore[no-any-return] - def has(self): + def has(self) -> bool: return hasattr(self.registry, "value") - def set(self, obj): + def set(self, obj: _T) -> None: self.registry.value = obj - def clear(self): + def clear(self) -> None: try: - del self.registry.value # type: ignore + del self.registry.value except AttributeError: pass diff --git a/lib/sqlalchemy/util/_py_collections.py b/lib/sqlalchemy/util/_py_collections.py index 771e974e93..d503529303 100644 --- a/lib/sqlalchemy/util/_py_collections.py +++ b/lib/sqlalchemy/util/_py_collections.py @@ -11,6 +11,7 @@ from itertools import filterfalse from typing import AbstractSet from typing import Any from typing import cast +from typing import Collection from typing import Dict from typing import Iterable from typing import Iterator @@ -67,7 +68,9 @@ class immutabledict(ImmutableDictBase[_KT, _VT]): dict.__init__(new, *args) return new - def __init__(self, *args: Union[Mapping[_KT, _VT], Tuple[_KT, _VT]]): + def __init__( + self, *args: Union[Mapping[_KT, _VT], Iterable[Tuple[_KT, _VT]]] + ): pass def __reduce__(self): @@ -369,6 +372,8 @@ class IdentitySet: def difference(self, iterable): result = self.__new__(self.__class__) + other: Collection[Any] + if isinstance(iterable, self.__class__): other = iterable._members else: @@ -394,6 +399,9 @@ class IdentitySet: def intersection(self, iterable): result = self.__new__(self.__class__) + + other: Collection[Any] + if isinstance(iterable, self.__class__): other = iterable._members else: @@ -466,7 +474,7 @@ class IdentitySet: def unique_list(seq, hashfunc=None): - seen = set() + seen: Set[Any] = set() seen_add = seen.add if not hashfunc: return [x for x in seq if x not in seen and not seen_add(x)] diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py index 5674e19afe..8cb84f73f5 100644 --- a/lib/sqlalchemy/util/langhelpers.py +++ b/lib/sqlalchemy/util/langhelpers.py @@ -679,7 +679,7 @@ def create_proxy_methods( def decorate(cls): def instrument(name, clslevel=False): - fn = cast(Callable[..., Any], getattr(target_cls, name)) + fn = cast(types.FunctionType, getattr(target_cls, name)) spec = compat.inspect_getfullargspec(fn) env = {"__name__": fn.__module__} @@ -709,7 +709,7 @@ def create_proxy_methods( ) proxy_fn = cast( - Callable[..., Any], _exec_code_in_env(code, env, fn.__name__) + types.FunctionType, _exec_code_in_env(code, env, fn.__name__) ) proxy_fn.__defaults__ = getattr(fn, "__func__", fn).__defaults__ proxy_fn.__doc__ = inject_docstring_text( @@ -721,9 +721,9 @@ def create_proxy_methods( ) if clslevel: - proxy_fn = classmethod(proxy_fn) - - return proxy_fn + return classmethod(proxy_fn) + else: + return proxy_fn def makeprop(name): attr = target_cls.__dict__.get(name, None) @@ -824,7 +824,7 @@ def generic_repr(obj, additional_kw=(), to_inspect=None, omit_kwarg=()): missing = object() pos_args = [] - kw_args = _collections.OrderedDict() + kw_args: _collections.OrderedDict[str, Any] = _collections.OrderedDict() vargs = None for i, insp in enumerate(to_inspect): try: @@ -855,7 +855,7 @@ def generic_repr(obj, additional_kw=(), to_inspect=None, omit_kwarg=()): ) ] ) - output = [] + output: List[str] = [] output.extend(repr(getattr(obj, arg, None)) for arg in pos_args) @@ -1007,7 +1007,7 @@ def monkeypatch_proxied_specials( if not hasattr(maybe_fn, "__call__"): continue maybe_fn = getattr(maybe_fn, "__func__", maybe_fn) - fn = cast(Callable[..., Any], maybe_fn) + fn = cast(types.FunctionType, maybe_fn) except AttributeError: continue @@ -1024,7 +1024,9 @@ def monkeypatch_proxied_specials( "return %(name)s.%(method)s%(d_args)s" % locals() ) - env = from_instance is not None and {name: from_instance} or {} + env: Dict[str, types.FunctionType] = ( + from_instance is not None and {name: from_instance} or {} + ) exec(py, env) try: env[method].__defaults__ = fn.__defaults__ @@ -1482,6 +1484,7 @@ def dictlike_iteritems(dictlike): def iterator(): for key in dictlike.iterkeys(): + assert getter is not None yield key, getter(key) return iterator() @@ -1989,7 +1992,7 @@ def quoted_token_parser(value): # 0 = outside of quotes # 1 = inside of quotes state = 0 - result = [[]] + result: List[List[str]] = [[]] idx = 0 lv = len(value) while idx < lv: diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py index 291061561d..160eabd85f 100644 --- a/lib/sqlalchemy/util/typing.py +++ b/lib/sqlalchemy/util/typing.py @@ -7,8 +7,6 @@ from typing import Callable # noqa from typing import cast from typing import Dict from typing import ForwardRef -from typing import Generic -from typing import overload from typing import Type from typing import TypeVar from typing import Union @@ -58,35 +56,6 @@ else: from typing import ParamSpec as ParamSpec # noqa F401 -class _TypeToInstance(Generic[_T]): - """describe a variable that moves between a class and an instance of - that class. - - """ - - @overload - def __get__(self, instance: None, owner: Any) -> Type[_T]: - ... - - @overload - def __get__(self, instance: object, owner: Any) -> _T: - ... - - def __get__(self, instance: object, owner: Any) -> Union[Type[_T], _T]: - ... - - @overload - def __set__(self, instance: None, value: Type[_T]) -> None: - ... - - @overload - def __set__(self, instance: object, value: _T) -> None: - ... - - def __set__(self, instance: object, value: Union[Type[_T], _T]) -> None: - ... - - def de_stringify_annotation( cls: Type[Any], annotation: Union[str, Type[Any]] ) -> Union[str, Type[Any]]: diff --git a/pyproject.toml b/pyproject.toml index 963b546ed9..b90feae498 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -125,8 +125,8 @@ module = [ ignore_errors = false +# mostly strict without requiring totally untyped things to be +# typed +strict = true allow_untyped_defs = true -check_untyped_defs = false allow_untyped_calls = true - - diff --git a/test/engine/test_execute.py b/test/engine/test_execute.py index dbd957703f..8b950026f2 100644 --- a/test/engine/test_execute.py +++ b/test/engine/test_execute.py @@ -2983,7 +2983,7 @@ class HandleErrorTest(fixtures.TestBase): the_conn.append(connection) with mock.patch( - "sqlalchemy.engine.cursor.BaseCursorResult.__init__", + "sqlalchemy.engine.cursor.CursorResult.__init__", Mock(side_effect=tsa.exc.InvalidRequestError("duplicate col")), ): with engine.connect() as conn: @@ -3019,7 +3019,7 @@ class HandleErrorTest(fixtures.TestBase): conn = engine.connect() with mock.patch( - "sqlalchemy.engine.cursor.BaseCursorResult.__init__", + "sqlalchemy.engine.cursor.CursorResult.__init__", Mock(side_effect=tsa.exc.InvalidRequestError("duplicate col")), ): assert_raises( diff --git a/test/sql/test_types.py b/test/sql/test_types.py index da96f6c3a0..acf16565a0 100644 --- a/test/sql/test_types.py +++ b/test/sql/test_types.py @@ -141,7 +141,7 @@ class AdaptTest(fixtures.TestBase): def test_uppercase_importable(self, typ): if typ.__name__ == typ.__name__.upper(): assert getattr(sa, typ.__name__) is typ - assert typ.__name__ in types.__all__ + assert typ.__name__ in dir(types) @testing.combinations( ((d.name, d) for d in _all_dialects()), argnames="dialect", id_="ia"