From: Mike Bayer Date: Thu, 17 Feb 2022 18:43:04 +0000 (-0500) Subject: pep-484 for engine X-Git-Tag: rel_2_0_0b1~460^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=a4bb502cf95ea3523e4d383c4377e50f402d7d52;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git pep-484 for engine All modules in sqlalchemy.engine are strictly typed with the exception of cursor, default, and reflection. cursor and default pass with non-strict typing, reflection is waiting on the multi-reflection refactor. Behavioral changes: * create_connect_args() methods return a tuple of list, dict, rather than a list of list, dict * removed allow_chars parameter from pyodbc connector ._get_server_version_info() method * the parameter list passed to do_executemany is now a list in all cases. previously, this was being run through dialect.execute_sequence_format, which defaults to tuple and was only intended for individual tuple params. * broke up dialect.dbapi into dialect.import_dbapi class method and dialect.dbapi module object. added a deprecation path for legacy dialects. it's not really feasible to type a single attr as a classmethod vs. module type. The "type_compiler" attribute also has this problem with greater ability to work around, left that one for now. * lots of constants changing to be Enum, so that we can type them. for fixed tuple-position constants in cursor.py / compiler.py (which are used to avoid the speed overhead of namedtuple), using Literal[value] which seems to work well * some tightening up in Row regarding __getitem__, which we can do since we are on full 2.0 style result use * altered the set_connection_execution_options and set_engine_execution_options event flows so that the dictionary of options may be mutated within the event hook, where it will then take effect as the actual options used. Previously, changing the dict would be silently ignored which seems counter-intuitive and not very useful. * A lot of DefaultDialect/DefaultExecutionContext methods and attributes, including underscored ones, move to interfaces. This is not fully ideal as it means the Dialect/ExecutionContext interfaces aren't publicly subclassable directly, but their current purpose is more of documentation for dialect authors who should (and certainly are) still be subclassing the DefaultXYZ versions in all cases Overall, Result was the most extremely difficult class hierarchy to type here as this hierarchy passes through largely amorphous "row" datatypes throughout, which can in fact by all kinds of different things, like raw DBAPI rows, or Row objects, or "scalar"/Any, but at the same time these types have meaning so I tried still maintaining some level of semantic markings for these, it highlights how complex Result is now, as it's trying to be extremely efficient and inlined while also being very open-ended and extensible. Change-Id: I98b75c0c09eab5355fc7a33ba41dd9874274f12a --- diff --git a/doc/build/changelog/unreleased_20/eng_ex_opt.rst b/doc/build/changelog/unreleased_20/eng_ex_opt.rst new file mode 100644 index 0000000000..00947f3ded --- /dev/null +++ b/doc/build/changelog/unreleased_20/eng_ex_opt.rst @@ -0,0 +1,9 @@ +.. change:: + :tags: engine, feature + + The :meth:`.ConnectionEvents.set_connection_execution_options` + and :meth:`.ConnectionEvents.set_engine_execution_options` + event hooks now allow the given options dictionary to be modified + in-place, where the new contents will be received as the ultimate + execution options to be acted upon. Previously, in-place modifications to + the dictionary were not supported. diff --git a/doc/build/tutorial/data_insert.rst b/doc/build/tutorial/data_insert.rst index 74b0aff56c..a8b1a49a25 100644 --- a/doc/build/tutorial/data_insert.rst +++ b/doc/build/tutorial/data_insert.rst @@ -127,7 +127,7 @@ illustrate this: ... conn.commit() {opensql}BEGIN (implicit) INSERT INTO user_account (name, fullname) VALUES (?, ?) - [...] (('sandy', 'Sandy Cheeks'), ('patrick', 'Patrick Star')) + [...] [('sandy', 'Sandy Cheeks'), ('patrick', 'Patrick Star')] COMMIT{stop} The execution above features "executemany" form first illustrated at @@ -185,8 +185,8 @@ construct automatically. INSERT INTO address (user_id, email_address) VALUES ((SELECT user_account.id FROM user_account WHERE user_account.name = ?), ?) - [...] (('spongebob', 'spongebob@sqlalchemy.org'), ('sandy', 'sandy@sqlalchemy.org'), - ('sandy', 'sandy@squirrelpower.org')) + [...] [('spongebob', 'spongebob@sqlalchemy.org'), ('sandy', 'sandy@sqlalchemy.org'), + ('sandy', 'sandy@squirrelpower.org')] COMMIT{stop} .. _tutorial_insert_from_select: diff --git a/doc/build/tutorial/data_update.rst b/doc/build/tutorial/data_update.rst index 8813dda988..8e88eb2f75 100644 --- a/doc/build/tutorial/data_update.rst +++ b/doc/build/tutorial/data_update.rst @@ -101,7 +101,7 @@ that literal values would normally go: ... ) {opensql}BEGIN (implicit) UPDATE user_account SET name=? WHERE user_account.name = ? - [...] (('ed', 'jack'), ('mary', 'wendy'), ('jake', 'jim')) + [...] [('ed', 'jack'), ('mary', 'wendy'), ('jake', 'jim')] COMMIT{stop} diff --git a/doc/build/tutorial/dbapi_transactions.rst b/doc/build/tutorial/dbapi_transactions.rst index 16768da2b9..f4d2ad8e07 100644 --- a/doc/build/tutorial/dbapi_transactions.rst +++ b/doc/build/tutorial/dbapi_transactions.rst @@ -115,7 +115,7 @@ where we acquired the :class:`_future.Connection` object: [...] () INSERT INTO some_table (x, y) VALUES (?, ?) - [...] ((1, 1), (2, 4)) + [...] [(1, 1), (2, 4)] COMMIT @@ -149,7 +149,7 @@ may be referred towards as **begin once**: ... ) {opensql}BEGIN (implicit) INSERT INTO some_table (x, y) VALUES (?, ?) - [...] ((6, 8), (9, 10)) + [...] [(6, 8), (9, 10)] COMMIT @@ -374,7 +374,7 @@ be invoked against each parameter set individually: ... conn.commit() {opensql}BEGIN (implicit) INSERT INTO some_table (x, y) VALUES (?, ?) - [...] ((11, 12), (13, 14)) + [...] [(11, 12), (13, 14)] COMMIT @@ -508,7 +508,7 @@ our data: ... session.commit() {opensql}BEGIN (implicit) UPDATE some_table SET y=? WHERE x=? - [...] ((11, 9), (15, 13)) + [...] [(11, 9), (15, 13)] COMMIT{stop} Above, we invoked an UPDATE statement using the bound-parameter, "executemany" diff --git a/lib/sqlalchemy/connectors/__init__.py b/lib/sqlalchemy/connectors/__init__.py index 132a0a4de6..f4fa5b66b7 100644 --- a/lib/sqlalchemy/connectors/__init__.py +++ b/lib/sqlalchemy/connectors/__init__.py @@ -6,5 +6,13 @@ # the MIT License: https://www.opensource.org/licenses/mit-license.php -class Connector: - pass +from ..engine.interfaces import Dialect + + +class Connector(Dialect): + """Base class for dialect mixins, for DBAPIs that work + across entirely different database backends. + + Currently the only such mixin is pyodbc. + + """ diff --git a/lib/sqlalchemy/connectors/pyodbc.py b/lib/sqlalchemy/connectors/pyodbc.py index f7d01ce437..c5f07de077 100644 --- a/lib/sqlalchemy/connectors/pyodbc.py +++ b/lib/sqlalchemy/connectors/pyodbc.py @@ -5,12 +5,27 @@ # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +from __future__ import annotations + import re +from types import ModuleType +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union from urllib.parse import unquote_plus from . import Connector +from .. import ExecutionContext +from .. import pool from .. import util +from ..engine import ConnectArgsType +from ..engine import Connection from ..engine import interfaces +from ..engine import URL +from ..sql.type_api import TypeEngine class PyODBCConnector(Connector): @@ -25,18 +40,20 @@ class PyODBCConnector(Connector): # for non-DSN connections, this *may* be used to # hold the desired driver name - pyodbc_driver_name = None + pyodbc_driver_name: Optional[str] = None + + dbapi: ModuleType - def __init__(self, use_setinputsizes=False, **kw): + def __init__(self, use_setinputsizes: bool = False, **kw: Any): super(PyODBCConnector, self).__init__(**kw) if use_setinputsizes: self.bind_typing = interfaces.BindTyping.SETINPUTSIZES @classmethod - def dbapi(cls): + def import_dbapi(cls) -> ModuleType: return __import__("pyodbc") - def create_connect_args(self, url): + def create_connect_args(self, url: URL) -> ConnectArgsType: opts = url.translate_connect_args(username="user") opts.update(url.query) @@ -44,7 +61,9 @@ class PyODBCConnector(Connector): query = url.query - connect_args = {} + connect_args: Dict[str, Any] = {} + connectors: List[str] + for param in ("ansi", "unicode_results", "autocommit"): if param in keys: connect_args[param] = util.asbool(keys.pop(param)) @@ -53,7 +72,7 @@ class PyODBCConnector(Connector): connectors = [unquote_plus(keys.pop("odbc_connect"))] else: - def check_quote(token): + def check_quote(token: str) -> str: if ";" in str(token): token = "{%s}" % token.replace("}", "}}") return token @@ -115,9 +134,14 @@ class PyODBCConnector(Connector): connectors.extend(["%s=%s" % (k, v) for k, v in keys.items()]) - return [[";".join(connectors)], connect_args] + return ((";".join(connectors),), connect_args) - def is_disconnect(self, e, connection, cursor): + def is_disconnect( + self, + e: Exception, + connection: Optional[pool.PoolProxiedConnection], + cursor: Optional[interfaces.DBAPICursor], + ) -> bool: if isinstance(e, self.dbapi.ProgrammingError): return "The cursor's connection has been closed." in str( e @@ -125,36 +149,44 @@ class PyODBCConnector(Connector): else: return False - def _dbapi_version(self): + def _dbapi_version(self) -> interfaces.VersionInfoType: if not self.dbapi: return () return self._parse_dbapi_version(self.dbapi.version) - def _parse_dbapi_version(self, vers): + def _parse_dbapi_version(self, vers: str) -> interfaces.VersionInfoType: m = re.match(r"(?:py.*-)?([\d\.]+)(?:-(\w+))?", vers) if not m: return () - vers = tuple([int(x) for x in m.group(1).split(".")]) + vers_tuple: interfaces.VersionInfoType = tuple( + [int(x) for x in m.group(1).split(".")] + ) if m.group(2): - vers += (m.group(2),) - return vers + vers_tuple += (m.group(2),) + return vers_tuple - def _get_server_version_info(self, connection, allow_chars=True): + def _get_server_version_info( + self, connection: Connection + ) -> interfaces.VersionInfoType: # NOTE: this function is not reliable, particularly when # freetds is in use. Implement database-specific server version # queries. - dbapi_con = connection.connection - version = [] + dbapi_con = connection.connection.dbapi_connection + version: Tuple[Union[int, str], ...] = () r = re.compile(r"[.\-]") - for n in r.split(dbapi_con.getinfo(self.dbapi.SQL_DBMS_VER)): + for n in r.split(dbapi_con.getinfo(self.dbapi.SQL_DBMS_VER)): # type: ignore[union-attr] # noqa E501 try: - version.append(int(n)) + version += (int(n),) except ValueError: - if allow_chars: - version.append(n) + pass return tuple(version) - def do_set_input_sizes(self, cursor, list_of_tuples, context): + def do_set_input_sizes( + self, + cursor: interfaces.DBAPICursor, + list_of_tuples: List[Tuple[str, Any, TypeEngine[Any]]], + context: ExecutionContext, + ) -> None: # the rules for these types seems a little strange, as you can pass # non-tuples as well as tuples, however it seems to assume "0" # for the subsequent values if you don't pass a tuple which fails @@ -174,12 +206,16 @@ class PyODBCConnector(Connector): ] ) - def get_isolation_level_values(self, dbapi_connection): - return super().get_isolation_level_values(dbapi_connection) + [ + def get_isolation_level_values( + self, dbapi_connection: interfaces.DBAPIConnection + ) -> List[str]: + return super().get_isolation_level_values(dbapi_connection) + [ # type: ignore # noqa E501 "AUTOCOMMIT" ] - def set_isolation_level(self, dbapi_connection, level): + def set_isolation_level( + self, dbapi_connection: interfaces.DBAPIConnection, level: str + ) -> None: # adjust for ConnectionFairy being present # allows attribute set e.g. "connection.autocommit = True" # to work properly @@ -188,6 +224,4 @@ class PyODBCConnector(Connector): dbapi_connection.autocommit = True else: dbapi_connection.autocommit = False - super(PyODBCConnector, self).set_isolation_level( - dbapi_connection, level - ) + super().set_isolation_level(dbapi_connection, level) diff --git a/lib/sqlalchemy/cyextension/resultproxy.pyx b/lib/sqlalchemy/cyextension/resultproxy.pyx index daf5cc9400..e88c8ec0be 100644 --- a/lib/sqlalchemy/cyextension/resultproxy.pyx +++ b/lib/sqlalchemy/cyextension/resultproxy.pyx @@ -7,8 +7,6 @@ cdef int MD_INDEX = 0 # integer index in cursor.description KEY_INTEGER_ONLY = 0 KEY_OBJECTS_ONLY = 1 -sqlalchemy_engine_row = None - cdef class BaseRow: cdef readonly object _parent cdef readonly tuple _data @@ -53,19 +51,6 @@ cdef class BaseRow: self._keymap = self._parent._keymap self._key_style = state["_key_style"] - def _filter_on_values(self, filters): - global sqlalchemy_engine_row - if sqlalchemy_engine_row is None: - from sqlalchemy.engine.row import Row as sqlalchemy_engine_row - - return sqlalchemy_engine_row( - self._parent, - filters, - self._keymap, - self._key_style, - self._data, - ) - def _values_impl(self): return list(self) @@ -78,18 +63,8 @@ cdef class BaseRow: def __hash__(self): return hash(self._data) - def _get_by_int_impl(self, key): - return self._data[key] - - cpdef _get_by_key_impl(self, key): - # keep two isinstance since it's noticeably faster in the int case - if isinstance(key, int) or isinstance(key, slice): - return self._data[key] - - self._parent._raise_for_nonint(key) - - def __getitem__(self, key): - return self._get_by_key_impl(key) + def __getitem__(self, index): + return self._data[index] cpdef _get_by_key_impl_mapping(self, key): try: diff --git a/lib/sqlalchemy/dialects/mssql/pymssql.py b/lib/sqlalchemy/dialects/mssql/pymssql.py index 20d2b09db7..8d654b72df 100644 --- a/lib/sqlalchemy/dialects/mssql/pymssql.py +++ b/lib/sqlalchemy/dialects/mssql/pymssql.py @@ -77,7 +77,7 @@ class MSDialect_pymssql(MSDialect): ) @classmethod - def dbapi(cls): + def import_dbapi(cls): module = __import__("pymssql") # pymmsql < 2.1.1 doesn't have a Binary method. we use string client_ver = tuple(int(x) for x in module.__version__.split(".")) @@ -106,7 +106,7 @@ class MSDialect_pymssql(MSDialect): port = opts.pop("port", None) if port and "host" in opts: opts["host"] = "%s:%s" % (opts["host"], port) - return [[], opts] + return ([], opts) def is_disconnect(self, e, connection, cursor): for msg in ( diff --git a/lib/sqlalchemy/dialects/mssql/pyodbc.py b/lib/sqlalchemy/dialects/mssql/pyodbc.py index 3af083a6ae..0951f219b3 100644 --- a/lib/sqlalchemy/dialects/mssql/pyodbc.py +++ b/lib/sqlalchemy/dialects/mssql/pyodbc.py @@ -538,7 +538,7 @@ class MSDialect_pyodbc(PyODBCConnector, MSDialect): # 2008. Before we had the VARCHAR cast above, pyodbc would also # fail on this query. return super(MSDialect_pyodbc, self)._get_server_version_info( - connection, allow_chars=False + connection ) else: version = [] diff --git a/lib/sqlalchemy/dialects/mysql/aiomysql.py b/lib/sqlalchemy/dialects/mysql/aiomysql.py index df716346ef..d685b7ea10 100644 --- a/lib/sqlalchemy/dialects/mysql/aiomysql.py +++ b/lib/sqlalchemy/dialects/mysql/aiomysql.py @@ -276,7 +276,7 @@ class MySQLDialect_aiomysql(MySQLDialect_pymysql): is_async = True @classmethod - def dbapi(cls): + def import_dbapi(cls): return AsyncAdapt_aiomysql_dbapi( __import__("aiomysql"), __import__("pymysql") ) diff --git a/lib/sqlalchemy/dialects/mysql/asyncmy.py b/lib/sqlalchemy/dialects/mysql/asyncmy.py index 915b666bbe..7d5b1bf866 100644 --- a/lib/sqlalchemy/dialects/mysql/asyncmy.py +++ b/lib/sqlalchemy/dialects/mysql/asyncmy.py @@ -288,7 +288,7 @@ class MySQLDialect_asyncmy(MySQLDialect_pymysql): is_async = True @classmethod - def dbapi(cls): + def import_dbapi(cls): return AsyncAdapt_asyncmy_dbapi(__import__("asyncmy")) @classmethod diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index 5c2de09115..b2ccfc90f2 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -2434,7 +2434,7 @@ class MySQLDialect(default.DefaultDialect): @classmethod def _is_mariadb_from_url(cls, url): - dbapi = cls.dbapi() + dbapi = cls.import_dbapi() dialect = cls(dbapi=dbapi) cargs, cparams = dialect.create_connect_args(url) diff --git a/lib/sqlalchemy/dialects/mysql/cymysql.py b/lib/sqlalchemy/dialects/mysql/cymysql.py index 9fd0b4a093..281c509b79 100644 --- a/lib/sqlalchemy/dialects/mysql/cymysql.py +++ b/lib/sqlalchemy/dialects/mysql/cymysql.py @@ -53,7 +53,7 @@ class MySQLDialect_cymysql(MySQLDialect_mysqldb): colspecs = util.update_copy(MySQLDialect.colspecs, {BIT: _cymysqlBIT}) @classmethod - def dbapi(cls): + def import_dbapi(cls): return __import__("cymysql") def _detect_charset(self, connection): diff --git a/lib/sqlalchemy/dialects/mysql/mariadbconnector.py b/lib/sqlalchemy/dialects/mysql/mariadbconnector.py index fca91204f2..bf2b042513 100644 --- a/lib/sqlalchemy/dialects/mysql/mariadbconnector.py +++ b/lib/sqlalchemy/dialects/mysql/mariadbconnector.py @@ -111,7 +111,7 @@ class MySQLDialect_mariadbconnector(MySQLDialect): ) @classmethod - def dbapi(cls): + def import_dbapi(cls): return __import__("mariadb") def is_disconnect(self, e, connection, cursor): diff --git a/lib/sqlalchemy/dialects/mysql/mysqlconnector.py b/lib/sqlalchemy/dialects/mysql/mysqlconnector.py index c96b739dc4..a69dac9a5c 100644 --- a/lib/sqlalchemy/dialects/mysql/mysqlconnector.py +++ b/lib/sqlalchemy/dialects/mysql/mysqlconnector.py @@ -77,7 +77,7 @@ class MySQLDialect_mysqlconnector(MySQLDialect): colspecs = util.update_copy(MySQLDialect.colspecs, {BIT: _myconnpyBIT}) @classmethod - def dbapi(cls): + def import_dbapi(cls): from mysql import connector return connector diff --git a/lib/sqlalchemy/dialects/mysql/mysqldb.py b/lib/sqlalchemy/dialects/mysql/mysqldb.py index b4f071de0b..6d66f88b4f 100644 --- a/lib/sqlalchemy/dialects/mysql/mysqldb.py +++ b/lib/sqlalchemy/dialects/mysql/mysqldb.py @@ -159,7 +159,7 @@ class MySQLDialect_mysqldb(MySQLDialect): return False @classmethod - def dbapi(cls): + def import_dbapi(cls): return __import__("MySQLdb") def on_connect(self): diff --git a/lib/sqlalchemy/dialects/mysql/pymysql.py b/lib/sqlalchemy/dialects/mysql/pymysql.py index eddb9c9219..9a240da618 100644 --- a/lib/sqlalchemy/dialects/mysql/pymysql.py +++ b/lib/sqlalchemy/dialects/mysql/pymysql.py @@ -57,7 +57,7 @@ class MySQLDialect_pymysql(MySQLDialect_mysqldb): return False @classmethod - def dbapi(cls): + def import_dbapi(cls): return __import__("pymysql") def create_connect_args(self, url, _translate_args=None): diff --git a/lib/sqlalchemy/dialects/oracle/cx_oracle.py b/lib/sqlalchemy/dialects/oracle/cx_oracle.py index a390099aef..98181051e5 100644 --- a/lib/sqlalchemy/dialects/oracle/cx_oracle.py +++ b/lib/sqlalchemy/dialects/oracle/cx_oracle.py @@ -963,7 +963,7 @@ class OracleDialect_cx_oracle(OracleDialect): return (0, 0, 0) @classmethod - def dbapi(cls): + def import_dbapi(cls): import cx_Oracle return cx_Oracle diff --git a/lib/sqlalchemy/dialects/postgresql/asyncpg.py b/lib/sqlalchemy/dialects/postgresql/asyncpg.py index 4c3c47ba6f..75f6c2704b 100644 --- a/lib/sqlalchemy/dialects/postgresql/asyncpg.py +++ b/lib/sqlalchemy/dialects/postgresql/asyncpg.py @@ -878,7 +878,7 @@ class PGDialect_asyncpg(PGDialect): return (99, 99, 99) @classmethod - def dbapi(cls): + def import_dbapi(cls): return AsyncAdapt_asyncpg_dbapi(__import__("asyncpg")) @util.memoized_property diff --git a/lib/sqlalchemy/dialects/postgresql/pg8000.py b/lib/sqlalchemy/dialects/postgresql/pg8000.py index c23da93bb6..372b8639e0 100644 --- a/lib/sqlalchemy/dialects/postgresql/pg8000.py +++ b/lib/sqlalchemy/dialects/postgresql/pg8000.py @@ -426,7 +426,7 @@ class PGDialect_pg8000(PGDialect): return (99, 99, 99) @classmethod - def dbapi(cls): + def import_dbapi(cls): return __import__("pg8000") def create_connect_args(self, url): diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg.py b/lib/sqlalchemy/dialects/postgresql/psycopg.py index 3ba535d6cf..33dc65afc5 100644 --- a/lib/sqlalchemy/dialects/postgresql/psycopg.py +++ b/lib/sqlalchemy/dialects/postgresql/psycopg.py @@ -281,7 +281,7 @@ class PGDialect_psycopg(_PGDialect_common_psycopg): register_hstore(info, connection.connection) @classmethod - def dbapi(cls): + def import_dbapi(cls): import psycopg return psycopg @@ -592,7 +592,7 @@ class PGDialectAsync_psycopg(PGDialect_psycopg): supports_statement_cache = True @classmethod - def dbapi(cls): + def import_dbapi(cls): import psycopg from psycopg.pq import ExecStatus diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg2.py b/lib/sqlalchemy/dialects/postgresql/psycopg2.py index a08c5e5b07..dddce5a629 100644 --- a/lib/sqlalchemy/dialects/postgresql/psycopg2.py +++ b/lib/sqlalchemy/dialects/postgresql/psycopg2.py @@ -612,7 +612,7 @@ class PGDialect_psycopg2(_PGDialect_common_psycopg): ) @classmethod - def dbapi(cls): + def import_dbapi(cls): import psycopg2 return psycopg2 diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg2cffi.py b/lib/sqlalchemy/dialects/postgresql/psycopg2cffi.py index 5a4dcb2e67..0943613a28 100644 --- a/lib/sqlalchemy/dialects/postgresql/psycopg2cffi.py +++ b/lib/sqlalchemy/dialects/postgresql/psycopg2cffi.py @@ -44,7 +44,7 @@ class PGDialect_psycopg2cffi(PGDialect_psycopg2): ) @classmethod - def dbapi(cls): + def import_dbapi(cls): return __import__("psycopg2cffi") @util.memoized_property diff --git a/lib/sqlalchemy/dialects/sqlite/aiosqlite.py b/lib/sqlalchemy/dialects/sqlite/aiosqlite.py index e88ab1a0fa..dd0499975b 100644 --- a/lib/sqlalchemy/dialects/sqlite/aiosqlite.py +++ b/lib/sqlalchemy/dialects/sqlite/aiosqlite.py @@ -308,7 +308,7 @@ class SQLiteDialect_aiosqlite(SQLiteDialect_pysqlite): execution_ctx_cls = SQLiteExecutionContext_aiosqlite @classmethod - def dbapi(cls): + def import_dbapi(cls): return AsyncAdapt_aiosqlite_dbapi( __import__("aiosqlite"), __import__("sqlite3") ) diff --git a/lib/sqlalchemy/dialects/sqlite/pysqlcipher.py b/lib/sqlalchemy/dialects/sqlite/pysqlcipher.py index 28f7952981..b67eed9749 100644 --- a/lib/sqlalchemy/dialects/sqlite/pysqlcipher.py +++ b/lib/sqlalchemy/dialects/sqlite/pysqlcipher.py @@ -105,7 +105,7 @@ class SQLiteDialect_pysqlcipher(SQLiteDialect_pysqlite): pragmas = ("kdf_iter", "cipher", "cipher_page_size", "cipher_use_hmac") @classmethod - def dbapi(cls): + def import_dbapi(cls): try: import sqlcipher3 as sqlcipher except ImportError: diff --git a/lib/sqlalchemy/dialects/sqlite/pysqlite.py b/lib/sqlalchemy/dialects/sqlite/pysqlite.py index 8476e68342..2aa7149a68 100644 --- a/lib/sqlalchemy/dialects/sqlite/pysqlite.py +++ b/lib/sqlalchemy/dialects/sqlite/pysqlite.py @@ -465,7 +465,7 @@ class SQLiteDialect_pysqlite(SQLiteDialect): driver = "pysqlite" @classmethod - def dbapi(cls): + def import_dbapi(cls): from sqlite3 import dbapi2 as sqlite return sqlite diff --git a/lib/sqlalchemy/engine/__init__.py b/lib/sqlalchemy/engine/__init__.py index c6bc4b6aa6..32f3f2eccd 100644 --- a/lib/sqlalchemy/engine/__init__.py +++ b/lib/sqlalchemy/engine/__init__.py @@ -35,6 +35,7 @@ from .cursor import ResultProxy as ResultProxy from .interfaces import AdaptedConnection as AdaptedConnection from .interfaces import BindTyping as BindTyping from .interfaces import Compiled as Compiled +from .interfaces import ConnectArgsType as ConnectArgsType from .interfaces import CreateEnginePlugin as CreateEnginePlugin from .interfaces import Dialect as Dialect from .interfaces import ExceptionContext as ExceptionContext diff --git a/lib/sqlalchemy/engine/_py_processors.py b/lib/sqlalchemy/engine/_py_processors.py index e3024471a2..27cb9e9395 100644 --- a/lib/sqlalchemy/engine/_py_processors.py +++ b/lib/sqlalchemy/engine/_py_processors.py @@ -16,16 +16,30 @@ They all share one common characteristic: None is passed through unchanged. from __future__ import annotations import datetime +from decimal import Decimal import re +import typing +from typing import Any +from typing import Callable +from typing import Optional +from typing import Type +from typing import TypeVar +from typing import Union + +_DT = TypeVar( + "_DT", bound=Union[datetime.datetime, datetime.time, datetime.date] +) -def str_to_datetime_processor_factory(regexp, type_): +def str_to_datetime_processor_factory( + regexp: typing.Pattern[str], type_: Callable[..., _DT] +) -> Callable[[Optional[str]], Optional[_DT]]: rmatch = regexp.match # Even on python2.6 datetime.strptime is both slower than this code # and it does not support microseconds. has_named_groups = bool(regexp.groupindex) - def process(value): + def process(value: Optional[str]) -> Optional[_DT]: if value is None: return None else: @@ -59,10 +73,12 @@ def str_to_datetime_processor_factory(regexp, type_): return process -def to_decimal_processor_factory(target_class, scale): +def to_decimal_processor_factory( + target_class: Type[Decimal], scale: int +) -> Callable[[Optional[float]], Optional[Decimal]]: fstring = "%%.%df" % scale - def process(value): + def process(value: Optional[float]) -> Optional[Decimal]: if value is None: return None else: @@ -71,21 +87,21 @@ def to_decimal_processor_factory(target_class, scale): return process -def to_float(value): +def to_float(value: Optional[Union[int, float]]) -> Optional[float]: if value is None: return None else: return float(value) -def to_str(value): +def to_str(value: Optional[Any]) -> Optional[str]: if value is None: return None else: return str(value) -def int_to_boolean(value): +def int_to_boolean(value: Optional[int]) -> Optional[bool]: if value is None: return None else: diff --git a/lib/sqlalchemy/engine/_py_row.py b/lib/sqlalchemy/engine/_py_row.py index a6d5b79d59..7cbac552fd 100644 --- a/lib/sqlalchemy/engine/_py_row.py +++ b/lib/sqlalchemy/engine/_py_row.py @@ -1,26 +1,59 @@ from __future__ import annotations +import enum import operator +import typing +from typing import Any +from typing import Callable +from typing import Dict +from typing import Iterator +from typing import List +from typing import Optional +from typing import Tuple +from typing import Type +from typing import Union + +if typing.TYPE_CHECKING: + from .result import _KeyMapType + from .result import _KeyType + from .result import _ProcessorsType + from .result import _RawRowType + from .result import _TupleGetterType + from .result import ResultMetaData MD_INDEX = 0 # integer index in cursor.description -KEY_INTEGER_ONLY = 0 -"""__getitem__ only allows integer values and slices, raises TypeError - otherwise""" -KEY_OBJECTS_ONLY = 1 -"""__getitem__ only allows string/object values, raises TypeError otherwise""" +class _KeyStyle(enum.Enum): + KEY_INTEGER_ONLY = 0 + """__getitem__ only allows integer values and slices, raises TypeError + otherwise""" -sqlalchemy_engine_row = None + KEY_OBJECTS_ONLY = 1 + """__getitem__ only allows string/object values, raises TypeError + otherwise""" + + +KEY_INTEGER_ONLY, KEY_OBJECTS_ONLY = list(_KeyStyle) class BaseRow: - Row = None __slots__ = ("_parent", "_data", "_keymap", "_key_style") - def __init__(self, parent, processors, keymap, key_style, data): + _parent: ResultMetaData + _data: _RawRowType + _keymap: _KeyMapType + _key_style: _KeyStyle + + def __init__( + self, + parent: ResultMetaData, + processors: Optional[_ProcessorsType], + keymap: _KeyMapType, + key_style: _KeyStyle, + data: _RawRowType, + ): """Row objects are constructed by CursorResult objects.""" - object.__setattr__(self, "_parent", parent) if processors: @@ -41,68 +74,45 @@ class BaseRow: object.__setattr__(self, "_key_style", key_style) - def __reduce__(self): + def __reduce__(self) -> Tuple[Callable[..., BaseRow], Tuple[Any, ...]]: return ( rowproxy_reconstructor, (self.__class__, self.__getstate__()), ) - def __getstate__(self): + def __getstate__(self) -> Dict[str, Any]: return { "_parent": self._parent, "_data": self._data, "_key_style": self._key_style, } - def __setstate__(self, state): + def __setstate__(self, state: Dict[str, Any]) -> None: parent = state["_parent"] object.__setattr__(self, "_parent", parent) object.__setattr__(self, "_data", state["_data"]) object.__setattr__(self, "_keymap", parent._keymap) object.__setattr__(self, "_key_style", state["_key_style"]) - def _filter_on_values(self, filters): - global sqlalchemy_engine_row - if sqlalchemy_engine_row is None: - from sqlalchemy.engine.row import Row as sqlalchemy_engine_row - - return sqlalchemy_engine_row( - self._parent, - filters, - self._keymap, - self._key_style, - self._data, - ) - - def _values_impl(self): + def _values_impl(self) -> List[Any]: return list(self) - def __iter__(self): + def __iter__(self) -> Iterator[Any]: return iter(self._data) - def __len__(self): + def __len__(self) -> int: return len(self._data) - def __hash__(self): + def __hash__(self) -> int: return hash(self._data) - def _get_by_int_impl(self, key): + def _get_by_int_impl(self, key: Union[int, slice]) -> Any: return self._data[key] - def _get_by_key_impl(self, key): - # keep two isinstance since it's noticeably faster in the int case - if isinstance(key, int) or isinstance(key, slice): - return self._data[key] - - self._parent._raise_for_nonint(key) - - # The original 1.4 plan was that Row would not allow row["str"] - # access, however as the C extensions were inadvertently allowing - # this coupled with the fact that orm Session sets future=True, - # this allows a softer upgrade path. see #6218 - __getitem__ = _get_by_key_impl + if not typing.TYPE_CHECKING: + __getitem__ = _get_by_int_impl - def _get_by_key_impl_mapping(self, key): + def _get_by_key_impl_mapping(self, key: _KeyType) -> Any: try: rec = self._keymap[key] except KeyError as ke: @@ -116,7 +126,7 @@ class BaseRow: return self._data[mdindex] - def __getattr__(self, name): + def __getattr__(self, name: str) -> Any: try: return self._get_by_key_impl_mapping(name) except KeyError as e: @@ -125,13 +135,15 @@ class BaseRow: # This reconstructor is necessary so that pickles with the Cy extension or # without use the same Binary format. -def rowproxy_reconstructor(cls, state): +def rowproxy_reconstructor( + cls: Type[BaseRow], state: Dict[str, Any] +) -> BaseRow: obj = cls.__new__(cls) obj.__setstate__(state) return obj -def tuplegetter(*indexes): +def tuplegetter(*indexes: int) -> _TupleGetterType: it = operator.itemgetter(*indexes) if len(indexes) > 1: diff --git a/lib/sqlalchemy/engine/_py_util.py b/lib/sqlalchemy/engine/_py_util.py index ff03a47613..538c075a2b 100644 --- a/lib/sqlalchemy/engine/_py_util.py +++ b/lib/sqlalchemy/engine/_py_util.py @@ -1,21 +1,32 @@ from __future__ import annotations -from collections import abc as collections_abc +import typing +from typing import Any +from typing import Mapping +from typing import Optional +from typing import Tuple from .. import exc -_no_tuple = () +if typing.TYPE_CHECKING: + from .interfaces import _CoreAnyExecuteParams + from .interfaces import _CoreMultiExecuteParams + from .interfaces import _DBAPIAnyExecuteParams + from .interfaces import _DBAPIMultiExecuteParams -def _distill_params_20(params): +_no_tuple: Tuple[Any, ...] = () + + +def _distill_params_20( + params: Optional[_CoreAnyExecuteParams], +) -> _CoreMultiExecuteParams: if params is None: return _no_tuple # Assume list is more likely than tuple elif isinstance(params, list) or isinstance(params, tuple): # collections_abc.MutableSequence): # avoid abc.__instancecheck__ - if params and not isinstance( - params[0], (tuple, collections_abc.Mapping) - ): + if params and not isinstance(params[0], (tuple, Mapping)): raise exc.ArgumentError( "List argument must consist only of tuples or dictionaries" ) @@ -25,21 +36,21 @@ def _distill_params_20(params): # only do immutabledict or abc.__instancecheck__ for Mapping after # we've checked for plain dictionaries and would otherwise raise params, - collections_abc.Mapping, + Mapping, ): return [params] else: raise exc.ArgumentError("mapping or list expected for parameters") -def _distill_raw_params(params): +def _distill_raw_params( + params: Optional[_DBAPIAnyExecuteParams], +) -> _DBAPIMultiExecuteParams: if params is None: return _no_tuple elif isinstance(params, list): # collections_abc.MutableSequence): # avoid abc.__instancecheck__ - if params and not isinstance( - params[0], (tuple, collections_abc.Mapping) - ): + if params and not isinstance(params[0], (tuple, Mapping)): raise exc.ArgumentError( "List argument must consist only of tuples or dictionaries" ) @@ -49,8 +60,9 @@ def _distill_raw_params(params): # only do abc.__instancecheck__ for Mapping after we've checked # for plain dictionaries and would otherwise raise params, - collections_abc.Mapping, + Mapping, ): - return [params] + # cast("Union[List[Mapping[str, Any]], Tuple[Any, ...]]", [params]) + return [params] # type: ignore else: raise exc.ArgumentError("mapping or sequence expected for parameters") diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 8c99f63090..5ce531338e 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -10,13 +10,24 @@ import contextlib import sys import typing from typing import Any +from typing import Callable +from typing import cast +from typing import Dict +from typing import Iterator +from typing import List from typing import Mapping +from typing import MutableMapping +from typing import NoReturn from typing import Optional +from typing import Tuple +from typing import Type from typing import Union from .interfaces import BindTyping from .interfaces import ConnectionEventsTarget +from .interfaces import DBAPICursor from .interfaces import ExceptionContext +from .interfaces import ExecutionContext from .util import _distill_params_20 from .util import _distill_raw_params from .util import TransactionalContext @@ -26,22 +37,48 @@ from .. import log from .. import util from ..sql import compiler from ..sql import util as sql_util -from ..sql._typing import _ExecuteOptions -from ..sql._typing import _ExecuteParams + +_CompiledCacheType = MutableMapping[Any, Any] if typing.TYPE_CHECKING: + from . import Result + from . import ScalarResult + from .interfaces import _AnyExecuteParams + from .interfaces import _AnyMultiExecuteParams + from .interfaces import _AnySingleExecuteParams + from .interfaces import _CoreAnyExecuteParams + from .interfaces import _CoreMultiExecuteParams + from .interfaces import _CoreSingleExecuteParams + from .interfaces import _DBAPIAnyExecuteParams + from .interfaces import _DBAPIMultiExecuteParams + from .interfaces import _DBAPISingleExecuteParams + from .interfaces import _ExecuteOptions + from .interfaces import _ExecuteOptionsParameter + from .interfaces import _SchemaTranslateMapType from .interfaces import Dialect from .reflection import Inspector # noqa from .url import URL + from ..event import dispatcher + from ..log import _EchoFlagType + from ..pool import _ConnectionFairy from ..pool import Pool from ..pool import PoolProxiedConnection + from ..sql import Executable + from ..sql.base import SchemaVisitor + from ..sql.compiler import Compiled + from ..sql.ddl import DDLElement + from ..sql.ddl import SchemaDropper + from ..sql.ddl import SchemaGenerator + from ..sql.functions import FunctionElement + from ..sql.schema import ColumnDefault + from ..sql.schema import HasSchemaAttr """Defines :class:`_engine.Connection` and :class:`_engine.Engine`. """ -_EMPTY_EXECUTION_OPTS = util.immutabledict() -NO_OPTIONS = util.immutabledict() +_EMPTY_EXECUTION_OPTS: _ExecuteOptions = util.immutabledict() +NO_OPTIONS: Mapping[str, Any] = util.immutabledict() class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): @@ -69,23 +106,32 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): """ + dispatch: dispatcher[ConnectionEventsTarget] + _sqla_logger_namespace = "sqlalchemy.engine.Connection" # used by sqlalchemy.engine.util.TransactionalContext - _trans_context_manager = None + _trans_context_manager: Optional[TransactionalContext] = None # legacy as of 2.0, should be eventually deprecated and # removed. was used in the "pre_ping" recipe that's been in the docs # a long time should_close_with_result = False + _dbapi_connection: Optional[PoolProxiedConnection] + + _execution_options: _ExecuteOptions + + _transaction: Optional[RootTransaction] + _nested_transaction: Optional[NestedTransaction] + def __init__( self, - engine, - connection=None, - _has_events=None, - _allow_revalidate=True, - _allow_autobegin=True, + engine: Engine, + connection: Optional[PoolProxiedConnection] = None, + _has_events: Optional[bool] = None, + _allow_revalidate: bool = True, + _allow_autobegin: bool = True, ): """Construct a new Connection.""" self.engine = engine @@ -125,14 +171,14 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): self.dispatch.engine_connect(self) @util.memoized_property - def _message_formatter(self): + def _message_formatter(self) -> Any: if "logging_token" in self._execution_options: token = self._execution_options["logging_token"] return lambda msg: "[%s] %s" % (token, msg) else: return None - def _log_info(self, message, *arg, **kw): + def _log_info(self, message: str, *arg: Any, **kw: Any) -> None: fmt = self._message_formatter if fmt: @@ -143,7 +189,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): self.engine.logger.info(message, *arg, **kw) - def _log_debug(self, message, *arg, **kw): + def _log_debug(self, message: str, *arg: Any, **kw: Any) -> None: fmt = self._message_formatter if fmt: @@ -155,19 +201,19 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): self.engine.logger.debug(message, *arg, **kw) @property - def _schema_translate_map(self): + def _schema_translate_map(self) -> Optional[_SchemaTranslateMapType]: return self._execution_options.get("schema_translate_map", None) - def schema_for_object(self, obj): + def schema_for_object(self, obj: HasSchemaAttr) -> Optional[str]: """Return the schema name for the given schema item taking into account current schema translate map. """ name = obj.schema - schema_translate_map = self._execution_options.get( - "schema_translate_map", None - ) + schema_translate_map: Optional[ + Mapping[Optional[str], str] + ] = self._execution_options.get("schema_translate_map", None) if ( schema_translate_map @@ -178,13 +224,13 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): else: return name - def __enter__(self): + def __enter__(self) -> Connection: return self - def __exit__(self, type_, value, traceback): + def __exit__(self, type_: Any, value: Any, traceback: Any) -> None: self.close() - def execution_options(self, **opt): + def execution_options(self, **opt: Any) -> Connection: r"""Set non-SQL options for the connection which take effect during execution. @@ -346,13 +392,13 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): ORM-specific execution options """ # noqa - self._execution_options = self._execution_options.union(opt) if self._has_events or self.engine._has_events: self.dispatch.set_connection_execution_options(self, opt) + self._execution_options = self._execution_options.union(opt) self.dialect.set_connection_execution_options(self, opt) return self - def get_execution_options(self): + def get_execution_options(self) -> _ExecuteOptions: """Get the non-SQL options which will take effect during execution. .. versionadded:: 1.3 @@ -364,14 +410,27 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): return self._execution_options @property - def closed(self): + def _still_open_and_dbapi_connection_is_valid(self) -> bool: + pool_proxied_connection = self._dbapi_connection + return ( + pool_proxied_connection is not None + and pool_proxied_connection.is_valid + ) + + @property + def closed(self) -> bool: """Return True if this connection is closed.""" return self._dbapi_connection is None and not self.__can_reconnect @property - def invalidated(self): - """Return True if this connection was invalidated.""" + def invalidated(self) -> bool: + """Return True if this connection was invalidated. + + This does not indicate whether or not the connection was + invalidated at the pool level, however + + """ # prior to 1.4, "invalid" was stored as a state independent of # "closed", meaning an invalidated connection could be "closed", @@ -382,10 +441,11 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): # "closed" does not need to be "invalid". So the state is now # represented by the two facts alone. - return self._dbapi_connection is None and not self.closed + pool_proxied_connection = self._dbapi_connection + return pool_proxied_connection is None and self.__can_reconnect @property - def connection(self) -> "PoolProxiedConnection": + def connection(self) -> PoolProxiedConnection: """The underlying DB-API connection managed by this Connection. This is a SQLAlchemy connection-pool proxied connection @@ -410,7 +470,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): else: return self._dbapi_connection - def get_isolation_level(self): + def get_isolation_level(self) -> str: """Return the current isolation level assigned to this :class:`_engine.Connection`. @@ -442,15 +502,15 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): - set per :class:`_engine.Connection` isolation level """ + dbapi_connection = self.connection.dbapi_connection + assert dbapi_connection is not None try: - return self.dialect.get_isolation_level( - self.connection.dbapi_connection - ) + return self.dialect.get_isolation_level(dbapi_connection) except BaseException as e: self._handle_dbapi_exception(e, None, None, None, None) @property - def default_isolation_level(self): + def default_isolation_level(self) -> str: """The default isolation level assigned to this :class:`_engine.Connection`. @@ -482,7 +542,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): """ return self.dialect.default_isolation_level - def _invalid_transaction(self): + def _invalid_transaction(self) -> NoReturn: raise exc.PendingRollbackError( "Can't reconnect until invalid %stransaction is rolled " "back. Please rollback() fully before proceeding" @@ -490,7 +550,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): code="8s2b", ) - def _revalidate_connection(self): + def _revalidate_connection(self) -> PoolProxiedConnection: if self.__can_reconnect and self.invalidated: if self._transaction is not None: self._invalid_transaction() @@ -499,13 +559,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): raise exc.ResourceClosedError("This Connection is closed") @property - def _still_open_and_dbapi_connection_is_valid(self): - return self._dbapi_connection is not None and getattr( - self._dbapi_connection, "is_valid", False - ) - - @property - def info(self): + def info(self) -> Dict[str, Any]: """Info dictionary associated with the underlying DBAPI connection referred to by this :class:`_engine.Connection`, allowing user-defined data to be associated with the connection. @@ -518,7 +572,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): return self.connection.info - def invalidate(self, exception=None): + def invalidate(self, exception: Optional[BaseException] = None) -> None: """Invalidate the underlying DBAPI connection associated with this :class:`_engine.Connection`. @@ -567,14 +621,18 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): if self.invalidated: return + # MARKMARK if self.closed: raise exc.ResourceClosedError("This Connection is closed") if self._still_open_and_dbapi_connection_is_valid: - self._dbapi_connection.invalidate(exception) + pool_proxied_connection = self._dbapi_connection + assert pool_proxied_connection is not None + pool_proxied_connection.invalidate(exception) + self._dbapi_connection = None - def detach(self): + def detach(self) -> None: """Detach the underlying DB-API connection from its connection pool. E.g.:: @@ -600,13 +658,21 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): """ - self._dbapi_connection.detach() + if self.closed: + raise exc.ResourceClosedError("This Connection is closed") - def _autobegin(self): - if self._allow_autobegin: + pool_proxied_connection = self._dbapi_connection + if pool_proxied_connection is None: + raise exc.InvalidRequestError( + "Can't detach an invalidated Connection" + ) + pool_proxied_connection.detach() + + def _autobegin(self) -> None: + if self._allow_autobegin and not self.__in_begin: self.begin() - def begin(self): + def begin(self) -> RootTransaction: """Begin a transaction prior to autobegin occurring. E.g.:: @@ -671,14 +737,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): :class:`_engine.Engine` """ - if self.__in_begin: - # for dialects that emit SQL within the process of - # dialect.do_begin() or dialect.do_begin_twophase(), this - # flag prevents "autobegin" from being emitted within that - # process, while allowing self._transaction to remain at None - # until it's complete. - return - elif self._transaction is None: + if self._transaction is None: self._transaction = RootTransaction(self) return self._transaction else: @@ -689,7 +748,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): "is called first." ) - def begin_nested(self): + def begin_nested(self) -> NestedTransaction: """Begin a nested transaction (i.e. SAVEPOINT) and return a transaction handle that controls the scope of the SAVEPOINT. @@ -765,7 +824,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): return NestedTransaction(self) - def begin_twophase(self, xid=None): + def begin_twophase(self, xid: Optional[Any] = None) -> TwoPhaseTransaction: """Begin a two-phase or XA transaction and return a transaction handle. @@ -794,7 +853,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): xid = self.engine.dialect.create_xid() return TwoPhaseTransaction(self, xid) - def commit(self): + def commit(self) -> None: """Commit the transaction that is currently in progress. This method commits the current transaction if one has been started. @@ -819,7 +878,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): if self._transaction: self._transaction.commit() - def rollback(self): + def rollback(self) -> None: """Roll back the transaction that is currently in progress. This method rolls back the current transaction if one has been started. @@ -845,33 +904,33 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): if self._transaction: self._transaction.rollback() - def recover_twophase(self): + def recover_twophase(self) -> List[Any]: return self.engine.dialect.do_recover_twophase(self) - def rollback_prepared(self, xid, recover=False): + def rollback_prepared(self, xid: Any, recover: bool = False) -> None: self.engine.dialect.do_rollback_twophase(self, xid, recover=recover) - def commit_prepared(self, xid, recover=False): + def commit_prepared(self, xid: Any, recover: bool = False) -> None: self.engine.dialect.do_commit_twophase(self, xid, recover=recover) - def in_transaction(self): + def in_transaction(self) -> bool: """Return True if a transaction is in progress.""" return self._transaction is not None and self._transaction.is_active - def in_nested_transaction(self): + def in_nested_transaction(self) -> bool: """Return True if a transaction is in progress.""" return ( self._nested_transaction is not None and self._nested_transaction.is_active ) - def _is_autocommit(self): - return ( + def _is_autocommit_isolation(self) -> bool: + return bool( self._execution_options.get("isolation_level", None) == "AUTOCOMMIT" ) - def get_transaction(self): + def get_transaction(self) -> Optional[RootTransaction]: """Return the current root transaction in progress, if any. .. versionadded:: 1.4 @@ -880,7 +939,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): return self._transaction - def get_nested_transaction(self): + def get_nested_transaction(self) -> Optional[NestedTransaction]: """Return the current nested transaction in progress, if any. .. versionadded:: 1.4 @@ -888,7 +947,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): """ return self._nested_transaction - def _begin_impl(self, transaction): + def _begin_impl(self, transaction: RootTransaction) -> None: if self._echo: self._log_info("BEGIN (implicit)") @@ -904,13 +963,13 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): finally: self.__in_begin = False - def _rollback_impl(self): + def _rollback_impl(self) -> None: if self._has_events or self.engine._has_events: self.dispatch.rollback(self) if self._still_open_and_dbapi_connection_is_valid: if self._echo: - if self._is_autocommit(): + if self._is_autocommit_isolation(): self._log_info( "ROLLBACK using DBAPI connection.rollback(), " "DBAPI should ignore due to autocommit mode" @@ -922,13 +981,13 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): except BaseException as e: self._handle_dbapi_exception(e, None, None, None, None) - def _commit_impl(self): + def _commit_impl(self) -> None: if self._has_events or self.engine._has_events: self.dispatch.commit(self) if self._echo: - if self._is_autocommit(): + if self._is_autocommit_isolation(): self._log_info( "COMMIT using DBAPI connection.commit(), " "DBAPI should ignore due to autocommit mode" @@ -940,58 +999,54 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): except BaseException as e: self._handle_dbapi_exception(e, None, None, None, None) - def _savepoint_impl(self, name=None): + def _savepoint_impl(self, name: Optional[str] = None) -> str: if self._has_events or self.engine._has_events: self.dispatch.savepoint(self, name) if name is None: self.__savepoint_seq += 1 name = "sa_savepoint_%s" % self.__savepoint_seq - if self._still_open_and_dbapi_connection_is_valid: - self.engine.dialect.do_savepoint(self, name) - return name + self.engine.dialect.do_savepoint(self, name) + return name - def _rollback_to_savepoint_impl(self, name): + def _rollback_to_savepoint_impl(self, name: str) -> None: if self._has_events or self.engine._has_events: self.dispatch.rollback_savepoint(self, name, None) if self._still_open_and_dbapi_connection_is_valid: self.engine.dialect.do_rollback_to_savepoint(self, name) - def _release_savepoint_impl(self, name): + def _release_savepoint_impl(self, name: str) -> None: if self._has_events or self.engine._has_events: self.dispatch.release_savepoint(self, name, None) - if self._still_open_and_dbapi_connection_is_valid: - self.engine.dialect.do_release_savepoint(self, name) + self.engine.dialect.do_release_savepoint(self, name) - def _begin_twophase_impl(self, transaction): + def _begin_twophase_impl(self, transaction: TwoPhaseTransaction) -> None: if self._echo: self._log_info("BEGIN TWOPHASE (implicit)") if self._has_events or self.engine._has_events: self.dispatch.begin_twophase(self, transaction.xid) - if self._still_open_and_dbapi_connection_is_valid: - self.__in_begin = True - try: - self.engine.dialect.do_begin_twophase(self, transaction.xid) - except BaseException as e: - self._handle_dbapi_exception(e, None, None, None, None) - finally: - self.__in_begin = False + self.__in_begin = True + try: + self.engine.dialect.do_begin_twophase(self, transaction.xid) + except BaseException as e: + self._handle_dbapi_exception(e, None, None, None, None) + finally: + self.__in_begin = False - def _prepare_twophase_impl(self, xid): + def _prepare_twophase_impl(self, xid: Any) -> None: if self._has_events or self.engine._has_events: self.dispatch.prepare_twophase(self, xid) - if self._still_open_and_dbapi_connection_is_valid: - assert isinstance(self._transaction, TwoPhaseTransaction) - try: - self.engine.dialect.do_prepare_twophase(self, xid) - except BaseException as e: - self._handle_dbapi_exception(e, None, None, None, None) + assert isinstance(self._transaction, TwoPhaseTransaction) + try: + self.engine.dialect.do_prepare_twophase(self, xid) + except BaseException as e: + self._handle_dbapi_exception(e, None, None, None, None) - def _rollback_twophase_impl(self, xid, is_prepared): + def _rollback_twophase_impl(self, xid: Any, is_prepared: bool) -> None: if self._has_events or self.engine._has_events: self.dispatch.rollback_twophase(self, xid, is_prepared) @@ -1004,18 +1059,17 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): except BaseException as e: self._handle_dbapi_exception(e, None, None, None, None) - def _commit_twophase_impl(self, xid, is_prepared): + def _commit_twophase_impl(self, xid: Any, is_prepared: bool) -> None: if self._has_events or self.engine._has_events: self.dispatch.commit_twophase(self, xid, is_prepared) - if self._still_open_and_dbapi_connection_is_valid: - assert isinstance(self._transaction, TwoPhaseTransaction) - try: - self.engine.dialect.do_commit_twophase(self, xid, is_prepared) - except BaseException as e: - self._handle_dbapi_exception(e, None, None, None, None) + assert isinstance(self._transaction, TwoPhaseTransaction) + try: + self.engine.dialect.do_commit_twophase(self, xid, is_prepared) + except BaseException as e: + self._handle_dbapi_exception(e, None, None, None, None) - def close(self): + def close(self) -> None: """Close this :class:`_engine.Connection`. This results in a release of the underlying database @@ -1050,7 +1104,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): # as we just closed the transaction, close the connection # pool connection without doing an additional reset if skip_reset: - conn._close_no_reset() + cast("_ConnectionFairy", conn)._close_no_reset() else: conn.close() @@ -1061,7 +1115,12 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): self._dbapi_connection = None self.__can_reconnect = False - def scalar(self, statement, parameters=None, execution_options=None): + def scalar( + self, + statement: Executable, + parameters: Optional[_CoreSingleExecuteParams] = None, + execution_options: Optional[_ExecuteOptionsParameter] = None, + ) -> Any: r"""Executes a SQL statement construct and returns a scalar object. This method is shorthand for invoking the @@ -1074,7 +1133,12 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): """ return self.execute(statement, parameters, execution_options).scalar() - def scalars(self, statement, parameters=None, execution_options=None): + def scalars( + self, + statement: Executable, + parameters: Optional[_CoreSingleExecuteParams] = None, + execution_options: Optional[_ExecuteOptionsParameter] = None, + ) -> ScalarResult: """Executes and returns a scalar result set, which yields scalar values from the first column of each row. @@ -1093,10 +1157,10 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): def execute( self, - statement, - parameters: Optional[_ExecuteParams] = None, - execution_options: Optional[_ExecuteOptions] = None, - ): + statement: Executable, + parameters: Optional[_CoreAnyExecuteParams] = None, + execution_options: Optional[_ExecuteOptionsParameter] = None, + ) -> Result: r"""Executes a SQL statement construct and returns a :class:`_engine.Result`. @@ -1140,7 +1204,12 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): execution_options or NO_OPTIONS, ) - def _execute_function(self, func, distilled_parameters, execution_options): + def _execute_function( + self, + func: FunctionElement[Any], + distilled_parameters: _CoreMultiExecuteParams, + execution_options: _ExecuteOptions, + ) -> Result: """Execute a sql.FunctionElement object.""" return self._execute_clauseelement( @@ -1148,14 +1217,20 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): ) def _execute_default( - self, default, distilled_parameters, execution_options - ): + self, + default: ColumnDefault, + distilled_parameters: _CoreMultiExecuteParams, + execution_options: _ExecuteOptions, + ) -> Any: """Execute a schema.ColumnDefault object.""" execution_options = self._execution_options.merge_with( execution_options ) + event_multiparams: Optional[_CoreMultiExecuteParams] + event_params: Optional[_CoreAnyExecuteParams] + # note for event handlers, the "distilled parameters" which is always # a list of dicts is broken out into separate "multiparams" and # "params" collections, which allows the handler to distinguish @@ -1169,6 +1244,8 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): ) = self._invoke_before_exec_event( default, distilled_parameters, execution_options ) + else: + event_multiparams = event_params = None try: conn = self._dbapi_connection @@ -1198,13 +1275,21 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): return ret - def _execute_ddl(self, ddl, distilled_parameters, execution_options): + def _execute_ddl( + self, + ddl: DDLElement, + distilled_parameters: _CoreMultiExecuteParams, + execution_options: _ExecuteOptions, + ) -> Result: """Execute a schema.DDL object.""" execution_options = ddl._execution_options.merge_with( self._execution_options, execution_options ) + event_multiparams: Optional[_CoreMultiExecuteParams] + event_params: Optional[_CoreSingleExecuteParams] + if self._has_events or self.engine._has_events: ( ddl, @@ -1214,6 +1299,8 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): ) = self._invoke_before_exec_event( ddl, distilled_parameters, execution_options ) + else: + event_multiparams = event_params = None exec_opts = self._execution_options.merge_with(execution_options) schema_translate_map = exec_opts.get("schema_translate_map", None) @@ -1243,8 +1330,19 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): return ret def _invoke_before_exec_event( - self, elem, distilled_params, execution_options - ): + self, + elem: Any, + distilled_params: _CoreMultiExecuteParams, + execution_options: _ExecuteOptions, + ) -> Tuple[ + Any, + _CoreMultiExecuteParams, + _CoreMultiExecuteParams, + _CoreSingleExecuteParams, + ]: + + event_multiparams: _CoreMultiExecuteParams + event_params: _CoreSingleExecuteParams if len(distilled_params) == 1: event_multiparams, event_params = [], distilled_params[0] @@ -1275,8 +1373,11 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): return elem, distilled_params, event_multiparams, event_params def _execute_clauseelement( - self, elem, distilled_parameters, execution_options - ): + self, + elem: Executable, + distilled_parameters: _CoreMultiExecuteParams, + execution_options: _ExecuteOptions, + ) -> Result: """Execute a sql.ClauseElement object.""" execution_options = elem._execution_options.merge_with( @@ -1309,7 +1410,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): "schema_translate_map", None ) - compiled_cache = execution_options.get( + compiled_cache: _CompiledCacheType = execution_options.get( "compiled_cache", self.engine._compiled_cache ) @@ -1346,10 +1447,10 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): def _execute_compiled( self, - compiled, - distilled_parameters, - execution_options=_EMPTY_EXECUTION_OPTS, - ): + compiled: Compiled, + distilled_parameters: _CoreMultiExecuteParams, + execution_options: _ExecuteOptionsParameter = _EMPTY_EXECUTION_OPTS, + ) -> Result: """Execute a sql.Compiled object. TODO: why do we have this? likely deprecate or remove @@ -1395,8 +1496,11 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): return ret def exec_driver_sql( - self, statement, parameters=None, execution_options=None - ): + self, + statement: str, + parameters: Optional[_DBAPIAnyExecuteParams] = None, + execution_options: Optional[_ExecuteOptions] = None, + ) -> Result: r"""Executes a SQL statement construct and returns a :class:`_engine.CursorResult`. @@ -1456,7 +1560,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): dialect, dialect.execution_ctx_cls._init_statement, statement, - distilled_parameters, + None, execution_options, statement, distilled_parameters, @@ -1466,14 +1570,14 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): def _execute_context( self, - dialect, - constructor, - statement, - parameters, - execution_options, - *args, - **kw, - ): + dialect: Dialect, + constructor: Callable[..., ExecutionContext], + statement: Union[str, Compiled], + parameters: Optional[_AnyMultiExecuteParams], + execution_options: _ExecuteOptions, + *args: Any, + **kw: Any, + ) -> Result: """Create an :class:`.ExecutionContext` and execute, returning a :class:`_engine.CursorResult`.""" @@ -1491,7 +1595,6 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): self._handle_dbapi_exception( e, str(statement), parameters, None, None ) - return # not reached if ( self._transaction @@ -1514,29 +1617,33 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): if dialect.bind_typing is BindTyping.SETINPUTSIZES: context._set_input_sizes() - cursor, statement, parameters = ( + cursor, str_statement, parameters = ( context.cursor, context.statement, context.parameters, ) + effective_parameters: Optional[_AnyExecuteParams] + if not context.executemany: - parameters = parameters[0] + effective_parameters = parameters[0] + else: + effective_parameters = parameters if self._has_events or self.engine._has_events: for fn in self.dispatch.before_cursor_execute: - statement, parameters = fn( + str_statement, effective_parameters = fn( self, cursor, - statement, - parameters, + str_statement, + effective_parameters, context, context.executemany, ) if self._echo: - self._log_info(statement) + self._log_info(str_statement) stats = context._get_cache_stats() @@ -1545,7 +1652,9 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): "[%s] %r", stats, sql_util._repr_params( - parameters, batches=10, ismulti=context.executemany + effective_parameters, + batches=10, + ismulti=context.executemany, ), ) else: @@ -1554,45 +1663,61 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): % (stats,) ) - evt_handled = False + evt_handled: bool = False try: if context.executemany: + effective_parameters = cast( + "_CoreMultiExecuteParams", effective_parameters + ) if self.dialect._has_events: for fn in self.dialect.dispatch.do_executemany: - if fn(cursor, statement, parameters, context): + if fn( + cursor, + str_statement, + effective_parameters, + context, + ): evt_handled = True break if not evt_handled: self.dialect.do_executemany( - cursor, statement, parameters, context + cursor, str_statement, effective_parameters, context ) - elif not parameters and context.no_parameters: + elif not effective_parameters and context.no_parameters: if self.dialect._has_events: for fn in self.dialect.dispatch.do_execute_no_params: - if fn(cursor, statement, context): + if fn(cursor, str_statement, context): evt_handled = True break if not evt_handled: self.dialect.do_execute_no_params( - cursor, statement, context + cursor, str_statement, context ) else: + effective_parameters = cast( + "_CoreSingleExecuteParams", effective_parameters + ) if self.dialect._has_events: for fn in self.dialect.dispatch.do_execute: - if fn(cursor, statement, parameters, context): + if fn( + cursor, + str_statement, + effective_parameters, + context, + ): evt_handled = True break if not evt_handled: self.dialect.do_execute( - cursor, statement, parameters, context + cursor, str_statement, effective_parameters, context ) if self._has_events or self.engine._has_events: self.dispatch.after_cursor_execute( self, cursor, - statement, - parameters, + str_statement, + effective_parameters, context, context.executemany, ) @@ -1603,12 +1728,18 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): except BaseException as e: self._handle_dbapi_exception( - e, statement, parameters, cursor, context + e, str_statement, effective_parameters, cursor, context ) return result - def _cursor_execute(self, cursor, statement, parameters, context=None): + def _cursor_execute( + self, + cursor: DBAPICursor, + statement: str, + parameters: _DBAPISingleExecuteParams, + context: Optional[ExecutionContext] = None, + ) -> None: """Execute a statement + params on the given cursor. Adds appropriate logging and exception handling. @@ -1648,7 +1779,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): self, cursor, statement, parameters, context, False ) - def _safe_close_cursor(self, cursor): + def _safe_close_cursor(self, cursor: DBAPICursor) -> None: """Close the given cursor, catching exceptions and turning into log warnings. @@ -1665,8 +1796,13 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): _is_disconnect = False def _handle_dbapi_exception( - self, e, statement, parameters, cursor, context - ): + self, + e: BaseException, + statement: Optional[str], + parameters: Optional[_AnyExecuteParams], + cursor: Optional[DBAPICursor], + context: Optional[ExecutionContext], + ) -> NoReturn: exc_info = sys.exc_info() is_exit_exception = util.is_exit_exception(e) @@ -1708,7 +1844,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): sqlalchemy_exception = exc.DBAPIError.instance( statement, parameters, - e, + cast(Exception, e), self.dialect.dbapi.Error, hide_parameters=self.engine.hide_parameters, connection_invalidated=self._is_disconnect, @@ -1784,8 +1920,10 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): if newraise: raise newraise.with_traceback(exc_info[2]) from e elif should_wrap: + assert sqlalchemy_exception is not None raise sqlalchemy_exception.with_traceback(exc_info[2]) from e else: + assert exc_info[1] is not None raise exc_info[1].with_traceback(exc_info[2]) finally: del self._reentrant_error @@ -1793,15 +1931,20 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): del self._is_disconnect if not self.invalidated: dbapi_conn_wrapper = self._dbapi_connection + assert dbapi_conn_wrapper is not None if invalidate_pool_on_disconnect: self.engine.pool._invalidate(dbapi_conn_wrapper, e) self.invalidate(e) @classmethod - def _handle_dbapi_exception_noconnection(cls, e, dialect, engine): + def _handle_dbapi_exception_noconnection( + cls, e: BaseException, dialect: Dialect, engine: Engine + ) -> NoReturn: exc_info = sys.exc_info() - is_disconnect = dialect.is_disconnect(e, None, None) + is_disconnect = isinstance( + e, dialect.dbapi.Error + ) and dialect.is_disconnect(e, None, None) should_wrap = isinstance(e, dialect.dbapi.Error) @@ -1809,7 +1952,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): sqlalchemy_exception = exc.DBAPIError.instance( None, None, - e, + cast(Exception, e), dialect.dbapi.Error, hide_parameters=engine.hide_parameters, connection_invalidated=is_disconnect, @@ -1852,11 +1995,18 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): if newraise: raise newraise.with_traceback(exc_info[2]) from e elif should_wrap: + assert sqlalchemy_exception is not None raise sqlalchemy_exception.with_traceback(exc_info[2]) from e else: + assert exc_info[1] is not None raise exc_info[1].with_traceback(exc_info[2]) - def _run_ddl_visitor(self, visitorcallable, element, **kwargs): + def _run_ddl_visitor( + self, + visitorcallable: Type[Union[SchemaGenerator, SchemaDropper]], + element: DDLElement, + **kwargs: Any, + ) -> None: """run a DDL visitor. This method is only here so that the MockConnection can change the @@ -1871,16 +2021,16 @@ class ExceptionContextImpl(ExceptionContext): def __init__( self, - exception, - sqlalchemy_exception, - engine, - connection, - cursor, - statement, - parameters, - context, - is_disconnect, - invalidate_pool_on_disconnect, + exception: BaseException, + sqlalchemy_exception: Optional[exc.StatementError], + engine: Optional[Engine], + connection: Optional[Connection], + cursor: Optional[DBAPICursor], + statement: Optional[str], + parameters: Optional[_DBAPIAnyExecuteParams], + context: Optional[ExecutionContext], + is_disconnect: bool, + invalidate_pool_on_disconnect: bool, ): self.engine = engine self.connection = connection @@ -1932,33 +2082,35 @@ class Transaction(TransactionalContext): __slots__ = () - _is_root = False + _is_root: bool = False + is_active: bool + connection: Connection - def __init__(self, connection): + def __init__(self, connection: Connection): raise NotImplementedError() @property - def _deactivated_from_connection(self): + def _deactivated_from_connection(self) -> bool: """True if this transaction is totally deactivated from the connection and therefore can no longer affect its state. """ raise NotImplementedError() - def _do_close(self): + def _do_close(self) -> None: raise NotImplementedError() - def _do_rollback(self): + def _do_rollback(self) -> None: raise NotImplementedError() - def _do_commit(self): + def _do_commit(self) -> None: raise NotImplementedError() @property - def is_valid(self): + def is_valid(self) -> bool: return self.is_active and not self.connection.invalidated - def close(self): + def close(self) -> None: """Close this :class:`.Transaction`. If this transaction is the base transaction in a begin/commit @@ -1974,7 +2126,7 @@ class Transaction(TransactionalContext): finally: assert not self.is_active - def rollback(self): + def rollback(self) -> None: """Roll back this :class:`.Transaction`. The implementation of this may vary based on the type of transaction in @@ -1996,7 +2148,7 @@ class Transaction(TransactionalContext): finally: assert not self.is_active - def commit(self): + def commit(self) -> None: """Commit this :class:`.Transaction`. The implementation of this may vary based on the type of transaction in @@ -2017,16 +2169,16 @@ class Transaction(TransactionalContext): finally: assert not self.is_active - def _get_subject(self): + def _get_subject(self) -> Connection: return self.connection - def _transaction_is_active(self): + def _transaction_is_active(self) -> bool: return self.is_active - def _transaction_is_closed(self): + def _transaction_is_closed(self) -> bool: return not self._deactivated_from_connection - def _rollback_can_be_called(self): + def _rollback_can_be_called(self) -> bool: # for RootTransaction / NestedTransaction, it's safe to call # rollback() even if the transaction is deactive and no warnings # will be emitted. tested in @@ -2060,7 +2212,7 @@ class RootTransaction(Transaction): __slots__ = ("connection", "is_active") - def __init__(self, connection): + def __init__(self, connection: Connection): assert connection._transaction is None if connection._trans_context_manager: TransactionalContext._trans_ctx_check(connection) @@ -2070,7 +2222,7 @@ class RootTransaction(Transaction): self.is_active = True - def _deactivate_from_connection(self): + def _deactivate_from_connection(self) -> None: if self.is_active: assert self.connection._transaction is self self.is_active = False @@ -2079,19 +2231,19 @@ class RootTransaction(Transaction): util.warn("transaction already deassociated from connection") @property - def _deactivated_from_connection(self): + def _deactivated_from_connection(self) -> bool: return self.connection._transaction is not self - def _connection_begin_impl(self): + def _connection_begin_impl(self) -> None: self.connection._begin_impl(self) - def _connection_rollback_impl(self): + def _connection_rollback_impl(self) -> None: self.connection._rollback_impl() - def _connection_commit_impl(self): + def _connection_commit_impl(self) -> None: self.connection._commit_impl() - def _close_impl(self, try_deactivate=False): + def _close_impl(self, try_deactivate: bool = False) -> None: try: if self.is_active: self._connection_rollback_impl() @@ -2107,13 +2259,13 @@ class RootTransaction(Transaction): assert not self.is_active assert self.connection._transaction is not self - def _do_close(self): + def _do_close(self) -> None: self._close_impl() - def _do_rollback(self): + def _do_rollback(self) -> None: self._close_impl(try_deactivate=True) - def _do_commit(self): + def _do_commit(self) -> None: if self.is_active: assert self.connection._transaction is self @@ -2176,7 +2328,9 @@ class NestedTransaction(Transaction): __slots__ = ("connection", "is_active", "_savepoint", "_previous_nested") - def __init__(self, connection): + _savepoint: str + + def __init__(self, connection: Connection): assert connection._transaction is not None if connection._trans_context_manager: TransactionalContext._trans_ctx_check(connection) @@ -2186,7 +2340,7 @@ class NestedTransaction(Transaction): self._previous_nested = connection._nested_transaction connection._nested_transaction = self - def _deactivate_from_connection(self, warn=True): + def _deactivate_from_connection(self, warn: bool = True) -> None: if self.connection._nested_transaction is self: self.connection._nested_transaction = self._previous_nested elif warn: @@ -2195,10 +2349,10 @@ class NestedTransaction(Transaction): ) @property - def _deactivated_from_connection(self): + def _deactivated_from_connection(self) -> bool: return self.connection._nested_transaction is not self - def _cancel(self): + def _cancel(self) -> None: # called by RootTransaction when the outer transaction is # committed, rolled back, or closed to cancel all savepoints # without any action being taken @@ -2207,9 +2361,15 @@ class NestedTransaction(Transaction): if self._previous_nested: self._previous_nested._cancel() - def _close_impl(self, deactivate_from_connection, warn_already_deactive): + def _close_impl( + self, deactivate_from_connection: bool, warn_already_deactive: bool + ) -> None: try: - if self.is_active and self.connection._transaction.is_active: + if ( + self.is_active + and self.connection._transaction + and self.connection._transaction.is_active + ): self.connection._rollback_to_savepoint_impl(self._savepoint) finally: self.is_active = False @@ -2221,13 +2381,13 @@ class NestedTransaction(Transaction): if deactivate_from_connection: assert self.connection._nested_transaction is not self - def _do_close(self): + def _do_close(self) -> None: self._close_impl(True, False) - def _do_rollback(self): + def _do_rollback(self) -> None: self._close_impl(True, True) - def _do_commit(self): + def _do_commit(self) -> None: if self.is_active: try: self.connection._release_savepoint_impl(self._savepoint) @@ -2261,12 +2421,14 @@ class TwoPhaseTransaction(RootTransaction): __slots__ = ("xid", "_is_prepared") - def __init__(self, connection, xid): + xid: Any + + def __init__(self, connection: Connection, xid: Any): self._is_prepared = False self.xid = xid super(TwoPhaseTransaction, self).__init__(connection) - def prepare(self): + def prepare(self) -> None: """Prepare this :class:`.TwoPhaseTransaction`. After a PREPARE, the transaction can be committed. @@ -2277,13 +2439,13 @@ class TwoPhaseTransaction(RootTransaction): self.connection._prepare_twophase_impl(self.xid) self._is_prepared = True - def _connection_begin_impl(self): + def _connection_begin_impl(self) -> None: self.connection._begin_twophase_impl(self) - def _connection_rollback_impl(self): + def _connection_rollback_impl(self) -> None: self.connection._rollback_twophase_impl(self.xid, self._is_prepared) - def _connection_commit_impl(self): + def _connection_commit_impl(self) -> None: self.connection._commit_twophase_impl(self.xid, self._is_prepared) @@ -2310,17 +2472,23 @@ class Engine( """ - _execution_options = _EMPTY_EXECUTION_OPTS - _has_events = False - _connection_cls = Connection - _sqla_logger_namespace = "sqlalchemy.engine.Engine" - _is_future = False + dispatch: dispatcher[ConnectionEventsTarget] - _schema_translate_map = None + _compiled_cache: Optional[_CompiledCacheType] + + _execution_options: _ExecuteOptions = _EMPTY_EXECUTION_OPTS + _has_events: bool = False + _connection_cls: Type[Connection] = Connection + _sqla_logger_namespace: str = "sqlalchemy.engine.Engine" + _is_future: bool = False + + _schema_translate_map: Optional[_SchemaTranslateMapType] = None + _option_cls: Type[OptionEngine] dialect: Dialect pool: Pool url: URL + hide_parameters: bool def __init__( self, @@ -2328,7 +2496,7 @@ class Engine( dialect: Dialect, url: URL, logging_name: Optional[str] = None, - echo: Union[None, str, bool] = None, + echo: Optional[_EchoFlagType] = None, query_cache_size: int = 500, execution_options: Optional[Mapping[str, Any]] = None, hide_parameters: bool = False, @@ -2350,7 +2518,7 @@ class Engine( if execution_options: self.update_execution_options(**execution_options) - def _lru_size_alert(self, cache): + def _lru_size_alert(self, cache: util.LRUCache[Any, Any]) -> None: if self._should_log_info: self.logger.info( "Compiled cache size pruning from %d items to %d. " @@ -2360,10 +2528,10 @@ class Engine( ) @property - def engine(self): + def engine(self) -> Engine: return self - def clear_compiled_cache(self): + def clear_compiled_cache(self) -> None: """Clear the compiled cache associated with the dialect. This applies **only** to the built-in cache that is established @@ -2377,7 +2545,7 @@ class Engine( if self._compiled_cache: self._compiled_cache.clear() - def update_execution_options(self, **opt): + def update_execution_options(self, **opt: Any) -> None: r"""Update the default execution_options dictionary of this :class:`_engine.Engine`. @@ -2394,11 +2562,11 @@ class Engine( :meth:`_engine.Engine.execution_options` """ - self._execution_options = self._execution_options.union(opt) self.dispatch.set_engine_execution_options(self, opt) + self._execution_options = self._execution_options.union(opt) self.dialect.set_engine_execution_options(self, opt) - def execution_options(self, **opt): + def execution_options(self, **opt: Any) -> OptionEngine: """Return a new :class:`_engine.Engine` that will provide :class:`_engine.Connection` objects with the given execution options. @@ -2478,7 +2646,7 @@ class Engine( """ # noqa E501 return self._option_cls(self, opt) - def get_execution_options(self): + def get_execution_options(self) -> _ExecuteOptions: """Get the non-SQL options which will take effect during execution. .. versionadded: 1.3 @@ -2490,14 +2658,14 @@ class Engine( return self._execution_options @property - def name(self): + def name(self) -> str: """String name of the :class:`~sqlalchemy.engine.interfaces.Dialect` in use by this :class:`Engine`.""" return self.dialect.name @property - def driver(self): + def driver(self) -> str: """Driver name of the :class:`~sqlalchemy.engine.interfaces.Dialect` in use by this :class:`Engine`.""" @@ -2505,10 +2673,10 @@ class Engine( echo = log.echo_property() - def __repr__(self): + def __repr__(self) -> str: return "Engine(%r)" % (self.url,) - def dispose(self): + def dispose(self) -> None: """Dispose of the connection pool used by this :class:`_engine.Engine`. @@ -2538,7 +2706,9 @@ class Engine( self.dispatch.engine_disposed(self) @contextlib.contextmanager - def _optional_conn_ctx_manager(self, connection=None): + def _optional_conn_ctx_manager( + self, connection: Optional[Connection] = None + ) -> Iterator[Connection]: if connection is None: with self.connect() as conn: yield conn @@ -2546,7 +2716,7 @@ class Engine( yield connection @contextlib.contextmanager - def begin(self): + def begin(self) -> Iterator[Connection]: """Return a context manager delivering a :class:`_engine.Connection` with a :class:`.Transaction` established. @@ -2576,11 +2746,16 @@ class Engine( with conn.begin(): yield conn - def _run_ddl_visitor(self, visitorcallable, element, **kwargs): + def _run_ddl_visitor( + self, + visitorcallable: Type[Union[SchemaGenerator, SchemaDropper]], + element: DDLElement, + **kwargs: Any, + ) -> None: with self.begin() as conn: conn._run_ddl_visitor(visitorcallable, element, **kwargs) - def connect(self): + def connect(self) -> Connection: """Return a new :class:`_engine.Connection` object. The :class:`_engine.Connection` acts as a Python context manager, so @@ -2605,7 +2780,7 @@ class Engine( return self._connection_cls(self) - def raw_connection(self): + def raw_connection(self) -> PoolProxiedConnection: """Return a "raw" DBAPI connection from the connection pool. The returned object is a proxied version of the DBAPI @@ -2630,10 +2805,20 @@ class Engine( return self.pool.connect() -class OptionEngineMixin: +class OptionEngineMixin(log.Identified): _sa_propagate_class_events = False - def __init__(self, proxied, execution_options): + dispatch: dispatcher[ConnectionEventsTarget] + _compiled_cache: Optional[_CompiledCacheType] + dialect: Dialect + pool: Pool + url: URL + hide_parameters: bool + echo: log.echo_property + + def __init__( + self, proxied: Engine, execution_options: _ExecuteOptionsParameter + ): self._proxied = proxied self.url = proxied.url self.dialect = proxied.dialect @@ -2660,27 +2845,34 @@ class OptionEngineMixin: self._execution_options = proxied._execution_options self.update_execution_options(**execution_options) - def _get_pool(self): - return self._proxied.pool + def update_execution_options(self, **opt: Any) -> None: + raise NotImplementedError() - def _set_pool(self, pool): - self._proxied.pool = pool + if not typing.TYPE_CHECKING: + # https://github.com/python/typing/discussions/1095 - pool = property(_get_pool, _set_pool) + @property + def pool(self) -> Pool: + return self._proxied.pool - def _get_has_events(self): - return self._proxied._has_events or self.__dict__.get( - "_has_events", False - ) + @pool.setter + def pool(self, pool: Pool) -> None: + self._proxied.pool = pool - def _set_has_events(self, value): - self.__dict__["_has_events"] = value + @property + def _has_events(self) -> bool: + return self._proxied._has_events or self.__dict__.get( + "_has_events", False + ) - _has_events = property(_get_has_events, _set_has_events) + @_has_events.setter + def _has_events(self, value: bool) -> None: + self.__dict__["_has_events"] = value class OptionEngine(OptionEngineMixin, Engine): - pass + def update_execution_options(self, **opt: Any) -> None: + Engine.update_execution_options(self, **opt) Engine._option_cls = OptionEngine diff --git a/lib/sqlalchemy/engine/characteristics.py b/lib/sqlalchemy/engine/characteristics.py index c3674c931e..c0feb000be 100644 --- a/lib/sqlalchemy/engine/characteristics.py +++ b/lib/sqlalchemy/engine/characteristics.py @@ -1,6 +1,13 @@ from __future__ import annotations import abc +import typing +from typing import Any +from typing import ClassVar + +if typing.TYPE_CHECKING: + from .interfaces import DBAPIConnection + from .interfaces import Dialect class ConnectionCharacteristic(abc.ABC): @@ -25,18 +32,24 @@ class ConnectionCharacteristic(abc.ABC): __slots__ = () - transactional = False + transactional: ClassVar[bool] = False @abc.abstractmethod - def reset_characteristic(self, dialect, dbapi_conn): + def reset_characteristic( + self, dialect: Dialect, dbapi_conn: DBAPIConnection + ) -> None: """Reset the characteristic on the connection to its default value.""" @abc.abstractmethod - def set_characteristic(self, dialect, dbapi_conn, value): + def set_characteristic( + self, dialect: Dialect, dbapi_conn: DBAPIConnection, value: Any + ) -> None: """set characteristic on the connection to a given value.""" @abc.abstractmethod - def get_characteristic(self, dialect, dbapi_conn): + def get_characteristic( + self, dialect: Dialect, dbapi_conn: DBAPIConnection + ) -> Any: """Given a DBAPI connection, get the current value of the characteristic. @@ -44,13 +57,19 @@ class ConnectionCharacteristic(abc.ABC): class IsolationLevelCharacteristic(ConnectionCharacteristic): - transactional = True + transactional: ClassVar[bool] = True - def reset_characteristic(self, dialect, dbapi_conn): + def reset_characteristic( + self, dialect: Dialect, dbapi_conn: DBAPIConnection + ) -> None: dialect.reset_isolation_level(dbapi_conn) - def set_characteristic(self, dialect, dbapi_conn, value): + def set_characteristic( + self, dialect: Dialect, dbapi_conn: DBAPIConnection, value: Any + ) -> None: dialect._assert_and_set_isolation_level(dbapi_conn, value) - def get_characteristic(self, dialect, dbapi_conn): + def get_characteristic( + self, dialect: Dialect, dbapi_conn: DBAPIConnection + ) -> Any: return dialect.get_isolation_level(dbapi_conn) diff --git a/lib/sqlalchemy/engine/create.py b/lib/sqlalchemy/engine/create.py index ac3d6a2d89..cb5219396b 100644 --- a/lib/sqlalchemy/engine/create.py +++ b/lib/sqlalchemy/engine/create.py @@ -7,7 +7,12 @@ from __future__ import annotations +import inspect +import typing from typing import Any +from typing import cast +from typing import Dict +from typing import Optional from typing import Union from . import base @@ -21,6 +26,9 @@ from ..pool import _AdhocProxiedConnection from ..pool import ConnectionPoolEntry from ..sql import compiler +if typing.TYPE_CHECKING: + from .base import Engine + @util.deprecated_params( strategy=( @@ -46,7 +54,7 @@ from ..sql import compiler "is deprecated and will be removed in a future release. ", ), ) -def create_engine(url: Union[str, "_url.URL"], **kwargs: Any) -> "base.Engine": +def create_engine(url: Union[str, "_url.URL"], **kwargs: Any) -> Engine: """Create a new :class:`_engine.Engine` instance. The standard calling form is to send the :ref:`URL ` as the @@ -452,7 +460,8 @@ def create_engine(url: Union[str, "_url.URL"], **kwargs: Any) -> "base.Engine": if "strategy" in kwargs: strat = kwargs.pop("strategy") if strat == "mock": - return create_mock_engine(url, **kwargs) + # this case is deprecated + return create_mock_engine(url, **kwargs) # type: ignore else: raise exc.ArgumentError("unknown strategy: %r" % strat) @@ -472,14 +481,14 @@ def create_engine(url: Union[str, "_url.URL"], **kwargs: Any) -> "base.Engine": if kwargs.pop("_coerce_config", False): - def pop_kwarg(key, default=None): + def pop_kwarg(key: str, default: Optional[Any] = None) -> Any: value = kwargs.pop(key, default) if key in dialect_cls.engine_config_types: value = dialect_cls.engine_config_types[key](value) return value else: - pop_kwarg = kwargs.pop + pop_kwarg = kwargs.pop # type: ignore dialect_args = {} # consume dialect arguments from kwargs @@ -490,10 +499,29 @@ def create_engine(url: Union[str, "_url.URL"], **kwargs: Any) -> "base.Engine": dbapi = kwargs.pop("module", None) if dbapi is None: dbapi_args = {} - for k in util.get_func_kwargs(dialect_cls.dbapi): + + if "import_dbapi" in dialect_cls.__dict__: + dbapi_meth = dialect_cls.import_dbapi + + elif hasattr(dialect_cls, "dbapi") and inspect.ismethod( + dialect_cls.dbapi + ): + util.warn_deprecated( + "The dbapi() classmethod on dialect classes has been " + "renamed to import_dbapi(). Implement an import_dbapi() " + f"classmethod directly on class {dialect_cls} to remove this " + "warning; the old .dbapi() classmethod may be maintained for " + "backwards compatibility.", + "2.0", + ) + dbapi_meth = dialect_cls.dbapi + else: + dbapi_meth = dialect_cls.import_dbapi + + for k in util.get_func_kwargs(dbapi_meth): if k in kwargs: dbapi_args[k] = pop_kwarg(k) - dbapi = dialect_cls.dbapi(**dbapi_args) + dbapi = dbapi_meth(**dbapi_args) dialect_args["dbapi"] = dbapi @@ -509,18 +537,23 @@ def create_engine(url: Union[str, "_url.URL"], **kwargs: Any) -> "base.Engine": dialect = dialect_cls(**dialect_args) # assemble connection arguments - (cargs, cparams) = dialect.create_connect_args(u) + (cargs_tup, cparams) = dialect.create_connect_args(u) cparams.update(pop_kwarg("connect_args", {})) - cargs = list(cargs) # allow mutability + cargs = list(cargs_tup) # allow mutability # look for existing pool or create pool = pop_kwarg("pool", None) if pool is None: - def connect(connection_record=None): + def connect( + connection_record: Optional[ConnectionPoolEntry] = None, + ) -> DBAPIConnection: if dialect._has_events: for fn in dialect.dispatch.do_connect: - connection = fn(dialect, connection_record, cargs, cparams) + connection = cast( + DBAPIConnection, + fn(dialect, connection_record, cargs, cparams), + ) if connection is not None: return connection return dialect.connect(*cargs, **cparams) @@ -596,7 +629,11 @@ def create_engine(url: Union[str, "_url.URL"], **kwargs: Any) -> "base.Engine": do_on_connect = dialect.on_connect_url(u) if do_on_connect: - def on_connect(dbapi_connection, connection_record): + def on_connect( + dbapi_connection: DBAPIConnection, + connection_record: ConnectionPoolEntry, + ) -> None: + assert do_on_connect is not None do_on_connect(dbapi_connection) event.listen(pool, "connect", on_connect) @@ -608,7 +645,7 @@ def create_engine(url: Union[str, "_url.URL"], **kwargs: Any) -> "base.Engine": def first_connect( dbapi_connection: DBAPIConnection, connection_record: ConnectionPoolEntry, - ): + ) -> None: c = base.Connection( engine, connection=_AdhocProxiedConnection( @@ -654,7 +691,9 @@ def create_engine(url: Union[str, "_url.URL"], **kwargs: Any) -> "base.Engine": return engine -def engine_from_config(configuration, prefix="sqlalchemy.", **kwargs): +def engine_from_config( + configuration: Dict[str, Any], prefix: str = "sqlalchemy.", **kwargs: Any +) -> Engine: """Create a new Engine instance using a configuration dictionary. The dictionary is typically produced from a config file. diff --git a/lib/sqlalchemy/engine/cursor.py b/lib/sqlalchemy/engine/cursor.py index 2b077056fa..78805bac1b 100644 --- a/lib/sqlalchemy/engine/cursor.py +++ b/lib/sqlalchemy/engine/cursor.py @@ -13,6 +13,17 @@ from __future__ import annotations import collections import functools +import typing +from typing import Any +from typing import cast +from typing import ClassVar +from typing import Dict +from typing import Iterator +from typing import List +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Type from .result import Result from .result import ResultMetaData @@ -30,19 +41,43 @@ from ..sql.compiler import RM_OBJECTS from ..sql.compiler import RM_RENDERED_NAME from ..sql.compiler import RM_TYPE from ..util import compat +from ..util.typing import Literal _UNPICKLED = util.symbol("unpickled") +if typing.TYPE_CHECKING: + from .interfaces import _DBAPICursorDescription + from .interfaces import ExecutionContext + from .result import _KeyIndexType + from .result import _KeyMapRecType + from .result import _KeyMapType + from .result import _KeyType + from .result import _ProcessorsType + from .result import _ProcessorType + # metadata entry tuple indexes. # using raw tuple is faster than namedtuple. -MD_INDEX = 0 # integer index in cursor.description -MD_RESULT_MAP_INDEX = 1 # integer index in compiled._result_columns -MD_OBJECTS = 2 # other string keys and ColumnElement obj that can match -MD_LOOKUP_KEY = 3 # string key we usually expect for key-based lookup -MD_RENDERED_NAME = 4 # name that is usually in cursor.description -MD_PROCESSOR = 5 # callable to process a result value into a row -MD_UNTRANSLATED = 6 # raw name from cursor.description +MD_INDEX: Literal[0] = 0 # integer index in cursor.description +MD_RESULT_MAP_INDEX: Literal[ + 1 +] = 1 # integer index in compiled._result_columns +MD_OBJECTS: Literal[ + 2 +] = 2 # other string keys and ColumnElement obj that can match +MD_LOOKUP_KEY: Literal[ + 3 +] = 3 # string key we usually expect for key-based lookup +MD_RENDERED_NAME: Literal[4] = 4 # name that is usually in cursor.description +MD_PROCESSOR: Literal[5] = 5 # callable to process a result value into a row +MD_UNTRANSLATED: Literal[6] = 6 # raw name from cursor.description + + +_CursorKeyMapRecType = Tuple[ + int, int, List[Any], str, str, Optional["_ProcessorType"], str +] + +_CursorKeyMapType = Dict["_KeyType", _CursorKeyMapRecType] class CursorResultMetaData(ResultMetaData): @@ -61,22 +96,30 @@ class CursorResultMetaData(ResultMetaData): # if a need arises. ) - returns_rows = True + _keymap: _CursorKeyMapType + _processors: _ProcessorsType + _keymap_by_result_column_idx: Optional[Dict[int, _KeyMapRecType]] + _unpickled: bool + _safe_for_cache: bool + + returns_rows: ClassVar[bool] = True - def _has_key(self, key): + def _has_key(self, key: Any) -> bool: return key in self._keymap - def _for_freeze(self): + def _for_freeze(self) -> ResultMetaData: return SimpleResultMetaData( self._keys, extra=[self._keymap[key][MD_OBJECTS] for key in self._keys], ) - def _reduce(self, keys): - recs = list(self._metadata_for_keys(keys)) + def _reduce(self, keys: Sequence[_KeyIndexType]) -> ResultMetaData: + recs = cast( + "List[_CursorKeyMapRecType]", list(self._metadata_for_keys(keys)) + ) indexes = [rec[MD_INDEX] for rec in recs] - new_keys = [rec[MD_LOOKUP_KEY] for rec in recs] + new_keys: List[str] = [rec[MD_LOOKUP_KEY] for rec in recs] if self._translated_indexes: indexes = [self._translated_indexes[idx] for idx in indexes] @@ -104,7 +147,7 @@ class CursorResultMetaData(ResultMetaData): return new_metadata - def _adapt_to_context(self, context): + def _adapt_to_context(self, context: ExecutionContext) -> ResultMetaData: """When using a cached Compiled construct that has a _result_map, for a new statement that used the cached Compiled, we need to ensure the keymap has the Column objects from our new statement as keys. @@ -112,8 +155,7 @@ class CursorResultMetaData(ResultMetaData): as matched to those of the cached statement. """ - - if not context.compiled._result_columns: + if not context.compiled or not context.compiled._result_columns: return self compiled_statement = context.compiled.statement @@ -122,6 +164,8 @@ class CursorResultMetaData(ResultMetaData): if compiled_statement is invoked_statement: return self + assert invoked_statement is not None + # this is the most common path for Core statements when # caching is used. In ORM use, this codepath is not really used # as the _result_disable_adapt_to_context execution option is @@ -162,7 +206,9 @@ class CursorResultMetaData(ResultMetaData): md._safe_for_cache = self._safe_for_cache return md - def __init__(self, parent, cursor_description): + def __init__( + self, parent: CursorResult, cursor_description: _DBAPICursorDescription + ): context = parent.context self._tuplefilter = None self._translated_indexes = None @@ -229,7 +275,7 @@ class CursorResultMetaData(ResultMetaData): # new in 1.4: get the complete set of all possible keys, # strings, objects, whatever, that are dupes across two # different records, first. - index_by_key = {} + index_by_key: Dict[Any, Any] = {} dupes = set() for metadata_entry in raw: for key in (metadata_entry[MD_RENDERED_NAME],) + ( @@ -626,7 +672,7 @@ class CursorResultMetaData(ResultMetaData): "result set column descriptions" % rec[MD_LOOKUP_KEY] ) - def _index_for_key(self, key, raiseerr=True): + def _index_for_key(self, key: Any, raiseerr: bool = True) -> Optional[int]: # TODO: can consider pre-loading ints and negative ints # into _keymap - also no coverage here if isinstance(key, int): @@ -653,7 +699,9 @@ class CursorResultMetaData(ResultMetaData): # ensure it raises CursorResultMetaData._key_fallback(self, ke.args[0], ke) - def _metadata_for_keys(self, keys): + def _metadata_for_keys( + self, keys: Sequence[Any] + ) -> Iterator[_CursorKeyMapRecType]: for key in keys: if int in key.__class__.__mro__: key = self._keys[key] @@ -707,7 +755,7 @@ class ResultFetchStrategy: __slots__ = () - alternate_cursor_description = None + alternate_cursor_description: Optional[_DBAPICursorDescription] = None def soft_close(self, result, dbapi_cursor): raise NotImplementedError() @@ -1099,10 +1147,9 @@ _NO_RESULT_METADATA = _NoResultMetaData() class BaseCursorResult: """Base class for database result objects.""" - out_parameters = None - _metadata = None - _soft_closed = False - closed = False + _metadata: ResultMetaData + _soft_closed: bool = False + closed: bool = False def __init__(self, context, cursor_strategy, cursor_description): self.context = context @@ -1134,7 +1181,7 @@ class BaseCursorResult: keymap = metadata._keymap processors = metadata._processors - process_row = self._process_row + process_row = Row key_style = process_row._default_key_style _make_row = functools.partial( process_row, metadata, processors, keymap, key_style @@ -1644,7 +1691,7 @@ class CursorResult(BaseCursorResult, Result): """ - _cursor_metadata = CursorResultMetaData + _cursor_metadata: Type[ResultMetaData] = CursorResultMetaData _cursor_strategy_cls = CursorFetchStrategy _no_result_metadata = _NO_RESULT_METADATA _is_cursor = True @@ -1719,7 +1766,9 @@ class BufferedRowResultProxy(ResultProxy): """ - _cursor_strategy_cls = BufferedRowCursorFetchStrategy + _cursor_strategy_cls: Type[ + CursorFetchStrategy + ] = BufferedRowCursorFetchStrategy class FullyBufferedResultProxy(ResultProxy): @@ -1744,5 +1793,3 @@ class BufferedColumnResultProxy(ResultProxy): and this class does not change behavior in any way. """ - - _process_row = BufferedColumnRow diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index a4dbf2361e..0e0c76389a 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -19,12 +19,30 @@ import functools import random import re from time import perf_counter +import typing +from typing import Any +from typing import Callable +from typing import cast +from typing import Dict +from typing import List +from typing import Mapping +from typing import MutableMapping +from typing import MutableSequence +from typing import Optional +from typing import Sequence +from typing import Set +from typing import Tuple +from typing import Type import weakref from . import characteristics from . import cursor as _cursor from . import interfaces from .base import Connection +from .interfaces import CacheStats +from .interfaces import DBAPICursor +from .interfaces import Dialect +from .interfaces import ExecutionContext from .. import event from .. import exc from .. import pool @@ -32,25 +50,49 @@ from .. import types as sqltypes from .. import util from ..sql import compiler from ..sql import expression +from ..sql.compiler import DDLCompiler +from ..sql.compiler import SQLCompiler from ..sql.elements import quoted_name +if typing.TYPE_CHECKING: + from .interfaces import _AnyMultiExecuteParams + from .interfaces import _CoreMultiExecuteParams + from .interfaces import _CoreSingleExecuteParams + from .interfaces import _DBAPIAnyExecuteParams + from .interfaces import _DBAPIMultiExecuteParams + from .interfaces import _DBAPISingleExecuteParams + from .interfaces import _ExecuteOptions + from .result import _ProcessorType + from .row import Row + from .url import URL + from ..event import _ListenerFnType + from ..pool import Pool + from ..pool import PoolProxiedConnection + from ..sql import Executable + from ..sql.compiler import Compiled + from ..sql.compiler import ResultColumnsEntry + from ..sql.schema import Column + from ..sql.type_api import TypeEngine + # When we're handed literal SQL, ensure it's a SELECT query SERVER_SIDE_CURSOR_RE = re.compile(r"\s*SELECT", re.I | re.UNICODE) -CACHE_HIT = util.symbol("CACHE_HIT") -CACHE_MISS = util.symbol("CACHE_MISS") -CACHING_DISABLED = util.symbol("CACHING_DISABLED") -NO_CACHE_KEY = util.symbol("NO_CACHE_KEY") -NO_DIALECT_SUPPORT = util.symbol("NO_DIALECT_SUPPORT") +( + CACHE_HIT, + CACHE_MISS, + CACHING_DISABLED, + NO_CACHE_KEY, + NO_DIALECT_SUPPORT, +) = list(CacheStats) -class DefaultDialect(interfaces.Dialect): +class DefaultDialect(Dialect): """Default implementation of Dialect""" statement_compiler = compiler.SQLCompiler ddl_compiler = compiler.DDLCompiler - type_compiler = compiler.GenericTypeCompiler + type_compiler = compiler.GenericTypeCompiler # type: ignore preparer = compiler.IdentifierPreparer supports_alter = True supports_comments = False @@ -61,8 +103,8 @@ class DefaultDialect(interfaces.Dialect): bind_typing = interfaces.BindTyping.NONE - include_set_input_sizes = None - exclude_set_input_sizes = None + include_set_input_sizes: Optional[Set[Any]] = None + exclude_set_input_sizes: Optional[Set[Any]] = None # the first value we'd get for an autoincrement # column. @@ -70,7 +112,7 @@ class DefaultDialect(interfaces.Dialect): # most DBAPIs happy with this for execute(). # not cx_oracle. - execute_sequence_format = tuple + execute_sequence_format = tuple # type: ignore supports_schemas = True supports_views = True @@ -97,16 +139,16 @@ class DefaultDialect(interfaces.Dialect): {"isolation_level": characteristics.IsolationLevelCharacteristic()} ) - engine_config_types = util.immutabledict( - [ - ("pool_timeout", util.asint), - ("echo", util.bool_or_str("debug")), - ("echo_pool", util.bool_or_str("debug")), - ("pool_recycle", util.asint), - ("pool_size", util.asint), - ("max_overflow", util.asint), - ("future", util.asbool), - ] + engine_config_types: Mapping[str, Any] = util.immutabledict( + { + "pool_timeout": util.asint, + "echo": util.bool_or_str("debug"), + "echo_pool": util.bool_or_str("debug"), + "pool_recycle": util.asint, + "pool_size": util.asint, + "max_overflow": util.asint, + "future": util.asbool, + } ) # if the NUMERIC type @@ -119,19 +161,21 @@ class DefaultDialect(interfaces.Dialect): # length at which to truncate # any identifier. max_identifier_length = 9999 - _user_defined_max_identifier_length = None + _user_defined_max_identifier_length: Optional[int] = None - isolation_level = None + isolation_level: Optional[str] = None # sub-categories of max_identifier_length. # currently these accommodate for MySQL which allows alias names # of 255 but DDL names only of 64. - max_index_name_length = None - max_constraint_name_length = None + max_index_name_length: Optional[int] = None + max_constraint_name_length: Optional[int] = None supports_sane_rowcount = True supports_sane_multi_rowcount = True - colspecs = {} + colspecs: MutableMapping[ + Type["TypeEngine[Any]"], Type["TypeEngine[Any]"] + ] = {} default_paramstyle = "named" supports_default_values = False @@ -160,43 +204,6 @@ class DefaultDialect(interfaces.Dialect): default_schema_name = None - construct_arguments = None - """Optional set of argument specifiers for various SQLAlchemy - constructs, typically schema items. - - To implement, establish as a series of tuples, as in:: - - construct_arguments = [ - (schema.Index, { - "using": False, - "where": None, - "ops": None - }) - ] - - If the above construct is established on the PostgreSQL dialect, - the :class:`.Index` construct will now accept the keyword arguments - ``postgresql_using``, ``postgresql_where``, nad ``postgresql_ops``. - Any other argument specified to the constructor of :class:`.Index` - which is prefixed with ``postgresql_`` will raise :class:`.ArgumentError`. - - A dialect which does not include a ``construct_arguments`` member will - not participate in the argument validation system. For such a dialect, - any argument name is accepted by all participating constructs, within - the namespace of arguments prefixed with that dialect name. The rationale - here is so that third-party dialects that haven't yet implemented this - feature continue to function in the old way. - - .. versionadded:: 0.9.2 - - .. seealso:: - - :class:`.DialectKWArgs` - implementing base class which consumes - :attr:`.DefaultDialect.construct_arguments` - - - """ - # indicates symbol names are # UPPERCASEd if they are case insensitive # within the database. @@ -204,17 +211,6 @@ class DefaultDialect(interfaces.Dialect): # and denormalize_name() must be provided. requires_name_normalize = False - reflection_options = () - - dbapi_exception_translation_map = util.immutabledict() - """mapping used in the extremely unusual case that a DBAPI's - published exceptions don't actually have the __name__ that they - are linked towards. - - .. versionadded:: 1.0.5 - - """ - is_async = False CACHE_HIT = CACHE_HIT @@ -363,10 +359,10 @@ class DefaultDialect(interfaces.Dialect): return self.supports_sane_rowcount @classmethod - def get_pool_class(cls, url): + def get_pool_class(cls, url: URL) -> Type[Pool]: return getattr(cls, "poolclass", pool.QueuePool) - def get_dialect_pool_class(self, url): + def get_dialect_pool_class(self, url: URL) -> Type[Pool]: return self.get_pool_class(url) @classmethod @@ -377,7 +373,7 @@ class DefaultDialect(interfaces.Dialect): except ImportError: pass - def _builtin_onconnect(self): + def _builtin_onconnect(self) -> Optional[_ListenerFnType]: if self._on_connect_isolation_level is not None: def builtin_connect(dbapi_conn, conn_rec): @@ -734,7 +730,7 @@ class StrCompileDialect(DefaultDialect): statement_compiler = compiler.StrSQLCompiler ddl_compiler = compiler.DDLCompiler - type_compiler = compiler.StrSQLTypeCompiler + type_compiler = compiler.StrSQLTypeCompiler # type: ignore preparer = compiler.IdentifierPreparer supports_statement_cache = True @@ -758,24 +754,26 @@ class StrCompileDialect(DefaultDialect): } -class DefaultExecutionContext(interfaces.ExecutionContext): +class DefaultExecutionContext(ExecutionContext): isinsert = False isupdate = False isdelete = False is_crud = False is_text = False isddl = False + executemany = False - compiled = None - statement = None - result_column_struct = None - returned_default_rows = None - execution_options = util.immutabledict() + compiled: Optional[Compiled] = None + result_column_struct: Optional[ + Tuple[List[ResultColumnsEntry], bool, bool, bool] + ] = None + returned_default_rows: Optional[List[Row]] = None + + execution_options: _ExecuteOptions = util.EMPTY_DICT cursor_fetch_strategy = _cursor._DEFAULT_FETCH - cache_stats = None - invoked_statement = None + invoked_statement: Optional[Executable] = None _is_implicit_returning = False _is_explicit_returning = False @@ -786,21 +784,37 @@ class DefaultExecutionContext(interfaces.ExecutionContext): # a hook for SQLite's translation of # result column names # NOTE: pyhive is using this hook, can't remove it :( - _translate_colname = None + _translate_colname: Optional[Callable[[str], str]] = None + + _expanded_parameters: Mapping[str, List[str]] = util.immutabledict() + """used by set_input_sizes(). + + This collection comes from ``ExpandedState.parameter_expansion``. - _expanded_parameters = util.immutabledict() + """ cache_hit = NO_CACHE_KEY + root_connection: Connection + _dbapi_connection: PoolProxiedConnection + dialect: Dialect + unicode_statement: str + cursor: DBAPICursor + compiled_parameters: _CoreMultiExecuteParams + parameters: _DBAPIMultiExecuteParams + extracted_parameters: _CoreSingleExecuteParams + + _empty_dict_params = cast("Mapping[str, Any]", util.EMPTY_DICT) + @classmethod def _init_ddl( cls, - dialect, - connection, - dbapi_connection, - execution_options, - compiled_ddl, - ): + dialect: Dialect, + connection: Connection, + dbapi_connection: PoolProxiedConnection, + execution_options: _ExecuteOptions, + compiled_ddl: DDLCompiler, + ) -> ExecutionContext: """Initialize execution context for a DDLElement construct.""" self = cls.__new__(cls) @@ -832,23 +846,23 @@ class DefaultExecutionContext(interfaces.ExecutionContext): if dialect.positional: self.parameters = [dialect.execute_sequence_format()] else: - self.parameters = [{}] + self.parameters = [self._empty_dict_params] return self @classmethod def _init_compiled( cls, - dialect, - connection, - dbapi_connection, - execution_options, - compiled, - parameters, - invoked_statement, - extracted_parameters, - cache_hit=CACHING_DISABLED, - ): + dialect: Dialect, + connection: Connection, + dbapi_connection: PoolProxiedConnection, + execution_options: _ExecuteOptions, + compiled: SQLCompiler, + parameters: _CoreMultiExecuteParams, + invoked_statement: Executable, + extracted_parameters: _CoreSingleExecuteParams, + cache_hit: CacheStats = CacheStats.CACHING_DISABLED, + ) -> ExecutionContext: """Initialize execution context for a Compiled construct.""" self = cls.__new__(cls) @@ -868,6 +882,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext): compiled._textual_ordered_columns, compiled._loose_column_name_matching, ) + self.isinsert = compiled.isinsert self.isupdate = compiled.isupdate self.isdelete = compiled.isdelete @@ -910,6 +925,10 @@ class DefaultExecutionContext(interfaces.ExecutionContext): processors = compiled._bind_processors + flattened_processors: Mapping[ + str, _ProcessorType + ] = processors # type: ignore[assignment] + if compiled.literal_execute_params or compiled.post_compile_params: if self.executemany: raise exc.InvalidRequestError( @@ -924,14 +943,15 @@ class DefaultExecutionContext(interfaces.ExecutionContext): # re-assign self.unicode_statement self.unicode_statement = expanded_state.statement - # used by set_input_sizes() which is needed for Oracle self._expanded_parameters = expanded_state.parameter_expansion - processors = dict(processors) - processors.update(expanded_state.processors) + flattened_processors = dict(processors) # type: ignore + flattened_processors.update(expanded_state.processors) positiontup = expanded_state.positiontup elif compiled.positional: positiontup = self.compiled.positiontup + else: + positiontup = None if compiled.schema_translate_map: schema_translate_map = self.execution_options.get( @@ -949,42 +969,49 @@ class DefaultExecutionContext(interfaces.ExecutionContext): # Convert the dictionary of bind parameter values # into a dict or list to be sent to the DBAPI's # execute() or executemany() method. - parameters = [] + if compiled.positional: + core_positional_parameters: MutableSequence[Sequence[Any]] = [] + assert positiontup is not None for compiled_params in self.compiled_parameters: - param = [ - processors[key](compiled_params[key]) - if key in processors + l_param: List[Any] = [ + flattened_processors[key](compiled_params[key]) + if key in flattened_processors else compiled_params[key] for key in positiontup ] - parameters.append(dialect.execute_sequence_format(param)) + core_positional_parameters.append( + dialect.execute_sequence_format(l_param) + ) + + self.parameters = core_positional_parameters else: + core_dict_parameters: MutableSequence[Dict[str, Any]] = [] for compiled_params in self.compiled_parameters: - param = { - key: processors[key](compiled_params[key]) - if key in processors + d_param: Dict[str, Any] = { + key: flattened_processors[key](compiled_params[key]) + if key in flattened_processors else compiled_params[key] for key in compiled_params } - parameters.append(param) + core_dict_parameters.append(d_param) - self.parameters = dialect.execute_sequence_format(parameters) + self.parameters = core_dict_parameters return self @classmethod def _init_statement( cls, - dialect, - connection, - dbapi_connection, - execution_options, - statement, - parameters, - ): + dialect: Dialect, + connection: Connection, + dbapi_connection: PoolProxiedConnection, + execution_options: _ExecuteOptions, + statement: str, + parameters: _DBAPIMultiExecuteParams, + ) -> ExecutionContext: """Initialize execution context for a string SQL statement.""" self = cls.__new__(cls) @@ -999,7 +1026,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext): if self.dialect.positional: self.parameters = [dialect.execute_sequence_format()] else: - self.parameters = [{}] + self.parameters = [self._empty_dict_params] elif isinstance(parameters[0], dialect.execute_sequence_format): self.parameters = parameters elif isinstance(parameters[0], dict): @@ -1018,8 +1045,12 @@ class DefaultExecutionContext(interfaces.ExecutionContext): @classmethod def _init_default( - cls, dialect, connection, dbapi_connection, execution_options - ): + cls, + dialect: Dialect, + connection: Connection, + dbapi_connection: PoolProxiedConnection, + execution_options: _ExecuteOptions, + ) -> ExecutionContext: """Initialize execution context for a ColumnDefault construct.""" self = cls.__new__(cls) @@ -1032,7 +1063,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext): self.cursor = self.create_cursor() return self - def _get_cache_stats(self): + def _get_cache_stats(self) -> str: if self.compiled is None: return "raw sql" @@ -1040,19 +1071,22 @@ class DefaultExecutionContext(interfaces.ExecutionContext): ch = self.cache_hit + gen_time = self.compiled._gen_time + assert gen_time is not None + if ch is NO_CACHE_KEY: - return "no key %.5fs" % (now - self.compiled._gen_time,) + return "no key %.5fs" % (now - gen_time,) elif ch is CACHE_HIT: - return "cached since %.4gs ago" % (now - self.compiled._gen_time,) + return "cached since %.4gs ago" % (now - gen_time,) elif ch is CACHE_MISS: - return "generated in %.5fs" % (now - self.compiled._gen_time,) + return "generated in %.5fs" % (now - gen_time,) elif ch is CACHING_DISABLED: - return "caching disabled %.5fs" % (now - self.compiled._gen_time,) + return "caching disabled %.5fs" % (now - gen_time,) elif ch is NO_DIALECT_SUPPORT: return "dialect %s+%s does not support caching %.5fs" % ( self.dialect.name, self.dialect.driver, - now - self.compiled._gen_time, + now - gen_time, ) else: return "unknown" @@ -1073,11 +1107,13 @@ class DefaultExecutionContext(interfaces.ExecutionContext): return self.root_connection.engine @util.memoized_property - def postfetch_cols(self): + def postfetch_cols(self) -> Optional[Sequence[Column[Any]]]: # type: ignore[override] # mypy#4125 # noqa E501 + assert isinstance(self.compiled, SQLCompiler) return self.compiled.postfetch @util.memoized_property - def prefetch_cols(self): + def prefetch_cols(self) -> Optional[Sequence[Column[Any]]]: # type: ignore[override] # mypy#4125 # noqa E501 + assert isinstance(self.compiled, SQLCompiler) if self.isinsert: return self.compiled.insert_prefetch elif self.isupdate: @@ -1086,8 +1122,9 @@ class DefaultExecutionContext(interfaces.ExecutionContext): return () @util.memoized_property - def returning_cols(self): - self.compiled.returning + def returning_cols(self) -> Optional[Sequence[Column[Any]]]: + assert isinstance(self.compiled, SQLCompiler) + return self.compiled.returning @util.memoized_property def no_parameters(self): @@ -1564,7 +1601,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext): str(compiled), type_, parameters=parameters ) - current_parameters = None + current_parameters: Optional[_CoreSingleExecuteParams] = None """A dictionary of parameters applied to the current row. This attribute is only available in the context of a user-defined default diff --git a/lib/sqlalchemy/engine/events.py b/lib/sqlalchemy/engine/events.py index ab462bbe1f..0cbf56a6d5 100644 --- a/lib/sqlalchemy/engine/events.py +++ b/lib/sqlalchemy/engine/events.py @@ -8,14 +8,41 @@ from __future__ import annotations +import typing +from typing import Any +from typing import Dict +from typing import Optional +from typing import Tuple +from typing import Type +from typing import Union + from .base import Engine from .interfaces import ConnectionEventsTarget +from .interfaces import DBAPIConnection +from .interfaces import DBAPICursor from .interfaces import Dialect from .. import event from .. import exc - - -class ConnectionEvents(event.Events): +from ..util.typing import Literal + +if typing.TYPE_CHECKING: + from .base import Connection + from .interfaces import _CoreAnyExecuteParams + from .interfaces import _CoreMultiExecuteParams + from .interfaces import _CoreSingleExecuteParams + from .interfaces import _DBAPIAnyExecuteParams + from .interfaces import _DBAPIMultiExecuteParams + from .interfaces import _DBAPISingleExecuteParams + from .interfaces import _ExecuteOptions + from .interfaces import ExceptionContext + from .interfaces import ExecutionContext + from .result import Result + from ..pool import ConnectionPoolEntry + from ..sql import Executable + from ..sql.elements import BindParameter + + +class ConnectionEvents(event.Events[ConnectionEventsTarget]): """Available events for :class:`_engine.Connection` and :class:`_engine.Engine`. @@ -96,7 +123,12 @@ class ConnectionEvents(event.Events): _dispatch_target = ConnectionEventsTarget @classmethod - def _listen(cls, event_key, retval=False): + def _listen( # type: ignore[override] + cls, + event_key: event._EventKey[ConnectionEventsTarget], + retval: bool = False, + **kw: Any, + ) -> None: target, identifier, fn = ( event_key.dispatch_target, event_key.identifier, @@ -109,7 +141,7 @@ class ConnectionEvents(event.Events): if identifier == "before_execute": orig_fn = fn - def wrap_before_execute( + def wrap_before_execute( # type: ignore conn, clauseelement, multiparams, params, execution_options ): orig_fn( @@ -125,7 +157,7 @@ class ConnectionEvents(event.Events): elif identifier == "before_cursor_execute": orig_fn = fn - def wrap_before_cursor_execute( + def wrap_before_cursor_execute( # type: ignore conn, cursor, statement, parameters, context, executemany ): orig_fn( @@ -163,8 +195,15 @@ class ConnectionEvents(event.Events): ), ) def before_execute( - self, conn, clauseelement, multiparams, params, execution_options - ): + self, + conn: Connection, + clauseelement: Executable, + multiparams: _CoreMultiExecuteParams, + params: _CoreSingleExecuteParams, + execution_options: _ExecuteOptions, + ) -> Optional[ + Tuple[Executable, _CoreMultiExecuteParams, _CoreSingleExecuteParams] + ]: """Intercept high level execute() events, receiving uncompiled SQL constructs and other objects prior to rendering into SQL. @@ -214,13 +253,13 @@ class ConnectionEvents(event.Events): ) def after_execute( self, - conn, - clauseelement, - multiparams, - params, - execution_options, - result, - ): + conn: Connection, + clauseelement: Executable, + multiparams: _CoreMultiExecuteParams, + params: _CoreSingleExecuteParams, + execution_options: _ExecuteOptions, + result: Result, + ) -> None: """Intercept high level execute() events after execute. @@ -244,8 +283,14 @@ class ConnectionEvents(event.Events): """ def before_cursor_execute( - self, conn, cursor, statement, parameters, context, executemany - ): + self, + conn: Connection, + cursor: DBAPICursor, + statement: str, + parameters: _DBAPIAnyExecuteParams, + context: Optional[ExecutionContext], + executemany: bool, + ) -> Optional[Tuple[str, _DBAPIAnyExecuteParams]]: """Intercept low-level cursor execute() events before execution, receiving the string SQL statement and DBAPI-specific parameter list to be invoked against a cursor. @@ -286,8 +331,14 @@ class ConnectionEvents(event.Events): """ def after_cursor_execute( - self, conn, cursor, statement, parameters, context, executemany - ): + self, + conn: Connection, + cursor: DBAPICursor, + statement: str, + parameters: _DBAPIAnyExecuteParams, + context: Optional[ExecutionContext], + executemany: bool, + ) -> None: """Intercept low-level cursor execute() events after execution. :param conn: :class:`_engine.Connection` object @@ -305,7 +356,9 @@ class ConnectionEvents(event.Events): """ - def handle_error(self, exception_context): + def handle_error( + self, exception_context: ExceptionContext + ) -> Optional[BaseException]: r"""Intercept all exceptions processed by the :class:`_engine.Connection`. @@ -439,7 +492,7 @@ class ConnectionEvents(event.Events): @event._legacy_signature( "2.0", ["conn", "branch"], converter=lambda conn: (conn, False) ) - def engine_connect(self, conn): + def engine_connect(self, conn: Connection) -> None: """Intercept the creation of a new :class:`_engine.Connection`. This event is called typically as the direct result of calling @@ -475,7 +528,9 @@ class ConnectionEvents(event.Events): """ - def set_connection_execution_options(self, conn, opts): + def set_connection_execution_options( + self, conn: Connection, opts: Dict[str, Any] + ) -> None: """Intercept when the :meth:`_engine.Connection.execution_options` method is called. @@ -494,8 +549,12 @@ class ConnectionEvents(event.Events): :param opts: dictionary of options that were passed to the :meth:`_engine.Connection.execution_options` method. + This dictionary may be modified in place to affect the ultimate + options which take effect. + + .. versionadded:: 2.0 the ``opts`` dictionary may be modified + in place. - .. versionadded:: 0.9.0 .. seealso:: @@ -507,7 +566,9 @@ class ConnectionEvents(event.Events): """ - def set_engine_execution_options(self, engine, opts): + def set_engine_execution_options( + self, engine: Engine, opts: Dict[str, Any] + ) -> None: """Intercept when the :meth:`_engine.Engine.execution_options` method is called. @@ -526,8 +587,11 @@ class ConnectionEvents(event.Events): :param opts: dictionary of options that were passed to the :meth:`_engine.Connection.execution_options` method. + This dictionary may be modified in place to affect the ultimate + options which take effect. - .. versionadded:: 0.9.0 + .. versionadded:: 2.0 the ``opts`` dictionary may be modified + in place. .. seealso:: @@ -539,7 +603,7 @@ class ConnectionEvents(event.Events): """ - def engine_disposed(self, engine): + def engine_disposed(self, engine: Engine) -> None: """Intercept when the :meth:`_engine.Engine.dispose` method is called. The :meth:`_engine.Engine.dispose` method instructs the engine to @@ -559,14 +623,14 @@ class ConnectionEvents(event.Events): """ - def begin(self, conn): + def begin(self, conn: Connection) -> None: """Intercept begin() events. :param conn: :class:`_engine.Connection` object """ - def rollback(self, conn): + def rollback(self, conn: Connection) -> None: """Intercept rollback() events, as initiated by a :class:`.Transaction`. @@ -584,7 +648,7 @@ class ConnectionEvents(event.Events): """ - def commit(self, conn): + def commit(self, conn: Connection) -> None: """Intercept commit() events, as initiated by a :class:`.Transaction`. @@ -596,7 +660,7 @@ class ConnectionEvents(event.Events): :param conn: :class:`_engine.Connection` object """ - def savepoint(self, conn, name): + def savepoint(self, conn: Connection, name: str) -> None: """Intercept savepoint() events. :param conn: :class:`_engine.Connection` object @@ -604,7 +668,9 @@ class ConnectionEvents(event.Events): """ - def rollback_savepoint(self, conn, name, context): + def rollback_savepoint( + self, conn: Connection, name: str, context: None + ) -> None: """Intercept rollback_savepoint() events. :param conn: :class:`_engine.Connection` object @@ -614,7 +680,9 @@ class ConnectionEvents(event.Events): """ # TODO: deprecate "context" - def release_savepoint(self, conn, name, context): + def release_savepoint( + self, conn: Connection, name: str, context: None + ) -> None: """Intercept release_savepoint() events. :param conn: :class:`_engine.Connection` object @@ -624,7 +692,7 @@ class ConnectionEvents(event.Events): """ # TODO: deprecate "context" - def begin_twophase(self, conn, xid): + def begin_twophase(self, conn: Connection, xid: Any) -> None: """Intercept begin_twophase() events. :param conn: :class:`_engine.Connection` object @@ -632,14 +700,16 @@ class ConnectionEvents(event.Events): """ - def prepare_twophase(self, conn, xid): + def prepare_twophase(self, conn: Connection, xid: Any) -> None: """Intercept prepare_twophase() events. :param conn: :class:`_engine.Connection` object :param xid: two-phase XID identifier """ - def rollback_twophase(self, conn, xid, is_prepared): + def rollback_twophase( + self, conn: Connection, xid: Any, is_prepared: bool + ) -> None: """Intercept rollback_twophase() events. :param conn: :class:`_engine.Connection` object @@ -649,7 +719,9 @@ class ConnectionEvents(event.Events): """ - def commit_twophase(self, conn, xid, is_prepared): + def commit_twophase( + self, conn: Connection, xid: Any, is_prepared: bool + ) -> None: """Intercept commit_twophase() events. :param conn: :class:`_engine.Connection` object @@ -660,7 +732,7 @@ class ConnectionEvents(event.Events): """ -class DialectEvents(event.Events): +class DialectEvents(event.Events[Dialect]): """event interface for execution-replacement functions. These events allow direct instrumentation and replacement @@ -694,14 +766,20 @@ class DialectEvents(event.Events): _dispatch_target = Dialect @classmethod - def _listen(cls, event_key, retval=False): + def _listen( # type: ignore + cls, + event_key: event._EventKey[Dialect], + retval: bool = False, + ) -> None: target = event_key.dispatch_target target._has_events = True event_key.base_listen() @classmethod - def _accept_with(cls, target): + def _accept_with( + cls, target: Union[Engine, Type[Engine], Dialect, Type[Dialect]] + ) -> Union[Dialect, Type[Dialect]]: if isinstance(target, type): if issubclass(target, Engine): return Dialect @@ -712,7 +790,13 @@ class DialectEvents(event.Events): else: return target - def do_connect(self, dialect, conn_rec, cargs, cparams): + def do_connect( + self, + dialect: Dialect, + conn_rec: ConnectionPoolEntry, + cargs: Tuple[Any, ...], + cparams: Dict[str, Any], + ) -> Optional[DBAPIConnection]: """Receive connection arguments before a connection is made. This event is useful in that it allows the handler to manipulate the @@ -745,7 +829,13 @@ class DialectEvents(event.Events): """ - def do_executemany(self, cursor, statement, parameters, context): + def do_executemany( + self, + cursor: DBAPICursor, + statement: str, + parameters: _DBAPIMultiExecuteParams, + context: ExecutionContext, + ) -> Optional[Literal[True]]: """Receive a cursor to have executemany() called. Return the value True to halt further events from invoking, @@ -754,7 +844,9 @@ class DialectEvents(event.Events): """ - def do_execute_no_params(self, cursor, statement, context): + def do_execute_no_params( + self, cursor: DBAPICursor, statement: str, context: ExecutionContext + ) -> Optional[Literal[True]]: """Receive a cursor to have execute() with no parameters called. Return the value True to halt further events from invoking, @@ -763,7 +855,13 @@ class DialectEvents(event.Events): """ - def do_execute(self, cursor, statement, parameters, context): + def do_execute( + self, + cursor: DBAPICursor, + statement: str, + parameters: _DBAPISingleExecuteParams, + context: ExecutionContext, + ) -> Optional[Literal[True]]: """Receive a cursor to have execute() called. Return the value True to halt further events from invoking, @@ -773,8 +871,13 @@ class DialectEvents(event.Events): """ def do_setinputsizes( - self, inputsizes, cursor, statement, parameters, context - ): + self, + inputsizes: Dict[BindParameter[Any], Any], + cursor: DBAPICursor, + statement: str, + parameters: _DBAPIAnyExecuteParams, + context: ExecutionContext, + ) -> None: """Receive the setinputsizes dictionary for possible modification. This event is emitted in the case where the dialect makes use of the diff --git a/lib/sqlalchemy/engine/interfaces.py b/lib/sqlalchemy/engine/interfaces.py index 860c1faf95..545dd0ddcd 100644 --- a/lib/sqlalchemy/engine/interfaces.py +++ b/lib/sqlalchemy/engine/interfaces.py @@ -10,21 +10,31 @@ from __future__ import annotations from enum import Enum +from types import ModuleType from typing import Any +from typing import Awaitable from typing import Callable from typing import Dict from typing import List from typing import Mapping +from typing import MutableMapping from typing import Optional from typing import Sequence from typing import Tuple from typing import Type from typing import TYPE_CHECKING +from typing import TypeVar from typing import Union +from .. import util +from ..event import EventTarget +from ..pool import Pool from ..pool import PoolProxiedConnection +from ..sql.compiler import Compiled as Compiled from ..sql.compiler import Compiled # noqa +from ..sql.compiler import TypeCompiler as TypeCompiler from ..sql.compiler import TypeCompiler # noqa +from ..util import immutabledict from ..util.concurrency import await_only from ..util.typing import _TypeToInstance from ..util.typing import NotRequired @@ -34,12 +44,33 @@ from ..util.typing import TypedDict if TYPE_CHECKING: from .base import Connection from .base import Engine + from .result import Result from .url import URL + from ..event import _ListenerFnType + from ..event import dispatcher + from ..exc import StatementError + from ..sql import Executable from ..sql.compiler import DDLCompiler from ..sql.compiler import IdentifierPreparer + from ..sql.compiler import Linting from ..sql.compiler import SQLCompiler + from ..sql.elements import ClauseElement + from ..sql.schema import Column + from ..sql.schema import ColumnDefault from ..sql.type_api import TypeEngine +ConnectArgsType = Tuple[Tuple[str], MutableMapping[str, Any]] + +_T = TypeVar("_T", bound="Any") + + +class CacheStats(Enum): + CACHE_HIT = 0 + CACHE_MISS = 1 + CACHING_DISABLED = 2 + NO_CACHE_KEY = 3 + NO_DIALECT_SUPPORT = 4 + class DBAPIConnection(Protocol): """protocol representing a :pep:`249` database connection. @@ -65,6 +96,8 @@ class DBAPIConnection(Protocol): def rollback(self) -> None: ... + autocommit: bool + class DBAPIType(Protocol): """protocol representing a :pep:`249` database type. @@ -128,14 +161,14 @@ class DBAPICursor(Protocol): def execute( self, operation: Any, - parameters: Optional[Union[Sequence[Any], Mapping[str, Any]]], + parameters: Optional[_DBAPISingleExecuteParams], ) -> Any: ... def executemany( self, operation: Any, - parameters: Sequence[Union[Sequence[Any], Mapping[str, Any]]], + parameters: Sequence[_DBAPIMultiExecuteParams], ) -> Any: ... @@ -161,6 +194,34 @@ class DBAPICursor(Protocol): ... +_CoreSingleExecuteParams = Mapping[str, Any] +_CoreMultiExecuteParams = Sequence[_CoreSingleExecuteParams] +_CoreAnyExecuteParams = Union[ + _CoreMultiExecuteParams, _CoreSingleExecuteParams +] + +_DBAPISingleExecuteParams = Union[Sequence[Any], _CoreSingleExecuteParams] + +_DBAPIMultiExecuteParams = Union[ + Sequence[Sequence[Any]], _CoreMultiExecuteParams +] +_DBAPIAnyExecuteParams = Union[ + _DBAPIMultiExecuteParams, _DBAPISingleExecuteParams +] +_DBAPICursorDescription = Tuple[str, Any, Any, Any, Any, Any, Any] + +_AnySingleExecuteParams = _DBAPISingleExecuteParams +_AnyMultiExecuteParams = _DBAPIMultiExecuteParams +_AnyExecuteParams = _DBAPIAnyExecuteParams + + +_ExecuteOptions = immutabledict[str, Any] +_ExecuteOptionsParameter = Mapping[str, Any] +_SchemaTranslateMapType = Mapping[str, str] + +_ImmutableExecuteOptions = immutabledict[str, Any] + + class ReflectedIdentity(TypedDict): """represent the reflected IDENTITY structure of a column, corresponding to the :class:`_schema.Identity` construct. @@ -237,7 +298,7 @@ class ReflectedColumn(TypedDict): name: str """column name""" - type: "TypeEngine" + type: TypeEngine[Any] """column type represented as a :class:`.TypeEngine` instance.""" nullable: bool @@ -465,7 +526,10 @@ class BindTyping(Enum): """ -class Dialect: +VersionInfoType = Tuple[Union[int, str], ...] + + +class Dialect(EventTarget): """Define the behavior of a specific database and DB-API combination. Any aspect of metadata definition, SQL query generation, @@ -481,6 +545,8 @@ class Dialect: """ + dispatch: dispatcher[Dialect] + name: str """identifying name for the dialect from a DBAPI-neutral point of view (i.e. 'sqlite') @@ -489,6 +555,29 @@ class Dialect: driver: str """identifying name for the dialect's DBAPI""" + dbapi: ModuleType + """A reference to the DBAPI module object itself. + + SQLAlchemy dialects import DBAPI modules using the classmethod + :meth:`.Dialect.import_dbapi`. The rationale is so that any dialect + module can be imported and used to generate SQL statements without the + need for the actual DBAPI driver to be installed. Only when an + :class:`.Engine` is constructed using :func:`.create_engine` does the + DBAPI get imported; at that point, the creation process will assign + the DBAPI module to this attribute. + + Dialects should therefore implement :meth:`.Dialect.import_dbapi` + which will import the necessary module and return it, and then refer + to ``self.dbapi`` in dialect code in order to refer to the DBAPI module + contents. + + .. versionchanged:: The :attr:`.Dialect.dbapi` attribute is exclusively + used as the per-:class:`.Dialect`-instance reference to the DBAPI + module. The previous not-fully-documented ``.Dialect.dbapi()`` + classmethod is deprecated and replaced by :meth:`.Dialect.import_dbapi`. + + """ + positional: bool """True if the paramstyle for this Dialect is positional.""" @@ -497,21 +586,23 @@ class Dialect: paramstyles). """ - statement_compiler: Type["SQLCompiler"] + compiler_linting: Linting + + statement_compiler: Type[SQLCompiler] """a :class:`.Compiled` class used to compile SQL statements""" - ddl_compiler: Type["DDLCompiler"] + ddl_compiler: Type[DDLCompiler] """a :class:`.Compiled` class used to compile DDL statements""" - type_compiler: _TypeToInstance["TypeCompiler"] + type_compiler: _TypeToInstance[TypeCompiler] """a :class:`.Compiled` class used to compile SQL type objects""" - preparer: Type["IdentifierPreparer"] + preparer: Type[IdentifierPreparer] """a :class:`.IdentifierPreparer` class used to quote identifiers. """ - identifier_preparer: "IdentifierPreparer" + identifier_preparer: IdentifierPreparer """This element will refer to an instance of :class:`.IdentifierPreparer` once a :class:`.DefaultDialect` has been constructed. @@ -531,10 +622,15 @@ class Dialect: """ + default_isolation_level: str + """the isolation that is implicitly present on new connections""" + execution_ctx_cls: Type["ExecutionContext"] """a :class:`.ExecutionContext` class used to handle statement execution""" - execute_sequence_format: Union[Type[Tuple[Any, ...]], Type[List[Any]]] + execute_sequence_format: Union[ + Type[Tuple[Any, ...]], Type[Tuple[List[Any]]] + ] """either the 'tuple' or 'list' type, depending on what cursor.execute() accepts for the second argument (they vary).""" @@ -579,7 +675,7 @@ class Dialect: """ - colspecs: Dict[Type["TypeEngine[Any]"], Type["TypeEngine[Any]"]] + colspecs: MutableMapping[Type["TypeEngine[Any]"], Type["TypeEngine[Any]"]] """A dictionary of TypeEngine classes from sqlalchemy.types mapped to subclasses that are specific to the dialect class. This dictionary is class-level only and is not accessed from the @@ -610,7 +706,55 @@ class Dialect: constraint when that type is used. """ - dbapi_exception_translation_map: Dict[str, str] + construct_arguments: Optional[ + List[Tuple[Type[ClauseElement], Mapping[str, Any]]] + ] = None + """Optional set of argument specifiers for various SQLAlchemy + constructs, typically schema items. + + To implement, establish as a series of tuples, as in:: + + construct_arguments = [ + (schema.Index, { + "using": False, + "where": None, + "ops": None + }) + ] + + If the above construct is established on the PostgreSQL dialect, + the :class:`.Index` construct will now accept the keyword arguments + ``postgresql_using``, ``postgresql_where``, nad ``postgresql_ops``. + Any other argument specified to the constructor of :class:`.Index` + which is prefixed with ``postgresql_`` will raise :class:`.ArgumentError`. + + A dialect which does not include a ``construct_arguments`` member will + not participate in the argument validation system. For such a dialect, + any argument name is accepted by all participating constructs, within + the namespace of arguments prefixed with that dialect name. The rationale + here is so that third-party dialects that haven't yet implemented this + feature continue to function in the old way. + + .. versionadded:: 0.9.2 + + .. seealso:: + + :class:`.DialectKWArgs` - implementing base class which consumes + :attr:`.DefaultDialect.construct_arguments` + + + """ + + reflection_options: Sequence[str] = () + """Sequence of string names indicating keyword arguments that can be + established on a :class:`.Table` object which will be passed as + "reflection options" when using :paramref:`.Table.autoload_with`. + + Current example is "oracle_resolve_synonyms" in the Oracle dialect. + + """ + + dbapi_exception_translation_map: Mapping[str, str] = util.EMPTY_DICT """A dictionary of names that will contain as values the names of pep-249 exceptions ("IntegrityError", "OperationalError", etc) keyed to alternate class names, to support the case where a @@ -660,9 +804,16 @@ class Dialect: is_async: bool """Whether or not this dialect is intended for asyncio use.""" - def create_connect_args( - self, url: "URL" - ) -> Tuple[Tuple[str], Mapping[str, Any]]: + engine_config_types: Mapping[str, Any] + """a mapping of string keys that can be in an engine config linked to + type conversion functions. + + """ + + def _builtin_onconnect(self) -> Optional[_ListenerFnType]: + raise NotImplementedError() + + def create_connect_args(self, url: "URL") -> ConnectArgsType: """Build DB-API compatible connection arguments. Given a :class:`.URL` object, returns a tuple @@ -696,7 +847,25 @@ class Dialect: raise NotImplementedError() @classmethod - def type_descriptor(cls, typeobj: "TypeEngine") -> "TypeEngine": + def import_dbapi(cls) -> ModuleType: + """Import the DBAPI module that is used by this dialect. + + The Python module object returned here will be assigned as an + instance variable to a constructed dialect under the name + ``.dbapi``. + + .. versionchanged:: 2.0 The :meth:`.Dialect.import_dbapi` class + method is renamed from the previous method ``.Dialect.dbapi()``, + which would be replaced at dialect instantiation time by the + DBAPI module itself, thus using the same name in two different ways. + If a ``.Dialect.dbapi()`` classmethod is present on a third-party + dialect, it will be used and a deprecation warning will be emitted. + + """ + raise NotImplementedError() + + @classmethod + def type_descriptor(cls, typeobj: "TypeEngine[_T]") -> "TypeEngine[_T]": """Transform a generic type to a dialect-specific type. Dialect classes will usually use the @@ -735,7 +904,7 @@ class Dialect: connection: "Connection", table_name: str, schema: Optional[str] = None, - **kw, + **kw: Any, ) -> List[ReflectedColumn]: """Return information about columns in ``table_name``. @@ -908,11 +1077,12 @@ class Dialect: table_name: str, schema: Optional[str] = None, **kw: Any, - ) -> Dict[str, Any]: + ) -> Optional[Dict[str, Any]]: r"""Return the "options" for the table identified by ``table_name`` as a dictionary. """ + return None def get_table_comment( self, @@ -1115,7 +1285,7 @@ class Dialect: def do_set_input_sizes( self, cursor: DBAPICursor, - list_of_tuples: List[Tuple[str, Any, "TypeEngine"]], + list_of_tuples: List[Tuple[str, Any, TypeEngine[Any]]], context: "ExecutionContext", ) -> Any: """invoke the cursor.setinputsizes() method with appropriate arguments @@ -1242,7 +1412,7 @@ class Dialect: raise NotImplementedError() - def do_recover_twophase(self, connection: "Connection") -> None: + def do_recover_twophase(self, connection: "Connection") -> List[Any]: """Recover list of uncommitted prepared two phase transaction identifiers on the given connection. @@ -1256,7 +1426,7 @@ class Dialect: self, cursor: DBAPICursor, statement: str, - parameters: List[Union[Dict[str, Any], Tuple[Any]]], + parameters: _DBAPIMultiExecuteParams, context: Optional["ExecutionContext"] = None, ) -> None: """Provide an implementation of ``cursor.executemany(statement, @@ -1268,9 +1438,9 @@ class Dialect: self, cursor: DBAPICursor, statement: str, - parameters: Union[Mapping[str, Any], Tuple[Any]], - context: Optional["ExecutionContext"] = None, - ): + parameters: Optional[_DBAPISingleExecuteParams], + context: Optional[ExecutionContext] = None, + ) -> None: """Provide an implementation of ``cursor.execute(statement, parameters)``.""" @@ -1281,7 +1451,7 @@ class Dialect: cursor: DBAPICursor, statement: str, context: Optional["ExecutionContext"] = None, - ): + ) -> None: """Provide an implementation of ``cursor.execute(statement)``. The parameter collection should not be sent. @@ -1294,14 +1464,14 @@ class Dialect: self, e: Exception, connection: Optional[PoolProxiedConnection], - cursor: DBAPICursor, + cursor: Optional[DBAPICursor], ) -> bool: """Return True if the given DB-API error indicates an invalid connection""" raise NotImplementedError() - def connect(self, *cargs: Any, **cparams: Any) -> Any: + def connect(self, *cargs: Any, **cparams: Any) -> DBAPIConnection: r"""Establish a connection using this dialect's DBAPI. The default implementation of this method is:: @@ -1333,6 +1503,7 @@ class Dialect: :meth:`.Dialect.on_connect` """ + raise NotImplementedError() def on_connect_url(self, url: "URL") -> Optional[Callable[[Any], Any]]: """return a callable which sets up a newly created DBAPI connection. @@ -1542,7 +1713,7 @@ class Dialect: raise NotImplementedError() - def get_default_isolation_level(self, dbapi_conn: Any) -> str: + def get_default_isolation_level(self, dbapi_conn: DBAPIConnection) -> str: """Given a DBAPI connection, return its isolation level, or a default isolation level if one cannot be retrieved. @@ -1562,7 +1733,9 @@ class Dialect: """ raise NotImplementedError() - def get_isolation_level_values(self, dbapi_conn: Any) -> List[str]: + def get_isolation_level_values( + self, dbapi_conn: DBAPIConnection + ) -> List[str]: """return a sequence of string isolation level names that are accepted by this dialect. @@ -1604,8 +1777,13 @@ class Dialect: """ raise NotImplementedError() + def _assert_and_set_isolation_level( + self, dbapi_conn: DBAPIConnection, level: str + ) -> None: + raise NotImplementedError() + @classmethod - def get_dialect_cls(cls, url: "URL") -> Type: + def get_dialect_cls(cls, url: URL) -> Type[Dialect]: """Given a URL, return the :class:`.Dialect` that will be used. This is a hook that allows an external plugin to provide functionality @@ -1621,7 +1799,7 @@ class Dialect: return cls @classmethod - def get_async_dialect_cls(cls, url: "URL") -> None: + def get_async_dialect_cls(cls, url: URL) -> Type[Dialect]: """Given a URL, return the :class:`.Dialect` that will be used by an async engine. @@ -1702,6 +1880,39 @@ class Dialect: """ raise NotImplementedError() + def set_engine_execution_options( + self, engine: Engine, opt: _ExecuteOptionsParameter + ) -> None: + """Establish execution options for a given engine. + + This is implemented by :class:`.DefaultDialect` to establish + event hooks for new :class:`.Connection` instances created + by the given :class:`.Engine` which will then invoke the + :meth:`.Dialect.set_connection_execution_options` method for that + connection. + + """ + raise NotImplementedError() + + def set_connection_execution_options( + self, connection: Connection, opt: _ExecuteOptionsParameter + ) -> None: + """Establish execution options for a given connection. + + This is implemented by :class:`.DefaultDialect` in order to implement + the :paramref:`_engine.Connection.execution_options.isolation_level` + execution option. Dialects can intercept various execution options + which may need to modify state on a particular DBAPI connection. + + .. versionadded:: 1.4 + + """ + raise NotImplementedError() + + def get_dialect_pool_class(self, url: URL) -> Type[Pool]: + """return a Pool class to use for a given URL""" + raise NotImplementedError() + class CreateEnginePlugin: """A set of hooks intended to augment the construction of an @@ -1878,7 +2089,7 @@ class CreateEnginePlugin: """ # noqa: E501 - def __init__(self, url, kwargs): + def __init__(self, url: URL, kwargs: Dict[str, Any]): """Construct a new :class:`.CreateEnginePlugin`. The plugin object is instantiated individually for each call @@ -1905,7 +2116,7 @@ class CreateEnginePlugin: """ self.url = url - def update_url(self, url): + def update_url(self, url: URL) -> URL: """Update the :class:`_engine.URL`. A new :class:`_engine.URL` should be returned. This method is @@ -1920,14 +2131,19 @@ class CreateEnginePlugin: .. versionadded:: 1.4 """ + raise NotImplementedError() - def handle_dialect_kwargs(self, dialect_cls, dialect_args): + def handle_dialect_kwargs( + self, dialect_cls: Type[Dialect], dialect_args: Dict[str, Any] + ) -> None: """parse and modify dialect kwargs""" - def handle_pool_kwargs(self, pool_cls, pool_args): + def handle_pool_kwargs( + self, pool_cls: Type[Pool], pool_args: Dict[str, Any] + ) -> None: """parse and modify pool kwargs""" - def engine_created(self, engine): + def engine_created(self, engine: Engine) -> None: """Receive the :class:`_engine.Engine` object when it is fully constructed. @@ -1941,56 +2157,137 @@ class ExecutionContext: """A messenger object for a Dialect that corresponds to a single execution. - ExecutionContext should have these data members: + """ - connection - Connection object which can be freely used by default value + connection: Connection + """Connection object which can be freely used by default value generators to execute SQL. This Connection should reference the same underlying connection/transactional resources of - root_connection. + root_connection.""" - root_connection - Connection object which is the source of this ExecutionContext. + root_connection: Connection + """Connection object which is the source of this ExecutionContext.""" - dialect - dialect which created this ExecutionContext. + dialect: Dialect + """dialect which created this ExecutionContext.""" - cursor - DB-API cursor procured from the connection, + cursor: DBAPICursor + """DB-API cursor procured from the connection""" - compiled - if passed to constructor, sqlalchemy.engine.base.Compiled object - being executed, + compiled: Optional[Compiled] + """if passed to constructor, sqlalchemy.engine.base.Compiled object + being executed""" - statement - string version of the statement to be executed. Is either + statement: str + """string version of the statement to be executed. Is either passed to the constructor, or must be created from the - sql.Compiled object by the time pre_exec() has completed. + sql.Compiled object by the time pre_exec() has completed.""" - parameters - bind parameters passed to the execute() method. For compiled - statements, this is a dictionary or list of dictionaries. For - textual statements, it should be in a format suitable for the - dialect's paramstyle (i.e. dict or list of dicts for non - positional, list or list of lists/tuples for positional). + invoked_statement: Optional[Executable] + """The Executable statement object that was given in the first place. - isinsert - True if the statement is an INSERT. + This should be structurally equivalent to compiled.statement, but not + necessarily the same object as in a caching scenario the compiled form + will have been extracted from the cache. - isupdate - True if the statement is an UPDATE. + """ - prefetch_cols - a list of Column objects for which a client-side default - was fired off. Applies to inserts and updates. + parameters: _AnyMultiExecuteParams + """bind parameters passed to the execute() or exec_driver_sql() methods. + + These are always stored as a list of parameter entries. A single-element + list corresponds to a ``cursor.execute()`` call and a multiple-element + list corresponds to ``cursor.executemany()``. - postfetch_cols - a list of Column objects for which a server-side default or - inline SQL expression value was fired off. Applies to inserts - and updates. """ - def create_cursor(self): + no_parameters: bool + """True if the execution style does not use parameters""" + + isinsert: bool + """True if the statement is an INSERT.""" + + isupdate: bool + """True if the statement is an UPDATE.""" + + executemany: bool + """True if the parameters have determined this to be an executemany""" + + prefetch_cols: Optional[Sequence[Column[Any]]] + """a list of Column objects for which a client-side default + was fired off. Applies to inserts and updates.""" + + postfetch_cols: Optional[Sequence[Column[Any]]] + """a list of Column objects for which a server-side default or + inline SQL expression value was fired off. Applies to inserts + and updates.""" + + @classmethod + def _init_ddl( + cls, + dialect: Dialect, + connection: Connection, + dbapi_connection: PoolProxiedConnection, + execution_options: _ExecuteOptions, + compiled_ddl: DDLCompiler, + ) -> ExecutionContext: + raise NotImplementedError() + + @classmethod + def _init_compiled( + cls, + dialect: Dialect, + connection: Connection, + dbapi_connection: PoolProxiedConnection, + execution_options: _ExecuteOptions, + compiled: SQLCompiler, + parameters: _CoreMultiExecuteParams, + invoked_statement: Executable, + extracted_parameters: _CoreSingleExecuteParams, + cache_hit: CacheStats = CacheStats.CACHING_DISABLED, + ) -> ExecutionContext: + raise NotImplementedError() + + @classmethod + def _init_statement( + cls, + dialect: Dialect, + connection: Connection, + dbapi_connection: PoolProxiedConnection, + execution_options: _ExecuteOptions, + statement: str, + parameters: _DBAPIMultiExecuteParams, + ) -> ExecutionContext: + raise NotImplementedError() + + @classmethod + def _init_default( + cls, + dialect: Dialect, + connection: Connection, + dbapi_connection: PoolProxiedConnection, + execution_options: _ExecuteOptions, + ) -> ExecutionContext: + raise NotImplementedError() + + def _exec_default( + self, + column: Optional[Column[Any]], + default: ColumnDefault, + type_: Optional[TypeEngine[Any]], + ) -> Any: + raise NotImplementedError() + + def _set_input_sizes(self) -> None: + raise NotImplementedError() + + def _get_cache_stats(self) -> str: + raise NotImplementedError() + + def _setup_result_proxy(self) -> Result: + raise NotImplementedError() + + def create_cursor(self) -> DBAPICursor: """Return a new cursor generated from this ExecutionContext's connection. @@ -2001,7 +2298,7 @@ class ExecutionContext: raise NotImplementedError() - def pre_exec(self): + def pre_exec(self) -> None: """Called before an execution of a compiled statement. If a compiled statement was passed to this ExecutionContext, @@ -2011,7 +2308,9 @@ class ExecutionContext: raise NotImplementedError() - def get_out_parameter_values(self, out_param_names): + def get_out_parameter_values( + self, out_param_names: Sequence[str] + ) -> Sequence[Any]: """Return a sequence of OUT parameter values from a cursor. For dialects that support OUT parameters, this method will be called @@ -2045,7 +2344,7 @@ class ExecutionContext: """ raise NotImplementedError() - def post_exec(self): + def post_exec(self) -> None: """Called after the execution of a compiled statement. If a compiled statement was passed to this ExecutionContext, @@ -2055,20 +2354,20 @@ class ExecutionContext: raise NotImplementedError() - def handle_dbapi_exception(self, e): + def handle_dbapi_exception(self, e: BaseException) -> None: """Receive a DBAPI exception which occurred upon execute, result fetch, etc.""" raise NotImplementedError() - def lastrow_has_defaults(self): + def lastrow_has_defaults(self) -> bool: """Return True if the last INSERT or UPDATE row contained inlined or database-side defaults. """ raise NotImplementedError() - def get_rowcount(self): + def get_rowcount(self) -> Optional[int]: """Return the DBAPI ``cursor.rowcount`` value, or in some cases an interpreted value. @@ -2079,7 +2378,7 @@ class ExecutionContext: raise NotImplementedError() -class ConnectionEventsTarget: +class ConnectionEventsTarget(EventTarget): """An object which can accept events from :class:`.ConnectionEvents`. Includes :class:`_engine.Connection` and :class:`_engine.Engine`. @@ -2088,6 +2387,11 @@ class ConnectionEventsTarget: """ + dispatch: dispatcher[ConnectionEventsTarget] + + +Connectable = ConnectionEventsTarget + class ExceptionContext: """Encapsulate information about an error condition in progress. @@ -2101,7 +2405,7 @@ class ExceptionContext: """ - connection = None + connection: Optional[Connection] """The :class:`_engine.Connection` in use during the exception. This member is present, except in the case of a failure when @@ -2114,7 +2418,7 @@ class ExceptionContext: """ - engine = None + engine: Optional[Engine] """The :class:`_engine.Engine` in use during the exception. This member should always be present, even in the case of a failure @@ -2124,35 +2428,35 @@ class ExceptionContext: """ - cursor = None + cursor: Optional[DBAPICursor] """The DBAPI cursor object. May be None. """ - statement = None + statement: Optional[str] """String SQL statement that was emitted directly to the DBAPI. May be None. """ - parameters = None + parameters: Optional[_DBAPIAnyExecuteParams] """Parameter collection that was emitted directly to the DBAPI. May be None. """ - original_exception = None + original_exception: BaseException """The exception object which was caught. This member is always present. """ - sqlalchemy_exception = None + sqlalchemy_exception: Optional[StatementError] """The :class:`sqlalchemy.exc.StatementError` which wraps the original, and will be raised if exception handling is not circumvented by the event. @@ -2162,7 +2466,7 @@ class ExceptionContext: """ - chained_exception = None + chained_exception: Optional[BaseException] """The exception that was returned by the previous handler in the exception chain, if any. @@ -2173,7 +2477,7 @@ class ExceptionContext: """ - execution_context = None + execution_context: Optional[ExecutionContext] """The :class:`.ExecutionContext` corresponding to the execution operation in progress. @@ -2193,7 +2497,7 @@ class ExceptionContext: """ - is_disconnect = None + is_disconnect: bool """Represent whether the exception as occurred represents a "disconnect" condition. @@ -2218,7 +2522,7 @@ class ExceptionContext: """ - invalidate_pool_on_disconnect = True + invalidate_pool_on_disconnect: bool """Represent whether all connections in the pool should be invalidated when a "disconnect" condition is in effect. @@ -2250,12 +2554,14 @@ class AdaptedConnection: __slots__ = ("_connection",) + _connection: Any + @property - def driver_connection(self): + def driver_connection(self) -> Any: """The connection object as returned by the driver after a connect.""" return self._connection - def run_async(self, fn): + def run_async(self, fn: Callable[[Any], Awaitable[_T]]) -> _T: """Run the awaitable returned by the given function, which is passed the raw asyncio driver connection. @@ -2284,5 +2590,5 @@ class AdaptedConnection: """ return await_only(fn(self._connection)) - def __repr__(self): + def __repr__(self) -> str: return "" % self._connection diff --git a/lib/sqlalchemy/engine/mock.py b/lib/sqlalchemy/engine/mock.py index 76e77a3f3d..a0ba966039 100644 --- a/lib/sqlalchemy/engine/mock.py +++ b/lib/sqlalchemy/engine/mock.py @@ -8,40 +8,69 @@ from __future__ import annotations from operator import attrgetter +import typing +from typing import Any +from typing import Callable +from typing import cast +from typing import Optional +from typing import Type +from typing import Union from . import url as _url from .. import util +if typing.TYPE_CHECKING: + from .base import Connection + from .base import Engine + from .interfaces import _CoreAnyExecuteParams + from .interfaces import _ExecuteOptionsParameter + from .interfaces import Dialect + from .url import URL + from ..sql.base import Executable + from ..sql.ddl import DDLElement + from ..sql.ddl import SchemaDropper + from ..sql.ddl import SchemaGenerator + from ..sql.schema import HasSchemaAttr + + class MockConnection: - def __init__(self, dialect, execute): + def __init__(self, dialect: Dialect, execute: Callable[..., Any]): self._dialect = dialect - self.execute = execute + self._execute_impl = execute - engine = property(lambda s: s) - dialect = property(attrgetter("_dialect")) - name = property(lambda s: s._dialect.name) + engine: Engine = cast(Any, property(lambda s: s)) + dialect: Dialect = cast(Any, property(attrgetter("_dialect"))) + name: str = cast(Any, property(lambda s: s._dialect.name)) - def connect(self, **kwargs): + def connect(self, **kwargs: Any) -> MockConnection: return self - def schema_for_object(self, obj): + def schema_for_object(self, obj: HasSchemaAttr) -> Optional[str]: return obj.schema - def execution_options(self, **kw): + def execution_options(self, **kw: Any) -> MockConnection: return self def _run_ddl_visitor( - self, visitorcallable, element, connection=None, **kwargs - ): + self, + visitorcallable: Type[Union[SchemaGenerator, SchemaDropper]], + element: DDLElement, + **kwargs: Any, + ) -> None: kwargs["checkfirst"] = False visitorcallable(self.dialect, self, **kwargs).traverse_single(element) - def execute(self, object_, *multiparams, **params): - raise NotImplementedError() + def execute( + self, + obj: Executable, + parameters: Optional[_CoreAnyExecuteParams] = None, + execution_options: Optional[_ExecuteOptionsParameter] = None, + ) -> Any: + return self._execute_impl(obj, parameters) -def create_mock_engine(url, executor, **kw): +def create_mock_engine(url: URL, executor: Any, **kw: Any) -> MockConnection: """Create a "mock" engine used for echoing DDL. This is a utility function used for debugging or storing the output of DDL @@ -96,6 +125,6 @@ def create_mock_engine(url, executor, **kw): dialect_args[k] = kw.pop(k) # create dialect - dialect = dialect_cls(**dialect_args) + dialect = dialect_cls(**dialect_args) # type: ignore return MockConnection(dialect, executor) diff --git a/lib/sqlalchemy/engine/processors.py b/lib/sqlalchemy/engine/processors.py index 398c1fa361..7a6a57c03f 100644 --- a/lib/sqlalchemy/engine/processors.py +++ b/lib/sqlalchemy/engine/processors.py @@ -14,9 +14,20 @@ They all share one common characteristic: None is passed through unchanged. """ from __future__ import annotations +import typing + from ._py_processors import str_to_datetime_processor_factory # noqa +from ..util._has_cy import HAS_CYEXTENSION -try: +if typing.TYPE_CHECKING or not HAS_CYEXTENSION: + from ._py_processors import int_to_boolean # noqa + from ._py_processors import str_to_date # noqa + from ._py_processors import str_to_datetime # noqa + from ._py_processors import str_to_time # noqa + from ._py_processors import to_decimal_processor_factory # noqa + from ._py_processors import to_float # noqa + from ._py_processors import to_str # noqa +else: from sqlalchemy.cyextension.processors import ( DecimalResultProcessor, ) # noqa @@ -34,12 +45,3 @@ try: # Decimal('5.00000') whereas the C implementation will # return Decimal('5'). These are equivalent of course. return DecimalResultProcessor(target_class, "%%.%df" % scale).process - -except ImportError: - from ._py_processors import int_to_boolean # noqa - from ._py_processors import str_to_date # noqa - from ._py_processors import str_to_datetime # noqa - from ._py_processors import str_to_time # noqa - from ._py_processors import to_decimal_processor_factory # noqa - from ._py_processors import to_float # noqa - from ._py_processors import to_str # noqa diff --git a/lib/sqlalchemy/engine/result.py b/lib/sqlalchemy/engine/result.py index 3ba1ae519c..0951d57702 100644 --- a/lib/sqlalchemy/engine/result.py +++ b/lib/sqlalchemy/engine/result.py @@ -9,13 +9,26 @@ from __future__ import annotations -import collections.abc as collections_abc +from enum import Enum import functools import itertools import operator import typing +from typing import Any +from typing import Callable +from typing import Dict +from typing import Iterator +from typing import List +from typing import NoReturn +from typing import Optional +from typing import Sequence +from typing import Set +from typing import Tuple +from typing import TypeVar +from typing import Union from .row import Row +from .row import RowMapping from .. import exc from .. import util from ..sql.base import _generative @@ -25,9 +38,42 @@ from ..util._has_cy import HAS_CYEXTENSION if typing.TYPE_CHECKING or not HAS_CYEXTENSION: - from ._py_row import tuplegetter + from ._py_row import tuplegetter as tuplegetter else: - from sqlalchemy.cyextension.resultproxy import tuplegetter + from sqlalchemy.cyextension.resultproxy import tuplegetter as tuplegetter + +if typing.TYPE_CHECKING: + from .row import RowMapping + from ..sql.schema import Column + +_KeyType = Union[str, "Column[Any]"] +_KeyIndexType = Union[str, "Column[Any]", int] + +# is overridden in cursor using _CursorKeyMapRecType +_KeyMapRecType = Any + +_KeyMapType = Dict[_KeyType, _KeyMapRecType] + + +_RowData = Union[Row, RowMapping, Any] +"""A generic form of "row" that accommodates for the different kinds of +"rows" that different result objects return, including row, row mapping, and +scalar values""" + +_RawRowType = Tuple[Any, ...] +"""represents the kind of row we get from a DBAPI cursor""" + +_InterimRowType = Union[Row, RowMapping, Any, _RawRowType] +"""a catchall "anything" kind of return type that can be applied +across all the result types + +""" + +_ProcessorType = Callable[[Any], Any] +_ProcessorsType = Sequence[Optional[_ProcessorType]] +_TupleGetterType = Callable[[Sequence[Any]], Tuple[Any, ...]] +_UniqueFilterType = Callable[[Any], Any] +_UniqueFilterStateType = Tuple[Set[Any], Optional[_UniqueFilterType]] class ResultMetaData: @@ -35,40 +81,58 @@ class ResultMetaData: __slots__ = () - _tuplefilter = None - _translated_indexes = None - _unique_filters = None + _tuplefilter: Optional[_TupleGetterType] = None + _translated_indexes: Optional[Sequence[int]] = None + _unique_filters: Optional[Sequence[Callable[[Any], Any]]] = None + _keymap: _KeyMapType + _keys: Sequence[str] + _processors: Optional[_ProcessorsType] @property - def keys(self): + def keys(self) -> RMKeyView: return RMKeyView(self) - def _has_key(self, key): + def _has_key(self, key: object) -> bool: raise NotImplementedError() - def _for_freeze(self): + def _for_freeze(self) -> ResultMetaData: raise NotImplementedError() - def _key_fallback(self, key, err, raiseerr=True): + def _key_fallback( + self, key: _KeyType, err: Exception, raiseerr: bool = True + ) -> NoReturn: assert raiseerr raise KeyError(key) from err - def _raise_for_nonint(self, key): - raise TypeError( - "TypeError: tuple indices must be integers or slices, not %s" - % type(key).__name__ + def _raise_for_ambiguous_column_name( + self, rec: _KeyMapRecType + ) -> NoReturn: + raise NotImplementedError( + "ambiguous column name logic is implemented for " + "CursorResultMetaData" ) - def _index_for_key(self, keys, raiseerr): + def _index_for_key( + self, key: _KeyIndexType, raiseerr: bool + ) -> Optional[int]: raise NotImplementedError() - def _metadata_for_keys(self, key): + def _indexes_for_keys( + self, keys: Sequence[_KeyIndexType] + ) -> Sequence[int]: raise NotImplementedError() - def _reduce(self, keys): + def _metadata_for_keys( + self, keys: Sequence[_KeyIndexType] + ) -> Iterator[_KeyMapRecType]: raise NotImplementedError() - def _getter(self, key, raiseerr=True): + def _reduce(self, keys: Sequence[_KeyIndexType]) -> ResultMetaData: + raise NotImplementedError() + + def _getter( + self, key: Any, raiseerr: bool = True + ) -> Optional[Callable[[Sequence[_RowData]], _RowData]]: index = self._index_for_key(key, raiseerr) @@ -77,28 +141,33 @@ class ResultMetaData: else: return None - def _row_as_tuple_getter(self, keys): + def _row_as_tuple_getter( + self, keys: Sequence[_KeyIndexType] + ) -> _TupleGetterType: indexes = self._indexes_for_keys(keys) return tuplegetter(*indexes) -class RMKeyView(collections_abc.KeysView): +class RMKeyView(typing.KeysView[Any]): __slots__ = ("_parent", "_keys") - def __init__(self, parent): + _parent: ResultMetaData + _keys: Sequence[str] + + def __init__(self, parent: ResultMetaData): self._parent = parent self._keys = [k for k in parent._keys if k is not None] - def __len__(self): + def __len__(self) -> int: return len(self._keys) - def __repr__(self): + def __repr__(self) -> str: return "{0.__class__.__name__}({0._keys!r})".format(self) - def __iter__(self): + def __iter__(self) -> Iterator[str]: return iter(self._keys) - def __contains__(self, item): + def __contains__(self, item: Any) -> bool: if isinstance(item, int): return False @@ -106,10 +175,10 @@ class RMKeyView(collections_abc.KeysView): # which also don't seem to be tested in test_resultset right now return self._parent._has_key(item) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return list(other) == list(self) - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return list(other) != list(self) @@ -125,20 +194,21 @@ class SimpleResultMetaData(ResultMetaData): "_unique_filters", ) + _keys: Sequence[str] + def __init__( self, - keys, - extra=None, - _processors=None, - _tuplefilter=None, - _translated_indexes=None, - _unique_filters=None, + keys: Sequence[str], + extra: Optional[Sequence[Any]] = None, + _processors: Optional[_ProcessorsType] = None, + _tuplefilter: Optional[_TupleGetterType] = None, + _translated_indexes: Optional[Sequence[int]] = None, + _unique_filters: Optional[Sequence[Callable[[Any], Any]]] = None, ): self._keys = list(keys) self._tuplefilter = _tuplefilter self._translated_indexes = _translated_indexes self._unique_filters = _unique_filters - if extra: recs_names = [ ( @@ -157,10 +227,10 @@ class SimpleResultMetaData(ResultMetaData): self._processors = _processors - def _has_key(self, key): + def _has_key(self, key: object) -> bool: return key in self._keymap - def _for_freeze(self): + def _for_freeze(self) -> ResultMetaData: unique_filters = self._unique_filters if unique_filters and self._tuplefilter: unique_filters = self._tuplefilter(unique_filters) @@ -173,28 +243,28 @@ class SimpleResultMetaData(ResultMetaData): _unique_filters=unique_filters, ) - def __getstate__(self): + def __getstate__(self) -> Dict[str, Any]: return { "_keys": self._keys, "_translated_indexes": self._translated_indexes, } - def __setstate__(self, state): + def __setstate__(self, state: Dict[str, Any]) -> None: if state["_translated_indexes"]: _translated_indexes = state["_translated_indexes"] _tuplefilter = tuplegetter(*_translated_indexes) else: _translated_indexes = _tuplefilter = None - self.__init__( + self.__init__( # type: ignore state["_keys"], _translated_indexes=_translated_indexes, _tuplefilter=_tuplefilter, ) - def _contains(self, value, row): + def _contains(self, value: Any, row: Row) -> bool: return value in row._data - def _index_for_key(self, key, raiseerr=True): + def _index_for_key(self, key: Any, raiseerr: bool = True) -> int: if int in key.__class__.__mro__: key = self._keys[key] try: @@ -202,12 +272,14 @@ class SimpleResultMetaData(ResultMetaData): except KeyError as ke: rec = self._key_fallback(key, ke, raiseerr) - return rec[0] + return rec[0] # type: ignore[no-any-return] - def _indexes_for_keys(self, keys): + def _indexes_for_keys(self, keys: Sequence[Any]) -> Sequence[int]: return [self._keymap[key][0] for key in keys] - def _metadata_for_keys(self, keys): + def _metadata_for_keys( + self, keys: Sequence[Any] + ) -> Iterator[_KeyMapRecType]: for key in keys: if int in key.__class__.__mro__: key = self._keys[key] @@ -219,7 +291,7 @@ class SimpleResultMetaData(ResultMetaData): yield rec - def _reduce(self, keys): + def _reduce(self, keys: Sequence[Any]) -> ResultMetaData: try: metadata_for_keys = [ self._keymap[ @@ -230,7 +302,10 @@ class SimpleResultMetaData(ResultMetaData): except KeyError as ke: self._key_fallback(ke.args[0], ke, True) - indexes, new_keys, extra = zip(*metadata_for_keys) + indexes: Sequence[int] + new_keys: Sequence[str] + extra: Sequence[Any] + indexes, new_keys, extra = zip(*metadata_for_keys) # type: ignore if self._translated_indexes: indexes = [self._translated_indexes[idx] for idx in indexes] @@ -249,7 +324,9 @@ class SimpleResultMetaData(ResultMetaData): return new_metadata -def result_tuple(fields, extra=None): +def result_tuple( + fields: Sequence[str], extra: Optional[Any] = None +) -> Callable[[_RawRowType], Row]: parent = SimpleResultMetaData(fields, extra) return functools.partial( Row, parent, parent._processors, parent._keymap, Row._default_key_style @@ -259,31 +336,58 @@ def result_tuple(fields, extra=None): # a symbol that indicates to internal Result methods that # "no row is returned". We can't use None for those cases where a scalar # filter is applied to rows. -_NO_ROW = util.symbol("NO_ROW") +class _NoRow(Enum): + _NO_ROW = 0 -SelfResultInternal = typing.TypeVar( - "SelfResultInternal", bound="ResultInternal" -) + +_NO_ROW = _NoRow._NO_ROW + +SelfResultInternal = TypeVar("SelfResultInternal", bound="ResultInternal") class ResultInternal(InPlaceGenerative): - _real_result = None - _generate_rows = True - _unique_filter_state = None - _post_creational_filter = None + _real_result: Optional[Result] = None + _generate_rows: bool = True + _row_logging_fn: Optional[Callable[[Any], Any]] + + _unique_filter_state: Optional[_UniqueFilterStateType] = None + _post_creational_filter: Optional[Callable[[Any], Any]] = None _is_cursor = False + _metadata: ResultMetaData + + _source_supports_scalars: bool + + def _fetchiter_impl(self) -> Iterator[_InterimRowType]: + raise NotImplementedError() + + def _fetchone_impl( + self, hard_close: bool = False + ) -> Optional[_InterimRowType]: + raise NotImplementedError() + + def _fetchmany_impl( + self, size: Optional[int] = None + ) -> List[_InterimRowType]: + raise NotImplementedError() + + def _fetchall_impl(self) -> List[_InterimRowType]: + raise NotImplementedError() + + def _soft_close(self, hard: bool = False) -> None: + raise NotImplementedError() + @HasMemoized.memoized_attribute - def _row_getter(self): + def _row_getter(self) -> Optional[Callable[..., _RowData]]: real_result = self._real_result if self._real_result else self if real_result._source_supports_scalars: if not self._generate_rows: return None else: - _proc = real_result._process_row + _proc = Row - def process_row( + def process_row( # type: ignore metadata, processors, keymap, key_style, scalar_obj ): return _proc( @@ -291,9 +395,9 @@ class ResultInternal(InPlaceGenerative): ) else: - process_row = real_result._process_row + process_row = Row # type: ignore - key_style = real_result._process_row._default_key_style + key_style = Row._default_key_style metadata = self._metadata keymap = metadata._keymap @@ -304,19 +408,19 @@ class ResultInternal(InPlaceGenerative): if processors: processors = tf(processors) - _make_row_orig = functools.partial( + _make_row_orig: Callable[..., Any] = functools.partial( process_row, metadata, processors, keymap, key_style ) - def make_row(row): - return _make_row_orig(tf(row)) + def make_row(row: _InterimRowType) -> _InterimRowType: + return _make_row_orig(tf(row)) # type: ignore else: - make_row = functools.partial( + make_row = functools.partial( # type: ignore process_row, metadata, processors, keymap, key_style ) - fns = () + fns: Tuple[Any, ...] = () if real_result._row_logging_fn: fns = (real_result._row_logging_fn,) @@ -326,16 +430,16 @@ class ResultInternal(InPlaceGenerative): if fns: _make_row = make_row - def make_row(row): - row = _make_row(row) + def make_row(row: _InterimRowType) -> _InterimRowType: + interim_row = _make_row(row) for fn in fns: - row = fn(row) - return row + interim_row = fn(interim_row) + return interim_row return make_row @HasMemoized.memoized_attribute - def _iterator_getter(self): + def _iterator_getter(self) -> Callable[..., Iterator[_RowData]]: make_row = self._row_getter @@ -344,9 +448,9 @@ class ResultInternal(InPlaceGenerative): if self._unique_filter_state: uniques, strategy = self._unique_strategy - def iterrows(self): - for row in self._fetchiter_impl(): - obj = make_row(row) if make_row else row + def iterrows(self: Result) -> Iterator[_RowData]: + for raw_row in self._fetchiter_impl(): + obj = make_row(raw_row) if make_row else raw_row hashed = strategy(obj) if strategy else obj if hashed in uniques: continue @@ -357,27 +461,29 @@ class ResultInternal(InPlaceGenerative): else: - def iterrows(self): + def iterrows(self: Result) -> Iterator[_RowData]: for row in self._fetchiter_impl(): - row = make_row(row) if make_row else row + row = make_row(row) if make_row else row # type: ignore if post_creational_filter: row = post_creational_filter(row) yield row return iterrows - def _raw_all_rows(self): + def _raw_all_rows(self) -> List[_RowData]: make_row = self._row_getter + assert make_row is not None rows = self._fetchall_impl() return [make_row(row) for row in rows] - def _allrows(self): + def _allrows(self) -> List[_RowData]: post_creational_filter = self._post_creational_filter make_row = self._row_getter rows = self._fetchall_impl() + made_rows: List[_InterimRowType] if make_row: made_rows = [make_row(row) for row in rows] else: @@ -386,7 +492,7 @@ class ResultInternal(InPlaceGenerative): if self._unique_filter_state: uniques, strategy = self._unique_strategy - rows = [ + interim_rows = [ made_row for made_row, sig_row in [ ( @@ -395,17 +501,19 @@ class ResultInternal(InPlaceGenerative): ) for made_row in made_rows ] - if sig_row not in uniques and not uniques.add(sig_row) + if sig_row not in uniques and not uniques.add(sig_row) # type: ignore # noqa E501 ] else: - rows = made_rows + interim_rows = made_rows if post_creational_filter: - rows = [post_creational_filter(row) for row in rows] - return rows + interim_rows = [ + post_creational_filter(row) for row in interim_rows + ] + return interim_rows @HasMemoized.memoized_attribute - def _onerow_getter(self): + def _onerow_getter(self) -> Callable[..., Union[_NoRow, _RowData]]: make_row = self._row_getter post_creational_filter = self._post_creational_filter @@ -413,7 +521,7 @@ class ResultInternal(InPlaceGenerative): if self._unique_filter_state: uniques, strategy = self._unique_strategy - def onerow(self): + def onerow(self: Result) -> Union[_NoRow, _RowData]: _onerow = self._fetchone_impl while True: row = _onerow() @@ -432,20 +540,22 @@ class ResultInternal(InPlaceGenerative): else: - def onerow(self): + def onerow(self: Result) -> Union[_NoRow, _RowData]: row = self._fetchone_impl() if row is None: return _NO_ROW else: - row = make_row(row) if make_row else row + interim_row: _InterimRowType = ( + make_row(row) if make_row else row + ) if post_creational_filter: - row = post_creational_filter(row) - return row + interim_row = post_creational_filter(interim_row) + return interim_row return onerow @HasMemoized.memoized_attribute - def _manyrow_getter(self): + def _manyrow_getter(self) -> Callable[..., List[_RowData]]: make_row = self._row_getter post_creational_filter = self._post_creational_filter @@ -453,7 +563,12 @@ class ResultInternal(InPlaceGenerative): if self._unique_filter_state: uniques, strategy = self._unique_strategy - def filterrows(make_row, rows, strategy, uniques): + def filterrows( + make_row: Optional[Callable[..., _RowData]], + rows: List[Any], + strategy: Optional[Callable[[Sequence[Any]], Any]], + uniques: Set[Any], + ) -> List[Row]: if make_row: rows = [make_row(row) for row in rows] @@ -466,11 +581,11 @@ class ResultInternal(InPlaceGenerative): return [ made_row for made_row, sig_row in made_rows - if sig_row not in uniques and not uniques.add(sig_row) + if sig_row not in uniques and not uniques.add(sig_row) # type: ignore # noqa: E501 ] - def manyrows(self, num): - collect = [] + def manyrows(self: Result, num: Optional[int]) -> List[_RowData]: + collect: List[_RowData] = [] _manyrows = self._fetchmany_impl @@ -488,6 +603,7 @@ class ResultInternal(InPlaceGenerative): else: rows = _manyrows(num) num = len(rows) + assert make_row is not None collect.extend( filterrows(make_row, rows, strategy, uniques) ) @@ -495,6 +611,8 @@ class ResultInternal(InPlaceGenerative): else: num_required = num + assert num is not None + while num_required: rows = _manyrows(num_required) if not rows: @@ -511,14 +629,14 @@ class ResultInternal(InPlaceGenerative): else: - def manyrows(self, num): + def manyrows(self: Result, num: Optional[int]) -> List[_RowData]: if num is None: real_result = ( self._real_result if self._real_result else self ) num = real_result._yield_per - rows = self._fetchmany_impl(num) + rows: List[_InterimRowType] = self._fetchmany_impl(num) if make_row: rows = [make_row(row) for row in rows] if post_creational_filter: @@ -529,13 +647,13 @@ class ResultInternal(InPlaceGenerative): def _only_one_row( self, - raise_for_second_row, - raise_for_none, - scalar, - ): + raise_for_second_row: bool, + raise_for_none: bool, + scalar: bool, + ) -> Optional[_RowData]: onerow = self._fetchone_impl - row = onerow(hard_close=True) + row: _InterimRowType = onerow(hard_close=True) if row is None: if raise_for_none: raise exc.NoResultFound( @@ -565,7 +683,7 @@ class ResultInternal(InPlaceGenerative): existing_row_hash = strategy(row) if strategy else row while True: - next_row = onerow(hard_close=True) + next_row: Any = onerow(hard_close=True) if next_row is None: next_row = _NO_ROW break @@ -574,6 +692,7 @@ class ResultInternal(InPlaceGenerative): next_row = make_row(next_row) if make_row else next_row if strategy: + assert next_row is not _NO_ROW if existing_row_hash == strategy(next_row): continue elif row == next_row: @@ -608,14 +727,14 @@ class ResultInternal(InPlaceGenerative): row = post_creational_filter(row) if scalar and make_row: - return row[0] + return row[0] # type: ignore else: return row - def _iter_impl(self): + def _iter_impl(self) -> Iterator[_RowData]: return self._iterator_getter(self) - def _next_impl(self): + def _next_impl(self) -> _RowData: row = self._onerow_getter(self) if row is _NO_ROW: raise StopIteration() @@ -624,11 +743,14 @@ class ResultInternal(InPlaceGenerative): @_generative def _column_slices( - self: SelfResultInternal, indexes + self: SelfResultInternal, indexes: Sequence[_KeyIndexType] ) -> SelfResultInternal: real_result = self._real_result if self._real_result else self - if real_result._source_supports_scalars and len(indexes) == 1: + if ( + real_result._source_supports_scalars # type: ignore[attr-defined] # noqa E501 + and len(indexes) == 1 + ): self._generate_rows = False else: self._generate_rows = True @@ -637,7 +759,8 @@ class ResultInternal(InPlaceGenerative): return self @HasMemoized.memoized_attribute - def _unique_strategy(self): + def _unique_strategy(self) -> _UniqueFilterStateType: + assert self._unique_filter_state is not None uniques, strategy = self._unique_filter_state real_result = ( @@ -660,8 +783,10 @@ class ResultInternal(InPlaceGenerative): class _WithKeys: + _metadata: ResultMetaData + # used mainly to share documentation on the keys method. - def keys(self): + def keys(self) -> RMKeyView: """Return an iterable view which yields the string keys that would be represented by each :class:`.Row`. @@ -681,7 +806,7 @@ class _WithKeys: return self._metadata.keys -SelfResult = typing.TypeVar("SelfResult", bound="Result") +SelfResult = TypeVar("SelfResult", bound="Result") class Result(_WithKeys, ResultInternal): @@ -709,23 +834,18 @@ class Result(_WithKeys, ResultInternal): """ - _process_row = Row - - _row_logging_fn = None + _row_logging_fn: Optional[Callable[[Row], Row]] = None - _source_supports_scalars = False + _source_supports_scalars: bool = False - _yield_per = None + _yield_per: Optional[int] = None - _attributes = util.immutabledict() + _attributes: util.immutabledict[Any, Any] = util.immutabledict() - def __init__(self, cursor_metadata): + def __init__(self, cursor_metadata: ResultMetaData): self._metadata = cursor_metadata - def _soft_close(self, hard=False): - raise NotImplementedError() - - def close(self): + def close(self) -> None: """close this :class:`_result.Result`. The behavior of this method is implementation specific, and is @@ -748,7 +868,7 @@ class Result(_WithKeys, ResultInternal): self._soft_close(hard=True) @_generative - def yield_per(self: SelfResult, num) -> SelfResult: + def yield_per(self: SelfResult, num: int) -> SelfResult: """Configure the row-fetching strategy to fetch num rows at a time. This impacts the underlying behavior of the result when iterating over @@ -785,7 +905,9 @@ class Result(_WithKeys, ResultInternal): return self @_generative - def unique(self: SelfResult, strategy=None) -> SelfResult: + def unique( + self: SelfResult, strategy: Optional[_UniqueFilterType] = None + ) -> SelfResult: """Apply unique filtering to the objects returned by this :class:`_engine.Result`. @@ -826,7 +948,7 @@ class Result(_WithKeys, ResultInternal): return self def columns( - self: SelfResultInternal, *col_expressions + self: SelfResultInternal, *col_expressions: _KeyIndexType ) -> SelfResultInternal: r"""Establish the columns that should be returned in each row. @@ -865,7 +987,7 @@ class Result(_WithKeys, ResultInternal): """ return self._column_slices(col_expressions) - def scalars(self, index=0) -> "ScalarResult": + def scalars(self, index: _KeyIndexType = 0) -> ScalarResult: """Return a :class:`_result.ScalarResult` filtering object which will return single elements rather than :class:`_row.Row` objects. @@ -890,7 +1012,9 @@ class Result(_WithKeys, ResultInternal): """ return ScalarResult(self, index) - def _getter(self, key, raiseerr=True): + def _getter( + self, key: _KeyIndexType, raiseerr: bool = True + ) -> Optional[Callable[[Sequence[Any]], _RowData]]: """return a callable that will retrieve the given key from a :class:`.Row`. @@ -901,7 +1025,7 @@ class Result(_WithKeys, ResultInternal): ) return self._metadata._getter(key, raiseerr) - def _tuple_getter(self, keys): + def _tuple_getter(self, keys: Sequence[_KeyIndexType]) -> _TupleGetterType: """return a callable that will retrieve the given keys from a :class:`.Row`. @@ -912,7 +1036,7 @@ class Result(_WithKeys, ResultInternal): ) return self._metadata._row_as_tuple_getter(keys) - def mappings(self) -> "MappingResult": + def mappings(self) -> MappingResult: """Apply a mappings filter to returned rows, returning an instance of :class:`_result.MappingResult`. @@ -928,7 +1052,7 @@ class Result(_WithKeys, ResultInternal): return MappingResult(self) - def _raw_row_iterator(self): + def _raw_row_iterator(self) -> Iterator[_RowData]: """Return a safe iterator that yields raw row data. This is used by the :meth:`._engine.Result.merge` method @@ -937,25 +1061,13 @@ class Result(_WithKeys, ResultInternal): """ raise NotImplementedError() - def _fetchiter_impl(self): - raise NotImplementedError() - - def _fetchone_impl(self, hard_close=False): - raise NotImplementedError() - - def _fetchall_impl(self): - raise NotImplementedError() - - def _fetchmany_impl(self, size=None): - raise NotImplementedError() - - def __iter__(self): + def __iter__(self) -> Iterator[_RowData]: return self._iter_impl() - def __next__(self): + def __next__(self) -> _RowData: return self._next_impl() - def partitions(self, size=None): + def partitions(self, size: Optional[int] = None) -> Iterator[List[Row]]: """Iterate through sub-lists of rows of the size given. Each list will be of the size given, excluding the last list to @@ -989,16 +1101,16 @@ class Result(_WithKeys, ResultInternal): while True: partition = getter(self, size) if partition: - yield partition + yield partition # type: ignore else: break - def fetchall(self): + def fetchall(self) -> List[Row]: """A synonym for the :meth:`_engine.Result.all` method.""" - return self._allrows() + return self._allrows() # type: ignore[return-value] - def fetchone(self): + def fetchone(self) -> Optional[Row]: """Fetch one row. When all rows are exhausted, returns None. @@ -1018,9 +1130,9 @@ class Result(_WithKeys, ResultInternal): if row is _NO_ROW: return None else: - return row + return row # type: ignore[return-value] - def fetchmany(self, size=None): + def fetchmany(self, size: Optional[int] = None) -> List[Row]: """Fetch many rows. When all rows are exhausted, returns an empty list. @@ -1035,9 +1147,9 @@ class Result(_WithKeys, ResultInternal): """ - return self._manyrow_getter(self, size) + return self._manyrow_getter(self, size) # type: ignore[return-value] - def all(self): + def all(self) -> List[Row]: """Return all rows in a list. Closes the result set after invocation. Subsequent invocations @@ -1049,9 +1161,9 @@ class Result(_WithKeys, ResultInternal): """ - return self._allrows() + return self._allrows() # type: ignore[return-value] - def first(self): + def first(self) -> Optional[Row]: """Fetch the first row or None if no row is present. Closes the result set and discards remaining rows. @@ -1083,11 +1195,11 @@ class Result(_WithKeys, ResultInternal): """ - return self._only_one_row( + return self._only_one_row( # type: ignore[return-value] raise_for_second_row=False, raise_for_none=False, scalar=False ) - def one_or_none(self): + def one_or_none(self) -> Optional[Row]: """Return at most one result or raise an exception. Returns ``None`` if the result has no rows. @@ -1107,11 +1219,11 @@ class Result(_WithKeys, ResultInternal): :meth:`_result.Result.one` """ - return self._only_one_row( + return self._only_one_row( # type: ignore[return-value] raise_for_second_row=True, raise_for_none=False, scalar=False ) - def scalar_one(self): + def scalar_one(self) -> Any: """Return exactly one scalar result or raise an exception. This is equivalent to calling :meth:`.Result.scalars` and then @@ -1128,7 +1240,7 @@ class Result(_WithKeys, ResultInternal): raise_for_second_row=True, raise_for_none=True, scalar=True ) - def scalar_one_or_none(self): + def scalar_one_or_none(self) -> Optional[Any]: """Return exactly one or no scalar result. This is equivalent to calling :meth:`.Result.scalars` and then @@ -1145,7 +1257,7 @@ class Result(_WithKeys, ResultInternal): raise_for_second_row=True, raise_for_none=False, scalar=True ) - def one(self): + def one(self) -> Row: """Return exactly one row or raise an exception. Raises :class:`.NoResultFound` if the result returns no @@ -1172,11 +1284,11 @@ class Result(_WithKeys, ResultInternal): :meth:`_result.Result.scalar_one` """ - return self._only_one_row( + return self._only_one_row( # type: ignore[return-value] raise_for_second_row=True, raise_for_none=True, scalar=False ) - def scalar(self): + def scalar(self) -> Any: """Fetch the first column of the first row, and close the result set. Returns None if there are no rows to fetch. @@ -1194,7 +1306,7 @@ class Result(_WithKeys, ResultInternal): raise_for_second_row=False, raise_for_none=False, scalar=True ) - def freeze(self): + def freeze(self) -> FrozenResult: """Return a callable object that will produce copies of this :class:`.Result` when invoked. @@ -1217,7 +1329,7 @@ class Result(_WithKeys, ResultInternal): return FrozenResult(self) - def merge(self, *others): + def merge(self, *others: Result) -> MergedResult: """Merge this :class:`.Result` with other compatible result objects. @@ -1240,28 +1352,37 @@ class FilterResult(ResultInternal): """ - _post_creational_filter = None + _post_creational_filter: Optional[Callable[[Any], Any]] = None - def _soft_close(self, hard=False): + _real_result: Result + + def _soft_close(self, hard: bool = False) -> None: self._real_result._soft_close(hard=hard) @property - def _attributes(self): + def _attributes(self) -> Dict[Any, Any]: return self._real_result._attributes - def _fetchiter_impl(self): + def _fetchiter_impl(self) -> Iterator[_InterimRowType]: return self._real_result._fetchiter_impl() - def _fetchone_impl(self, hard_close=False): + def _fetchone_impl( + self, hard_close: bool = False + ) -> Optional[_InterimRowType]: return self._real_result._fetchone_impl(hard_close=hard_close) - def _fetchall_impl(self): + def _fetchall_impl(self) -> List[_InterimRowType]: return self._real_result._fetchall_impl() - def _fetchmany_impl(self, size=None): + def _fetchmany_impl( + self, size: Optional[int] = None + ) -> List[_InterimRowType]: return self._real_result._fetchmany_impl(size=size) +SelfScalarResult = TypeVar("SelfScalarResult", bound="ScalarResult") + + class ScalarResult(FilterResult): """A wrapper for a :class:`_result.Result` that returns scalar values rather than :class:`_row.Row` values. @@ -1280,7 +1401,9 @@ class ScalarResult(FilterResult): _generate_rows = False - def __init__(self, real_result, index): + _post_creational_filter: Optional[Callable[[Any], Any]] + + def __init__(self, real_result: Result, index: _KeyIndexType): self._real_result = real_result if real_result._source_supports_scalars: @@ -1292,7 +1415,9 @@ class ScalarResult(FilterResult): self._unique_filter_state = real_result._unique_filter_state - def unique(self, strategy=None): + def unique( + self: SelfScalarResult, strategy: Optional[_UniqueFilterType] = None + ) -> SelfScalarResult: """Apply unique filtering to the objects returned by this :class:`_engine.ScalarResult`. @@ -1302,7 +1427,7 @@ class ScalarResult(FilterResult): self._unique_filter_state = (set(), strategy) return self - def partitions(self, size=None): + def partitions(self, size: Optional[int] = None) -> Iterator[List[Any]]: """Iterate through sub-lists of elements of the size given. Equivalent to :meth:`_result.Result.partitions` except that @@ -1320,12 +1445,12 @@ class ScalarResult(FilterResult): else: break - def fetchall(self): + def fetchall(self) -> List[Any]: """A synonym for the :meth:`_engine.ScalarResult.all` method.""" return self._allrows() - def fetchmany(self, size=None): + def fetchmany(self, size: Optional[int] = None) -> List[Any]: """Fetch many objects. Equivalent to :meth:`_result.Result.fetchmany` except that @@ -1335,7 +1460,7 @@ class ScalarResult(FilterResult): """ return self._manyrow_getter(self, size) - def all(self): + def all(self) -> List[Any]: """Return all scalar values in a list. Equivalent to :meth:`_result.Result.all` except that @@ -1345,13 +1470,13 @@ class ScalarResult(FilterResult): """ return self._allrows() - def __iter__(self): + def __iter__(self) -> Iterator[Any]: return self._iter_impl() - def __next__(self): + def __next__(self) -> Any: return self._next_impl() - def first(self): + def first(self) -> Optional[Any]: """Fetch the first object or None if no object is present. Equivalent to :meth:`_result.Result.first` except that @@ -1364,7 +1489,7 @@ class ScalarResult(FilterResult): raise_for_second_row=False, raise_for_none=False, scalar=False ) - def one_or_none(self): + def one_or_none(self) -> Optional[Any]: """Return at most one object or raise an exception. Equivalent to :meth:`_result.Result.one_or_none` except that @@ -1376,7 +1501,7 @@ class ScalarResult(FilterResult): raise_for_second_row=True, raise_for_none=False, scalar=False ) - def one(self): + def one(self) -> Any: """Return exactly one object or raise an exception. Equivalent to :meth:`_result.Result.one` except that @@ -1389,6 +1514,9 @@ class ScalarResult(FilterResult): ) +SelfMappingResult = TypeVar("SelfMappingResult", bound="MappingResult") + + class MappingResult(_WithKeys, FilterResult): """A wrapper for a :class:`_engine.Result` that returns dictionary values rather than :class:`_engine.Row` values. @@ -1402,14 +1530,16 @@ class MappingResult(_WithKeys, FilterResult): _post_creational_filter = operator.attrgetter("_mapping") - def __init__(self, result): + def __init__(self, result: Result): self._real_result = result self._unique_filter_state = result._unique_filter_state self._metadata = result._metadata if result._source_supports_scalars: self._metadata = self._metadata._reduce([0]) - def unique(self, strategy=None): + def unique( + self: SelfMappingResult, strategy: Optional[_UniqueFilterType] = None + ) -> SelfMappingResult: """Apply unique filtering to the objects returned by this :class:`_engine.MappingResult`. @@ -1419,11 +1549,15 @@ class MappingResult(_WithKeys, FilterResult): self._unique_filter_state = (set(), strategy) return self - def columns(self, *col_expressions): + def columns( + self: SelfMappingResult, *col_expressions: _KeyIndexType + ) -> SelfMappingResult: r"""Establish the columns that should be returned in each row.""" return self._column_slices(col_expressions) - def partitions(self, size=None): + def partitions( + self, size: Optional[int] = None + ) -> Iterator[List[RowMapping]]: """Iterate through sub-lists of elements of the size given. Equivalent to :meth:`_result.Result.partitions` except that @@ -1437,16 +1571,16 @@ class MappingResult(_WithKeys, FilterResult): while True: partition = getter(self, size) if partition: - yield partition + yield partition # type: ignore else: break - def fetchall(self): + def fetchall(self) -> List[RowMapping]: """A synonym for the :meth:`_engine.MappingResult.all` method.""" - return self._allrows() + return self._allrows() # type: ignore[return-value] - def fetchone(self): + def fetchone(self) -> Optional[RowMapping]: """Fetch one object. Equivalent to :meth:`_result.Result.fetchone` except that @@ -1459,9 +1593,9 @@ class MappingResult(_WithKeys, FilterResult): if row is _NO_ROW: return None else: - return row + return row # type: ignore[return-value] - def fetchmany(self, size=None): + def fetchmany(self, size: Optional[int] = None) -> List[RowMapping]: """Fetch many objects. Equivalent to :meth:`_result.Result.fetchmany` except that @@ -1470,9 +1604,9 @@ class MappingResult(_WithKeys, FilterResult): """ - return self._manyrow_getter(self, size) + return self._manyrow_getter(self, size) # type: ignore[return-value] - def all(self): + def all(self) -> List[RowMapping]: """Return all scalar values in a list. Equivalent to :meth:`_result.Result.all` except that @@ -1481,15 +1615,15 @@ class MappingResult(_WithKeys, FilterResult): """ - return self._allrows() + return self._allrows() # type: ignore[return-value] - def __iter__(self): - return self._iter_impl() + def __iter__(self) -> Iterator[RowMapping]: + return self._iter_impl() # type: ignore[return-value] - def __next__(self): - return self._next_impl() + def __next__(self) -> RowMapping: + return self._next_impl() # type: ignore[return-value] - def first(self): + def first(self) -> Optional[RowMapping]: """Fetch the first object or None if no object is present. Equivalent to :meth:`_result.Result.first` except that @@ -1498,11 +1632,11 @@ class MappingResult(_WithKeys, FilterResult): """ - return self._only_one_row( + return self._only_one_row( # type: ignore[return-value] raise_for_second_row=False, raise_for_none=False, scalar=False ) - def one_or_none(self): + def one_or_none(self) -> Optional[RowMapping]: """Return at most one object or raise an exception. Equivalent to :meth:`_result.Result.one_or_none` except that @@ -1510,11 +1644,11 @@ class MappingResult(_WithKeys, FilterResult): are returned. """ - return self._only_one_row( + return self._only_one_row( # type: ignore[return-value] raise_for_second_row=True, raise_for_none=False, scalar=False ) - def one(self): + def one(self) -> RowMapping: """Return exactly one object or raise an exception. Equivalent to :meth:`_result.Result.one` except that @@ -1522,7 +1656,7 @@ class MappingResult(_WithKeys, FilterResult): are returned. """ - return self._only_one_row( + return self._only_one_row( # type: ignore[return-value] raise_for_second_row=True, raise_for_none=True, scalar=False ) @@ -1566,7 +1700,9 @@ class FrozenResult: """ - def __init__(self, result): + data: Sequence[Any] + + def __init__(self, result: Result): self.metadata = result._metadata._for_freeze() self._source_supports_scalars = result._source_supports_scalars self._attributes = result._attributes @@ -1576,13 +1712,13 @@ class FrozenResult: else: self.data = result.fetchall() - def rewrite_rows(self): + def rewrite_rows(self) -> List[List[Any]]: if self._source_supports_scalars: return [[elem] for elem in self.data] else: return [list(row) for row in self.data] - def with_new_rows(self, tuple_data): + def with_new_rows(self, tuple_data: Sequence[Row]) -> FrozenResult: fr = FrozenResult.__new__(FrozenResult) fr.metadata = self.metadata fr._attributes = self._attributes @@ -1594,7 +1730,7 @@ class FrozenResult: fr.data = tuple_data return fr - def __call__(self): + def __call__(self) -> Result: result = IteratorResult(self.metadata, iter(self.data)) result._attributes = self._attributes result._source_supports_scalars = self._source_supports_scalars @@ -1603,7 +1739,7 @@ class FrozenResult: class IteratorResult(Result): """A :class:`.Result` that gets data from a Python iterator of - :class:`.Row` objects. + :class:`.Row` objects or similar row-like data. .. versionadded:: 1.4 @@ -1613,17 +1749,17 @@ class IteratorResult(Result): def __init__( self, - cursor_metadata, - iterator, - raw=None, - _source_supports_scalars=False, + cursor_metadata: ResultMetaData, + iterator: Iterator[_RowData], + raw: Optional[Any] = None, + _source_supports_scalars: bool = False, ): self._metadata = cursor_metadata self.iterator = iterator self.raw = raw self._source_supports_scalars = _source_supports_scalars - def _soft_close(self, hard=False, **kw): + def _soft_close(self, hard: bool = False, **kw: Any) -> None: if hard: self._hard_closed = True if self.raw is not None: @@ -1631,18 +1767,18 @@ class IteratorResult(Result): self.iterator = iter([]) self._reset_memoizations() - def _raise_hard_closed(self): + def _raise_hard_closed(self) -> NoReturn: raise exc.ResourceClosedError("This result object is closed.") - def _raw_row_iterator(self): + def _raw_row_iterator(self) -> Iterator[_RowData]: return self.iterator - def _fetchiter_impl(self): + def _fetchiter_impl(self) -> Iterator[_RowData]: if self._hard_closed: self._raise_hard_closed() return self.iterator - def _fetchone_impl(self, hard_close=False): + def _fetchone_impl(self, hard_close: bool = False) -> Optional[_RowData]: if self._hard_closed: self._raise_hard_closed() @@ -1653,27 +1789,26 @@ class IteratorResult(Result): else: return row - def _fetchall_impl(self): + def _fetchall_impl(self) -> List[_RowData]: if self._hard_closed: self._raise_hard_closed() - try: return list(self.iterator) finally: self._soft_close() - def _fetchmany_impl(self, size=None): + def _fetchmany_impl(self, size: Optional[int] = None) -> List[_RowData]: if self._hard_closed: self._raise_hard_closed() return list(itertools.islice(self.iterator, 0, size)) -def null_result(): +def null_result() -> IteratorResult: return IteratorResult(SimpleResultMetaData([]), iter([])) -SelfChunkedIteratorResult = typing.TypeVar( +SelfChunkedIteratorResult = TypeVar( "SelfChunkedIteratorResult", bound="ChunkedIteratorResult" ) @@ -1695,11 +1830,11 @@ class ChunkedIteratorResult(IteratorResult): def __init__( self, - cursor_metadata, - chunks, - source_supports_scalars=False, - raw=None, - dynamic_yield_per=False, + cursor_metadata: ResultMetaData, + chunks: Callable[[Optional[int]], Iterator[List[_InterimRowType]]], + source_supports_scalars: bool = False, + raw: Optional[Any] = None, + dynamic_yield_per: bool = False, ): self._metadata = cursor_metadata self.chunks = chunks @@ -1710,7 +1845,7 @@ class ChunkedIteratorResult(IteratorResult): @_generative def yield_per( - self: SelfChunkedIteratorResult, num + self: SelfChunkedIteratorResult, num: int ) -> SelfChunkedIteratorResult: # TODO: this throws away the iterator which may be holding # onto a chunk. the yield_per cannot be changed once any @@ -1722,11 +1857,13 @@ class ChunkedIteratorResult(IteratorResult): self.iterator = itertools.chain.from_iterable(self.chunks(num)) return self - def _soft_close(self, **kw): - super(ChunkedIteratorResult, self)._soft_close(**kw) - self.chunks = lambda size: [] + def _soft_close(self, hard: bool = False, **kw: Any) -> None: + super(ChunkedIteratorResult, self)._soft_close(hard=hard, **kw) + self.chunks = lambda size: [] # type: ignore - def _fetchmany_impl(self, size=None): + def _fetchmany_impl( + self, size: Optional[int] = None + ) -> List[_InterimRowType]: if self.dynamic_yield_per: self.iterator = itertools.chain.from_iterable(self.chunks(size)) return super(ChunkedIteratorResult, self)._fetchmany_impl(size=size) @@ -1744,7 +1881,9 @@ class MergedResult(IteratorResult): closed = False - def __init__(self, cursor_metadata, results): + def __init__( + self, cursor_metadata: ResultMetaData, results: Sequence[Result] + ): self._results = results super(MergedResult, self).__init__( cursor_metadata, @@ -1763,7 +1902,7 @@ class MergedResult(IteratorResult): *[r._attributes for r in results] ) - def _soft_close(self, hard=False, **kw): + def _soft_close(self, hard: bool = False, **kw: Any) -> None: for r in self._results: r._soft_close(hard=hard, **kw) if hard: diff --git a/lib/sqlalchemy/engine/row.py b/lib/sqlalchemy/engine/row.py index 29b2f338b6..ff63199d40 100644 --- a/lib/sqlalchemy/engine/row.py +++ b/lib/sqlalchemy/engine/row.py @@ -9,24 +9,41 @@ from __future__ import annotations +from abc import ABC import collections.abc as collections_abc import operator import typing +from typing import Any +from typing import Callable +from typing import Dict +from typing import Iterator +from typing import List +from typing import Mapping +from typing import NoReturn +from typing import Optional +from typing import overload +from typing import Sequence +from typing import Tuple +from typing import Union from ..sql import util as sql_util from ..util._has_cy import HAS_CYEXTENSION if typing.TYPE_CHECKING or not HAS_CYEXTENSION: - from ._py_row import BaseRow + from ._py_row import BaseRow as BaseRow from ._py_row import KEY_INTEGER_ONLY from ._py_row import KEY_OBJECTS_ONLY else: - from sqlalchemy.cyextension.resultproxy import BaseRow + from sqlalchemy.cyextension.resultproxy import BaseRow as BaseRow from sqlalchemy.cyextension.resultproxy import KEY_INTEGER_ONLY from sqlalchemy.cyextension.resultproxy import KEY_OBJECTS_ONLY +if typing.TYPE_CHECKING: + from .result import _KeyType + from .result import RMKeyView -class Row(BaseRow, collections_abc.Sequence): + +class Row(BaseRow, typing.Sequence[Any]): """Represent a single result row. The :class:`.Row` object represents a row of a database result. It is @@ -58,14 +75,14 @@ class Row(BaseRow, collections_abc.Sequence): _default_key_style = KEY_INTEGER_ONLY - def __setattr__(self, name, value): + def __setattr__(self, name: str, value: Any) -> NoReturn: raise AttributeError("can't set attribute") - def __delattr__(self, name): + def __delattr__(self, name: str) -> NoReturn: raise AttributeError("can't delete attribute") @property - def _mapping(self): + def _mapping(self) -> RowMapping: """Return a :class:`.RowMapping` for this :class:`.Row`. This object provides a consistent Python mapping (i.e. dictionary) @@ -87,31 +104,44 @@ class Row(BaseRow, collections_abc.Sequence): self._data, ) - def _special_name_accessor(name): - """Handle ambiguous names such as "count" and "index" """ + def _filter_on_values( + self, filters: Optional[Sequence[Optional[Callable[[Any], Any]]]] + ) -> Row: + return Row( + self._parent, + filters, + self._keymap, + self._key_style, + self._data, + ) + + if not typing.TYPE_CHECKING: + + def _special_name_accessor(name: str) -> Any: + """Handle ambiguous names such as "count" and "index" """ - @property - def go(self): - if self._parent._has_key(name): - return self.__getattr__(name) - else: + @property + def go(self: Row) -> Any: + if self._parent._has_key(name): + return self.__getattr__(name) + else: - def meth(*arg, **kw): - return getattr(collections_abc.Sequence, name)( - self, *arg, **kw - ) + def meth(*arg: Any, **kw: Any) -> Any: + return getattr(collections_abc.Sequence, name)( + self, *arg, **kw + ) - return meth + return meth - return go + return go - count = _special_name_accessor("count") - index = _special_name_accessor("index") + count = _special_name_accessor("count") + index = _special_name_accessor("index") - def __contains__(self, key): + def __contains__(self, key: Any) -> bool: return key in self._data - def _op(self, other, op): + def _op(self, other: Any, op: Callable[[Any, Any], bool]) -> bool: return ( op(tuple(self), tuple(other)) if isinstance(other, Row) @@ -120,29 +150,44 @@ class Row(BaseRow, collections_abc.Sequence): __hash__ = BaseRow.__hash__ - def __lt__(self, other): + if typing.TYPE_CHECKING: + + @overload + def __getitem__(self, index: int) -> Any: + ... + + @overload + def __getitem__(self, index: slice) -> Sequence[Any]: + ... + + def __getitem__( + self, index: Union[int, slice] + ) -> Union[Any, Sequence[Any]]: + ... + + def __lt__(self, other: Any) -> bool: return self._op(other, operator.lt) - def __le__(self, other): + def __le__(self, other: Any) -> bool: return self._op(other, operator.le) - def __ge__(self, other): + def __ge__(self, other: Any) -> bool: return self._op(other, operator.ge) - def __gt__(self, other): + def __gt__(self, other: Any) -> bool: return self._op(other, operator.gt) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return self._op(other, operator.eq) - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return self._op(other, operator.ne) - def __repr__(self): + def __repr__(self) -> str: return repr(sql_util._repr_row(self)) @property - def _fields(self): + def _fields(self) -> Tuple[str, ...]: """Return a tuple of string keys as represented by this :class:`.Row`. @@ -162,7 +207,7 @@ class Row(BaseRow, collections_abc.Sequence): """ return tuple([k for k in self._parent.keys if k is not None]) - def _asdict(self): + def _asdict(self) -> Dict[str, Any]: """Return a new dict which maps field names to their corresponding values. @@ -179,49 +224,51 @@ class Row(BaseRow, collections_abc.Sequence): """ return dict(self._mapping) - def _replace(self): - raise NotImplementedError() - - @property - def _field_defaults(self): - raise NotImplementedError() - BaseRowProxy = BaseRow RowProxy = Row -class ROMappingView( - collections_abc.KeysView, - collections_abc.ValuesView, - collections_abc.ItemsView, -): - __slots__ = ("_items",) +class ROMappingView(ABC): + __slots__ = () + + _items: Sequence[Any] + _mapping: Mapping[str, Any] - def __init__(self, mapping, items): + def __init__(self, mapping: Mapping[str, Any], items: Sequence[Any]): self._mapping = mapping self._items = items - def __len__(self): + def __len__(self) -> int: return len(self._items) - def __repr__(self): + def __repr__(self) -> str: return "{0.__class__.__name__}({0._mapping!r})".format(self) - def __iter__(self): + def __iter__(self) -> Iterator[Any]: return iter(self._items) - def __contains__(self, item): + def __contains__(self, item: Any) -> bool: return item in self._items - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return list(other) == list(self) - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return list(other) != list(self) -class RowMapping(BaseRow, collections_abc.Mapping): +class ROMappingKeysValuesView( + ROMappingView, typing.KeysView[str], typing.ValuesView[Any] +): + __slots__ = ("_items",) + + +class ROMappingItemsView(ROMappingView, typing.ItemsView[str, Any]): + __slots__ = ("_items",) + + +class RowMapping(BaseRow, typing.Mapping[str, Any]): """A ``Mapping`` that maps column names and objects to :class:`.Row` values. The :class:`.RowMapping` is available from a :class:`.Row` via the @@ -251,31 +298,39 @@ class RowMapping(BaseRow, collections_abc.Mapping): _default_key_style = KEY_OBJECTS_ONLY - __getitem__ = BaseRow._get_by_key_impl_mapping + if typing.TYPE_CHECKING: - def _values_impl(self): + def __getitem__(self, key: _KeyType) -> Any: + ... + + else: + __getitem__ = BaseRow._get_by_key_impl_mapping + + def _values_impl(self) -> List[Any]: return list(self._data) - def __iter__(self): + def __iter__(self) -> Iterator[str]: return (k for k in self._parent.keys if k is not None) - def __len__(self): + def __len__(self) -> int: return len(self._data) - def __contains__(self, key): + def __contains__(self, key: object) -> bool: return self._parent._has_key(key) - def __repr__(self): + def __repr__(self) -> str: return repr(dict(self)) - def items(self): + def items(self) -> ROMappingItemsView: """Return a view of key/value tuples for the elements in the underlying :class:`.Row`. """ - return ROMappingView(self, [(key, self[key]) for key in self.keys()]) + return ROMappingItemsView( + self, [(key, self[key]) for key in self.keys()] + ) - def keys(self): + def keys(self) -> RMKeyView: """Return a view of 'keys' for string column names represented by the underlying :class:`.Row`. @@ -283,9 +338,9 @@ class RowMapping(BaseRow, collections_abc.Mapping): return self._parent.keys - def values(self): + def values(self) -> ROMappingKeysValuesView: """Return a view of values for the values represented in the underlying :class:`.Row`. """ - return ROMappingView(self, self._values_impl()) + return ROMappingKeysValuesView(self, self._values_impl()) diff --git a/lib/sqlalchemy/engine/url.py b/lib/sqlalchemy/engine/url.py index a55233397e..306989e0ba 100644 --- a/lib/sqlalchemy/engine/url.py +++ b/lib/sqlalchemy/engine/url.py @@ -18,10 +18,18 @@ from __future__ import annotations import collections.abc as collections_abc import re +from typing import Any +from typing import cast from typing import Dict +from typing import Iterable +from typing import List +from typing import Mapping from typing import NamedTuple from typing import Optional +from typing import overload +from typing import Sequence from typing import Tuple +from typing import Type from typing import Union from urllib.parse import parse_qsl from urllib.parse import quote_plus @@ -86,19 +94,19 @@ class URL(NamedTuple): host: Optional[str] port: Optional[int] database: Optional[str] - query: Dict[str, Union[str, Tuple[str]]] + query: util.immutabledict[str, Union[Tuple[str, ...], str]] @classmethod def create( cls, - drivername, - username=None, - password=None, - host=None, - port=None, - database=None, - query=util.EMPTY_DICT, - ): + drivername: str, + username: Optional[str] = None, + password: Optional[str] = None, + host: Optional[str] = None, + port: Optional[int] = None, + database: Optional[str] = None, + query: Mapping[str, Union[Sequence[str], str]] = util.EMPTY_DICT, + ) -> URL: """Create a new :class:`_engine.URL` object. :param drivername: the name of the database backend. This name will @@ -146,7 +154,7 @@ class URL(NamedTuple): ) @classmethod - def _assert_port(cls, port): + def _assert_port(cls, port: Optional[int]) -> Optional[int]: if port is None: return None try: @@ -155,24 +163,48 @@ class URL(NamedTuple): raise TypeError("Port argument must be an integer or None") @classmethod - def _assert_str(cls, v, paramname): + def _assert_str(cls, v: str, paramname: str) -> str: if not isinstance(v, str): raise TypeError("%s must be a string" % paramname) return v @classmethod - def _assert_none_str(cls, v, paramname): + def _assert_none_str( + cls, v: Optional[str], paramname: str + ) -> Optional[str]: if v is None: return v return cls._assert_str(v, paramname) @classmethod - def _str_dict(cls, dict_): + def _str_dict( + cls, + dict_: Optional[ + Union[ + Sequence[Tuple[str, Union[Sequence[str], str]]], + Mapping[str, Union[Sequence[str], str]], + ] + ], + ) -> util.immutabledict[str, Union[Tuple[str, ...], str]]: if dict_ is None: return util.EMPTY_DICT - def _assert_value(val): + @overload + def _assert_value( + val: str, + ) -> str: + ... + + @overload + def _assert_value( + val: Sequence[str], + ) -> Union[str, Tuple[str, ...]]: + ... + + def _assert_value( + val: Union[str, Sequence[str]], + ) -> Union[str, Tuple[str, ...]]: if isinstance(val, str): return val elif isinstance(val, collections_abc.Sequence): @@ -183,11 +215,12 @@ class URL(NamedTuple): "sequences of strings" ) - def _assert_str(v): + def _assert_str(v: str) -> str: if not isinstance(v, str): raise TypeError("Query dictionary keys must be strings") return v + dict_items: Iterable[Tuple[str, Union[Sequence[str], str]]] if isinstance(dict_, collections_abc.Sequence): dict_items = dict_ else: @@ -204,14 +237,14 @@ class URL(NamedTuple): def set( self, - drivername=None, - username=None, - password=None, - host=None, - port=None, - database=None, - query=None, - ): + drivername: Optional[str] = None, + username: Optional[str] = None, + password: Optional[str] = None, + host: Optional[str] = None, + port: Optional[int] = None, + database: Optional[str] = None, + query: Optional[Mapping[str, Union[Sequence[str], str]]] = None, + ) -> URL: """return a new :class:`_engine.URL` object with modifications. Values are used if they are non-None. To set a value to ``None`` @@ -237,7 +270,7 @@ class URL(NamedTuple): """ - kw = {} + kw: Dict[str, Any] = {} if drivername is not None: kw["drivername"] = drivername if username is not None: @@ -255,7 +288,7 @@ class URL(NamedTuple): return self._assert_replace(**kw) - def _assert_replace(self, **kw): + def _assert_replace(self, **kw: Any) -> URL: """argument checks before calling _replace()""" if "drivername" in kw: @@ -270,7 +303,9 @@ class URL(NamedTuple): return self._replace(**kw) - def update_query_string(self, query_string, append=False): + def update_query_string( + self, query_string: str, append: bool = False + ) -> URL: """Return a new :class:`_engine.URL` object with the :attr:`_engine.URL.query` parameter dictionary updated by the given query string. @@ -301,7 +336,11 @@ class URL(NamedTuple): """ # noqa: E501 return self.update_query_pairs(parse_qsl(query_string), append=append) - def update_query_pairs(self, key_value_pairs, append=False): + def update_query_pairs( + self, + key_value_pairs: Iterable[Tuple[str, Union[str, List[str]]]], + append: bool = False, + ) -> URL: """Return a new :class:`_engine.URL` object with the :attr:`_engine.URL.query` parameter dictionary updated by the given sequence of key/value pairs @@ -335,23 +374,27 @@ class URL(NamedTuple): """ # noqa: E501 existing_query = self.query - new_keys = {} + new_keys: Dict[str, Union[str, List[str]]] = {} for key, value in key_value_pairs: if key in new_keys: new_keys[key] = util.to_list(new_keys[key]) - new_keys[key].append(value) + cast("List[str]", new_keys[key]).append(cast(str, value)) else: - new_keys[key] = value + new_keys[key] = ( + list(value) if isinstance(value, (list, tuple)) else value + ) + new_query: Mapping[str, Union[str, Sequence[str]]] if append: new_query = {} for k in new_keys: if k in existing_query: - new_query[k] = util.to_list( - existing_query[k] - ) + util.to_list(new_keys[k]) + new_query[k] = tuple( + util.to_list(existing_query[k]) + + util.to_list(new_keys[k]) + ) else: new_query[k] = new_keys[k] @@ -362,10 +405,19 @@ class URL(NamedTuple): } ) else: - new_query = self.query.union(new_keys) + new_query = self.query.union( + { + k: tuple(v) if isinstance(v, list) else v + for k, v in new_keys.items() + } + ) return self.set(query=new_query) - def update_query_dict(self, query_parameters, append=False): + def update_query_dict( + self, + query_parameters: Mapping[str, Union[str, List[str]]], + append: bool = False, + ) -> URL: """Return a new :class:`_engine.URL` object with the :attr:`_engine.URL.query` parameter dictionary updated by the given dictionary. @@ -410,7 +462,7 @@ class URL(NamedTuple): """ # noqa: E501 return self.update_query_pairs(query_parameters.items(), append=append) - def difference_update_query(self, names): + def difference_update_query(self, names: Iterable[str]) -> URL: """ Remove the given names from the :attr:`_engine.URL.query` dictionary, returning the new :class:`_engine.URL`. @@ -459,7 +511,7 @@ class URL(NamedTuple): ) @util.memoized_property - def normalized_query(self): + def normalized_query(self) -> Mapping[str, Sequence[str]]: """Return the :attr:`_engine.URL.query` dictionary with values normalized into sequences. @@ -494,7 +546,7 @@ class URL(NamedTuple): "be removed in a future release. Please use the " ":meth:`_engine.URL.render_as_string` method.", ) - def __to_string__(self, hide_password=True): + def __to_string__(self, hide_password: bool = True) -> str: """Render this :class:`_engine.URL` object as a string. :param hide_password: Defaults to True. The password is not shown @@ -503,7 +555,7 @@ class URL(NamedTuple): """ return self.render_as_string(hide_password=hide_password) - def render_as_string(self, hide_password=True): + def render_as_string(self, hide_password: bool = True) -> str: """Render this :class:`_engine.URL` object as a string. This method is used when the ``__str__()`` or ``__repr__()`` @@ -542,13 +594,13 @@ class URL(NamedTuple): ) return s - def __str__(self): + def __str__(self) -> str: return self.render_as_string(hide_password=False) - def __repr__(self): + def __repr__(self) -> str: return self.render_as_string() - def __copy__(self): + def __copy__(self) -> URL: return self.__class__.create( self.drivername, self.username, @@ -561,13 +613,13 @@ class URL(NamedTuple): self.query, ) - def __deepcopy__(self, memo): + def __deepcopy__(self, memo: Any) -> URL: return self.__copy__() - def __hash__(self): + def __hash__(self) -> int: return hash(str(self)) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return ( isinstance(other, URL) and self.drivername == other.drivername @@ -579,10 +631,10 @@ class URL(NamedTuple): and self.port == other.port ) - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not self == other - def get_backend_name(self): + def get_backend_name(self) -> str: """Return the backend name. This is the name that corresponds to the database backend in @@ -595,7 +647,7 @@ class URL(NamedTuple): else: return self.drivername.split("+")[0] - def get_driver_name(self): + def get_driver_name(self) -> str: """Return the backend name. This is the name that corresponds to the DBAPI driver in @@ -613,7 +665,9 @@ class URL(NamedTuple): else: return self.drivername.split("+")[1] - def _instantiate_plugins(self, kwargs): + def _instantiate_plugins( + self, kwargs: Mapping[str, Any] + ) -> Tuple[URL, List[Any], Dict[str, Any]]: plugin_names = util.to_list(self.query.get("plugin", ())) plugin_names += kwargs.get("plugins", []) @@ -635,7 +689,7 @@ class URL(NamedTuple): return u, loaded_plugins, kwargs - def _get_entrypoint(self): + def _get_entrypoint(self) -> Type[Dialect]: """Return the "entry point" dialect class. This is normally the dialect itself except in the case when the @@ -657,9 +711,9 @@ class URL(NamedTuple): ): return cls.dialect else: - return cls + return cast("Type[Dialect]", cls) - def get_dialect(self, _is_async=False): + def get_dialect(self, _is_async: bool = False) -> Type[Dialect]: """Return the SQLAlchemy :class:`_engine.Dialect` class corresponding to this URL's driver name. @@ -671,7 +725,9 @@ class URL(NamedTuple): dialect_cls = entrypoint.get_dialect_cls(self) return dialect_cls - def translate_connect_args(self, names=None, **kw): + def translate_connect_args( + self, names: Optional[List[str]] = None, **kw: Any + ) -> Dict[str, Any]: r"""Translate url attributes into a dictionary of connection arguments. Returns attributes of this url (`host`, `database`, `username`, @@ -711,11 +767,12 @@ class URL(NamedTuple): return translated -def make_url(name_or_url): - """Given a string or unicode instance, produce a new URL instance. +def make_url(name_or_url: Union[str, URL]) -> URL: + """Given a string, produce a new URL instance. The given string is parsed according to the RFC 1738 spec. If an existing URL object is passed, just returns the object. + """ if isinstance(name_or_url, str): @@ -724,7 +781,7 @@ def make_url(name_or_url): return name_or_url -def _parse_rfc1738_args(name): +def _parse_rfc1738_args(name: str) -> URL: pattern = re.compile( r""" (?P[\w\+]+):// @@ -748,13 +805,14 @@ def _parse_rfc1738_args(name): m = pattern.match(name) if m is not None: components = m.groupdict() + query: Optional[Dict[str, Union[str, List[str]]]] if components["query"] is not None: query = {} for key, value in parse_qsl(components["query"]): if key in query: query[key] = util.to_list(query[key]) - query[key].append(value) + cast("List[str]", query[key]).append(value) else: query[key] = value else: @@ -775,7 +833,7 @@ def _parse_rfc1738_args(name): if components["port"]: components["port"] = int(components["port"]) - return URL.create(name, **components) + return URL.create(name, **components) # type: ignore else: raise exc.ArgumentError( @@ -783,18 +841,8 @@ def _parse_rfc1738_args(name): ) -def _rfc_1738_quote(text): +def _rfc_1738_quote(text: str) -> str: return re.sub(r"[:@/]", lambda m: "%%%X" % ord(m.group(0)), text) _rfc_1738_unquote = unquote - - -def _parse_keyvalue_args(name): - m = re.match(r"(\w+)://(.*)", name) - if m is not None: - (name, args) = m.group(1, 2) - opts = dict(parse_qsl(args)) - return URL(name, *opts) - else: - return None diff --git a/lib/sqlalchemy/engine/util.py b/lib/sqlalchemy/engine/util.py index f9ee65befe..213485cc92 100644 --- a/lib/sqlalchemy/engine/util.py +++ b/lib/sqlalchemy/engine/util.py @@ -7,18 +7,30 @@ from __future__ import annotations +import typing +from typing import Any +from typing import Callable +from typing import TypeVar + from .. import exc from .. import util +from ..util._has_cy import HAS_CYEXTENSION + +if typing.TYPE_CHECKING or not HAS_CYEXTENSION: + from ._py_util import _distill_params_20 as _distill_params_20 + from ._py_util import _distill_raw_params as _distill_raw_params +else: + from sqlalchemy.cyextension.util import ( + _distill_params_20 as _distill_params_20, + ) + from sqlalchemy.cyextension.util import ( + _distill_raw_params as _distill_raw_params, + ) -try: - from sqlalchemy.cyextension.util import _distill_params_20 # noqa - from sqlalchemy.cyextension.util import _distill_raw_params # noqa -except ImportError: - from ._py_util import _distill_params_20 # noqa - from ._py_util import _distill_raw_params # noqa +_C = TypeVar("_C", bound=Callable[[], Any]) -def connection_memoize(key): +def connection_memoize(key: str) -> Callable[[_C], _C]: """Decorator, memoize a function in a connection.info stash. Only applicable to functions which take no arguments other than a @@ -26,7 +38,7 @@ def connection_memoize(key): """ @util.decorator - def decorated(fn, self, connection): + def decorated(fn, self, connection): # type: ignore connection = connection.connect() try: return connection.info[key] @@ -34,7 +46,7 @@ def connection_memoize(key): connection.info[key] = val = fn(self, connection) return val - return decorated + return decorated # type: ignore[return-value] class TransactionalContext: @@ -47,13 +59,13 @@ class TransactionalContext: __slots__ = ("_outer_trans_ctx", "_trans_subject", "__weakref__") - def _transaction_is_active(self): + def _transaction_is_active(self) -> bool: raise NotImplementedError() - def _transaction_is_closed(self): + def _transaction_is_closed(self) -> bool: raise NotImplementedError() - def _rollback_can_be_called(self): + def _rollback_can_be_called(self) -> bool: """indicates the object is in a state that is known to be acceptable for rollback() to be called. @@ -70,11 +82,20 @@ class TransactionalContext: """ raise NotImplementedError() - def _get_subject(self): + def _get_subject(self) -> Any: + raise NotImplementedError() + + def commit(self) -> None: + raise NotImplementedError() + + def rollback(self) -> None: + raise NotImplementedError() + + def close(self) -> None: raise NotImplementedError() @classmethod - def _trans_ctx_check(cls, subject): + def _trans_ctx_check(cls, subject: Any) -> None: trans_context = subject._trans_context_manager if trans_context: if not trans_context._transaction_is_active(): @@ -84,7 +105,7 @@ class TransactionalContext: "before emitting further commands." ) - def __enter__(self): + def __enter__(self) -> TransactionalContext: subject = self._get_subject() # none for outer transaction, may be non-None for nested @@ -96,7 +117,7 @@ class TransactionalContext: subject._trans_context_manager = self return self - def __exit__(self, type_, value, traceback): + def __exit__(self, type_: Any, value: Any, traceback: Any) -> None: subject = getattr(self, "_trans_subject", None) # simplistically we could assume that @@ -119,6 +140,7 @@ class TransactionalContext: self.rollback() finally: if not out_of_band_exit: + assert subject is not None subject._trans_context_manager = self._outer_trans_ctx self._trans_subject = self._outer_trans_ctx = None else: @@ -131,5 +153,6 @@ class TransactionalContext: self.rollback() finally: if not out_of_band_exit: + assert subject is not None subject._trans_context_manager = self._outer_trans_ctx self._trans_subject = self._outer_trans_ctx = None diff --git a/lib/sqlalchemy/event/__init__.py b/lib/sqlalchemy/event/__init__.py index 0dfb39e1a0..e1c9496813 100644 --- a/lib/sqlalchemy/event/__init__.py +++ b/lib/sqlalchemy/event/__init__.py @@ -15,6 +15,7 @@ from .api import NO_RETVAL as NO_RETVAL from .api import remove as remove from .attr import RefCollection as RefCollection from .base import _Dispatch as _Dispatch +from .base import _DispatchCommon as _DispatchCommon from .base import dispatcher as dispatcher from .base import Events as Events from .legacy import _legacy_signature as _legacy_signature diff --git a/lib/sqlalchemy/event/attr.py b/lib/sqlalchemy/event/attr.py index 9692894fe8..afae8a59a2 100644 --- a/lib/sqlalchemy/event/attr.py +++ b/lib/sqlalchemy/event/attr.py @@ -605,14 +605,14 @@ class _ListenerCollection(_CompoundListener[_ET]): class _JoinedListener(_CompoundListener[_ET]): __slots__ = "parent_dispatch", "name", "local", "parent_listeners" - parent_dispatch: _Dispatch[_ET] + parent_dispatch: _DispatchCommon[_ET] name: str local: _InstanceLevelDispatch[_ET] parent_listeners: Collection[_ListenerFnType] def __init__( self, - parent_dispatch: _Dispatch[_ET], + parent_dispatch: _DispatchCommon[_ET], name: str, local: _EmptyListener[_ET], ): diff --git a/lib/sqlalchemy/event/base.py b/lib/sqlalchemy/event/base.py index ef3ff9dab3..4174b1dbea 100644 --- a/lib/sqlalchemy/event/base.py +++ b/lib/sqlalchemy/event/base.py @@ -75,6 +75,18 @@ class _UnpickleDispatch: class _DispatchCommon(Generic[_ET]): __slots__ = () + _instance_cls: Optional[Type[_ET]] + + def _join(self, other: _DispatchCommon[_ET]) -> _JoinedDispatcher[_ET]: + raise NotImplementedError() + + def __getattr__(self, name: str) -> _InstanceLevelDispatch[_ET]: + raise NotImplementedError() + + @property + def _events(self) -> Type[_HasEventsDispatch[_ET]]: + raise NotImplementedError() + class _Dispatch(_DispatchCommon[_ET]): """Mirror the event listening definitions of an Events class with @@ -169,7 +181,7 @@ class _Dispatch(_DispatchCommon[_ET]): instance_cls = instance.__class__ return self._for_class(instance_cls) - def _join(self, other: _Dispatch[_ET]) -> _JoinedDispatcher[_ET]: + def _join(self, other: _DispatchCommon[_ET]) -> _JoinedDispatcher[_ET]: """Create a 'join' of this :class:`._Dispatch` and another. This new dispatcher will dispatch events to both @@ -372,11 +384,13 @@ class _JoinedDispatcher(_DispatchCommon[_ET]): __slots__ = "local", "parent", "_instance_cls" - local: _Dispatch[_ET] - parent: _Dispatch[_ET] + local: _DispatchCommon[_ET] + parent: _DispatchCommon[_ET] _instance_cls: Optional[Type[_ET]] - def __init__(self, local: _Dispatch[_ET], parent: _Dispatch[_ET]): + def __init__( + self, local: _DispatchCommon[_ET], parent: _DispatchCommon[_ET] + ): self.local = local self.parent = parent self._instance_cls = self.local._instance_cls @@ -416,7 +430,7 @@ class dispatcher(Generic[_ET]): ... @overload - def __get__(self, obj: Any, cls: Type[Any]) -> _Dispatch[_ET]: + def __get__(self, obj: Any, cls: Type[Any]) -> _DispatchCommon[_ET]: ... def __get__(self, obj: Any, cls: Type[Any]) -> Any: diff --git a/lib/sqlalchemy/exc.py b/lib/sqlalchemy/exc.py index 1383e024a1..cc78e0971c 100644 --- a/lib/sqlalchemy/exc.py +++ b/lib/sqlalchemy/exc.py @@ -27,8 +27,11 @@ from .util import _preloaded from .util import compat if typing.TYPE_CHECKING: + from .engine.interfaces import _AnyExecuteParams + from .engine.interfaces import _CoreAnyExecuteParams + from .engine.interfaces import _CoreMultiExecuteParams + from .engine.interfaces import _DBAPIAnyExecuteParams from .engine.interfaces import Dialect - from .sql._typing import _ExecuteParams from .sql.compiler import Compiled from .sql.elements import ClauseElement @@ -446,7 +449,7 @@ class StatementError(SQLAlchemyError): statement: Optional[str] = None """The string SQL statement being invoked when this exception occurred.""" - params: Optional["_ExecuteParams"] = None + params: Optional[_AnyExecuteParams] = None """The parameter list being used when this exception occurred.""" orig: Optional[BaseException] = None @@ -457,11 +460,13 @@ class StatementError(SQLAlchemyError): ismulti: Optional[bool] = None """multi parameter passed to repr_params(). None is meaningful.""" + connection_invalidated: bool = False + def __init__( self, message: str, statement: Optional[str], - params: Optional["_ExecuteParams"], + params: Optional[_AnyExecuteParams], orig: Optional[BaseException], hide_parameters: bool = False, code: Optional[str] = None, @@ -553,8 +558,8 @@ class DBAPIError(StatementError): @classmethod def instance( cls, - statement: str, - params: "_ExecuteParams", + statement: Optional[str], + params: Optional[_AnyExecuteParams], orig: DontWrapMixin, dbapi_base_err: Type[Exception], hide_parameters: bool = False, @@ -568,8 +573,8 @@ class DBAPIError(StatementError): @classmethod def instance( cls, - statement: str, - params: "_ExecuteParams", + statement: Optional[str], + params: Optional[_AnyExecuteParams], orig: Exception, dbapi_base_err: Type[Exception], hide_parameters: bool = False, @@ -583,8 +588,8 @@ class DBAPIError(StatementError): @classmethod def instance( cls, - statement: str, - params: "_ExecuteParams", + statement: Optional[str], + params: Optional[_AnyExecuteParams], orig: BaseException, dbapi_base_err: Type[Exception], hide_parameters: bool = False, @@ -597,8 +602,8 @@ class DBAPIError(StatementError): @classmethod def instance( cls, - statement: str, - params: "_ExecuteParams", + statement: Optional[str], + params: Optional[_AnyExecuteParams], orig: Union[BaseException, DontWrapMixin], dbapi_base_err: Type[Exception], hide_parameters: bool = False, @@ -684,8 +689,8 @@ class DBAPIError(StatementError): def __init__( self, - statement: str, - params: "_ExecuteParams", + statement: Optional[str], + params: Optional[_AnyExecuteParams], orig: BaseException, hide_parameters: bool = False, connection_invalidated: bool = False, diff --git a/lib/sqlalchemy/log.py b/lib/sqlalchemy/log.py index 8da45ed0d7..9a23d89d3f 100644 --- a/lib/sqlalchemy/log.py +++ b/lib/sqlalchemy/log.py @@ -255,19 +255,19 @@ class echo_property: @overload def __get__( - self, instance: "Literal[None]", owner: "echo_property" - ) -> "echo_property": + self, instance: Literal[None], owner: Type[Identified] + ) -> echo_property: ... @overload def __get__( - self, instance: Identified, owner: "echo_property" + self, instance: Identified, owner: Type[Identified] ) -> _EchoFlagType: ... def __get__( - self, instance: Optional[Identified], owner: "echo_property" - ) -> Union["echo_property", _EchoFlagType]: + self, instance: Optional[Identified], owner: Type[Identified] + ) -> Union[echo_property, _EchoFlagType]: if instance is None: return self else: diff --git a/lib/sqlalchemy/pool/__init__.py b/lib/sqlalchemy/pool/__init__.py index 2c52a70650..1fc77243a9 100644 --- a/lib/sqlalchemy/pool/__init__.py +++ b/lib/sqlalchemy/pool/__init__.py @@ -18,8 +18,8 @@ SQLAlchemy connection pool. """ from . import events -from .base import _AdhocProxiedConnection -from .base import _ConnectionFairy +from .base import _AdhocProxiedConnection as _AdhocProxiedConnection +from .base import _ConnectionFairy as _ConnectionFairy from .base import _ConnectionRecord from .base import _finalize_fairy from .base import ConnectionPoolEntry as ConnectionPoolEntry diff --git a/lib/sqlalchemy/pool/base.py b/lib/sqlalchemy/pool/base.py index 18d268182d..c1008de5f5 100644 --- a/lib/sqlalchemy/pool/base.py +++ b/lib/sqlalchemy/pool/base.py @@ -41,6 +41,7 @@ if TYPE_CHECKING: from ..engine.interfaces import DBAPICursor from ..engine.interfaces import Dialect from ..event import _Dispatch + from ..event import _DispatchCommon from ..event import _ListenerFnType from ..event import dispatcher @@ -132,7 +133,7 @@ class Pool(log.Identified, event.EventTarget): events: Optional[List[Tuple[_ListenerFnType, str]]] = None, dialect: Optional[Union[_ConnDialect, Dialect]] = None, pre_ping: bool = False, - _dispatch: Optional[_Dispatch[Pool]] = None, + _dispatch: Optional[_DispatchCommon[Pool]] = None, ): """ Construct a Pool. @@ -443,78 +444,72 @@ class ManagesConnection: """ - @property - def driver_connection(self) -> Optional[Any]: - """The "driver level" connection object as used by the Python - DBAPI or database driver. + driver_connection: Optional[Any] + """The "driver level" connection object as used by the Python + DBAPI or database driver. - For traditional :pep:`249` DBAPI implementations, this object will - be the same object as that of - :attr:`.ManagesConnection.dbapi_connection`. For an asyncio database - driver, this will be the ultimate "connection" object used by that - driver, such as the ``asyncpg.Connection`` object which will not have - standard pep-249 methods. + For traditional :pep:`249` DBAPI implementations, this object will + be the same object as that of + :attr:`.ManagesConnection.dbapi_connection`. For an asyncio database + driver, this will be the ultimate "connection" object used by that + driver, such as the ``asyncpg.Connection`` object which will not have + standard pep-249 methods. - .. versionadded:: 1.4.24 + .. versionadded:: 1.4.24 - .. seealso:: + .. seealso:: - :attr:`.ManagesConnection.dbapi_connection` + :attr:`.ManagesConnection.dbapi_connection` - :ref:`faq_dbapi_connection` + :ref:`faq_dbapi_connection` - """ - raise NotImplementedError() + """ - @util.dynamic_property - def info(self) -> Dict[str, Any]: - """Info dictionary associated with the underlying DBAPI connection - referred to by this :class:`.ManagesConnection` instance, allowing - user-defined data to be associated with the connection. + info: Dict[str, Any] + """Info dictionary associated with the underlying DBAPI connection + referred to by this :class:`.ManagesConnection` instance, allowing + user-defined data to be associated with the connection. - The data in this dictionary is persistent for the lifespan - of the DBAPI connection itself, including across pool checkins - and checkouts. When the connection is invalidated - and replaced with a new one, this dictionary is cleared. + The data in this dictionary is persistent for the lifespan + of the DBAPI connection itself, including across pool checkins + and checkouts. When the connection is invalidated + and replaced with a new one, this dictionary is cleared. - For a :class:`.PoolProxiedConnection` instance that's not associated - with a :class:`.ConnectionPoolEntry`, such as if it were detached, the - attribute returns a dictionary that is local to that - :class:`.ConnectionPoolEntry`. Therefore the - :attr:`.ManagesConnection.info` attribute will always provide a Python - dictionary. + For a :class:`.PoolProxiedConnection` instance that's not associated + with a :class:`.ConnectionPoolEntry`, such as if it were detached, the + attribute returns a dictionary that is local to that + :class:`.ConnectionPoolEntry`. Therefore the + :attr:`.ManagesConnection.info` attribute will always provide a Python + dictionary. - .. seealso:: + .. seealso:: - :attr:`.ManagesConnection.record_info` + :attr:`.ManagesConnection.record_info` - """ - raise NotImplementedError() + """ - @util.dynamic_property - def record_info(self) -> Optional[Dict[str, Any]]: - """Persistent info dictionary associated with this - :class:`.ManagesConnection`. + record_info: Optional[Dict[str, Any]] + """Persistent info dictionary associated with this + :class:`.ManagesConnection`. - Unlike the :attr:`.ManagesConnection.info` dictionary, the lifespan - of this dictionary is that of the :class:`.ConnectionPoolEntry` - which owns it; therefore this dictionary will persist across - reconnects and connection invalidation for a particular entry - in the connection pool. + Unlike the :attr:`.ManagesConnection.info` dictionary, the lifespan + of this dictionary is that of the :class:`.ConnectionPoolEntry` + which owns it; therefore this dictionary will persist across + reconnects and connection invalidation for a particular entry + in the connection pool. - For a :class:`.PoolProxiedConnection` instance that's not associated - with a :class:`.ConnectionPoolEntry`, such as if it were detached, the - attribute returns None. Contrast to the :attr:`.ManagesConnection.info` - dictionary which is never None. + For a :class:`.PoolProxiedConnection` instance that's not associated + with a :class:`.ConnectionPoolEntry`, such as if it were detached, the + attribute returns None. Contrast to the :attr:`.ManagesConnection.info` + dictionary which is never None. - .. seealso:: + .. seealso:: - :attr:`.ManagesConnection.info` + :attr:`.ManagesConnection.info` - """ - raise NotImplementedError() + """ def invalidate( self, e: Optional[BaseException] = None, soft: bool = False @@ -618,7 +613,7 @@ class _ConnectionRecord(ConnectionPoolEntry): dbapi_connection: Optional[DBAPIConnection] @property - def driver_connection(self) -> Optional[Any]: + def driver_connection(self) -> Optional[Any]: # type: ignore[override] # mypy#4125 # noqa E501 if self.dbapi_connection is None: return None else: @@ -637,11 +632,11 @@ class _ConnectionRecord(ConnectionPoolEntry): _soft_invalidate_time: float = 0 @util.memoized_property - def info(self) -> Dict[str, Any]: + def info(self) -> Dict[str, Any]: # type: ignore[override] # mypy#4125 return {} @util.memoized_property - def record_info(self) -> Optional[Dict[str, Any]]: + def record_info(self) -> Optional[Dict[str, Any]]: # type: ignore[override] # mypy#4125 # noqa E501 return {} @classmethod @@ -1048,7 +1043,7 @@ class _AdhocProxiedConnection(PoolProxiedConnection): """ - __slots__ = ("dbapi_connection", "_connection_record") + __slots__ = ("dbapi_connection", "_connection_record", "_is_valid") dbapi_connection: DBAPIConnection _connection_record: ConnectionPoolEntry @@ -1060,9 +1055,10 @@ class _AdhocProxiedConnection(PoolProxiedConnection): ): self.dbapi_connection = dbapi_connection self._connection_record = connection_record + self._is_valid = True @property - def driver_connection(self) -> Any: + def driver_connection(self) -> Any: # type: ignore[override] # mypy#4125 return self._connection_record.driver_connection @property @@ -1071,10 +1067,21 @@ class _AdhocProxiedConnection(PoolProxiedConnection): @property def is_valid(self) -> bool: - raise AttributeError("is_valid not implemented by this proxy") + """Implement is_valid state attribute. + + for the adhoc proxied connection it's assumed the connection is valid + as there is no "invalidate" routine. + + """ + return self._is_valid - @util.dynamic_property - def record_info(self) -> Optional[Dict[str, Any]]: + def invalidate( + self, e: Optional[BaseException] = None, soft: bool = False + ) -> None: + self._is_valid = False + + @property + def record_info(self) -> Optional[Dict[str, Any]]: # type: ignore[override] # mypy#4125 # noqa E501 return self._connection_record.record_info def cursor(self, *args: Any, **kwargs: Any) -> DBAPICursor: @@ -1140,7 +1147,7 @@ class _ConnectionFairy(PoolProxiedConnection): _connection_record: Optional[_ConnectionRecord] @property - def driver_connection(self) -> Optional[Any]: + def driver_connection(self) -> Optional[Any]: # type: ignore[override] # mypy#4125 # noqa E501 if self._connection_record is None: return None return self._connection_record.driver_connection @@ -1305,17 +1312,17 @@ class _ConnectionFairy(PoolProxiedConnection): @property def is_detached(self) -> bool: - return self._connection_record is not None + return self._connection_record is None @util.memoized_property - def info(self) -> Dict[str, Any]: + def info(self) -> Dict[str, Any]: # type: ignore[override] # mypy#4125 if self._connection_record is None: return {} else: return self._connection_record.info - @util.dynamic_property - def record_info(self) -> Optional[Dict[str, Any]]: + @property + def record_info(self) -> Optional[Dict[str, Any]]: # type: ignore[override] # mypy#4125 # noqa E501 if self._connection_record is None: return None else: diff --git a/lib/sqlalchemy/sql/_typing.py b/lib/sqlalchemy/sql/_typing.py index 7d8b9ee5c4..69e4645fa6 100644 --- a/lib/sqlalchemy/sql/_typing.py +++ b/lib/sqlalchemy/sql/_typing.py @@ -1,20 +1,11 @@ from __future__ import annotations -from typing import Any -from typing import Mapping -from typing import Sequence from typing import Type from typing import Union from . import roles from ..inspection import Inspectable -from ..util import immutabledict -_SingleExecuteParams = Mapping[str, Any] -_MultiExecuteParams = Sequence[_SingleExecuteParams] -_ExecuteParams = Union[_SingleExecuteParams, _MultiExecuteParams] -_ExecuteOptions = Mapping[str, Any] -_ImmutableExecuteOptions = immutabledict[str, Any] _ColumnsClauseElement = Union[ roles.ColumnsClauseRole, Type, Inspectable[roles.HasClauseElement] ] diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index 3936ed9c63..a94590da1c 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -19,11 +19,12 @@ from itertools import zip_longest import operator import re import typing +from typing import Optional +from typing import Sequence from typing import TypeVar from . import roles from . import visitors -from ._typing import _ImmutableExecuteOptions from .cache_key import HasCacheKey # noqa from .cache_key import MemoizedHasCacheKey # noqa from .traversals import HasCopyInternals # noqa @@ -32,7 +33,7 @@ from .visitors import ExtendedInternalTraversal from .visitors import InternalTraversal from .. import exc from .. import util -from ..util import HasMemoized +from ..util import HasMemoized as HasMemoized from ..util import hybridmethod from ..util import typing as compat_typing from ..util._has_cy import HAS_CYEXTENSION @@ -42,6 +43,16 @@ if typing.TYPE_CHECKING or not HAS_CYEXTENSION: else: from sqlalchemy.cyextension.util import prefix_anon_map # noqa +if typing.TYPE_CHECKING: + from ..engine import Connection + from ..engine import Result + from ..engine.interfaces import _CoreMultiExecuteParams + from ..engine.interfaces import _ExecuteOptions + from ..engine.interfaces import _ExecuteOptionsParameter + from ..engine.interfaces import _ImmutableExecuteOptions + from ..engine.interfaces import CacheStats + + coercions = None elements = None type_api = None @@ -856,6 +867,32 @@ class Executable(roles.StatementRole, Generative): is_delete = False is_dml = False + if typing.TYPE_CHECKING: + + def _compile_w_cache( + self, + dialect: Dialect, + compiled_cache: Optional[_CompiledCacheType] = None, + column_keys: Optional[Sequence[str]] = None, + for_executemany: bool = False, + schema_translate_map: Optional[_SchemaTranslateMapType] = None, + **kw: Any, + ) -> Tuple[Compiled, _SingleExecuteParams, CacheStats]: + ... + + def _execute_on_connection( + self, + connection: Connection, + distilled_params: _CoreMultiExecuteParams, + execution_options: _ExecuteOptionsParameter, + _force: bool = False, + ) -> Result: + ... + + @property + def _all_selected_columns(self): + raise NotImplementedError() + @property def _effective_plugin_target(self): return self.__visit_name__ diff --git a/lib/sqlalchemy/sql/cache_key.py b/lib/sqlalchemy/sql/cache_key.py index 49f1899d5a..ff659b77de 100644 --- a/lib/sqlalchemy/sql/cache_key.py +++ b/lib/sqlalchemy/sql/cache_key.py @@ -7,10 +7,12 @@ from __future__ import annotations -from collections import namedtuple import enum from itertools import zip_longest +import typing +from typing import Any from typing import Callable +from typing import NamedTuple from typing import Union from .visitors import anon_map @@ -22,6 +24,10 @@ from ..util import HasMemoized from ..util.typing import Literal +if typing.TYPE_CHECKING: + from .elements import BindParameter + + class CacheConst(enum.Enum): NO_CACHE = 0 @@ -345,7 +351,7 @@ class MemoizedHasCacheKey(HasCacheKey, HasMemoized): return HasCacheKey._generate_cache_key(self) -class CacheKey(namedtuple("CacheKey", ["key", "bindparams"])): +class CacheKey(NamedTuple): """The key used to identify a SQL statement construct in the SQL compilation cache. @@ -355,6 +361,9 @@ class CacheKey(namedtuple("CacheKey", ["key", "bindparams"])): """ + key: Tuple[Any, ...] + bindparams: Sequence[BindParameter] + def __hash__(self): """CacheKey itself is not hashable - hash the .key portion""" diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index d0f114d6c9..712d314624 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -27,6 +27,7 @@ from __future__ import annotations import collections import collections.abc as collections_abc import contextlib +from enum import IntEnum import itertools import operator import re @@ -35,9 +36,13 @@ import typing from typing import Any from typing import Dict from typing import List +from typing import Mapping from typing import MutableMapping +from typing import NamedTuple from typing import Optional +from typing import Sequence from typing import Tuple +from typing import Union from . import base from . import coercions @@ -51,12 +56,17 @@ from . import sqltypes from .base import NO_ARG from .base import prefix_anon_map from .elements import quoted_name +from .schema import Column +from .type_api import TypeEngine from .. import exc from .. import util +from ..util.typing import Literal if typing.TYPE_CHECKING: from .selectable import CTE from .selectable import FromClause + from ..engine.interfaces import _CoreSingleExecuteParams + from ..engine.result import _ProcessorType _FromHintsType = Dict["FromClause", str] @@ -271,42 +281,71 @@ COMPOUND_KEYWORDS = { } -RM_RENDERED_NAME = 0 -RM_NAME = 1 -RM_OBJECTS = 2 -RM_TYPE = 3 +class ResultColumnsEntry(NamedTuple): + """Tracks a column expression that is expected to be represented + in the result rows for this statement. + This normally refers to the columns clause of a SELECT statement + but may also refer to a RETURNING clause, as well as for dialect-specific + emulations. -ExpandedState = collections.namedtuple( - "ExpandedState", - [ - "statement", - "additional_parameters", - "processors", - "positiontup", - "parameter_expansion", - ], -) + """ + keyname: str + """string name that's expected in cursor.description""" -NO_LINTING = util.symbol("NO_LINTING", "Disable all linting.", canonical=0) + name: str + """column name, may be labeled""" -COLLECT_CARTESIAN_PRODUCTS = util.symbol( - "COLLECT_CARTESIAN_PRODUCTS", - "Collect data on FROMs and cartesian products and gather " - "into 'self.from_linter'", - canonical=1, -) + objects: List[Any] + """list of objects that should be able to locate this column + in a RowMapping. This is typically string names and aliases + as well as Column objects. -WARN_LINTING = util.symbol( - "WARN_LINTING", "Emit warnings for linters that find problems", canonical=2 -) + """ + + type: TypeEngine[Any] + """Datatype to be associated with this column. This is where + the "result processing" logic directly links the compiled statement + to the rows that come back from the cursor. + + """ + + +# integer indexes into ResultColumnsEntry used by cursor.py. +# some profiling showed integer access faster than named tuple +RM_RENDERED_NAME: Literal[0] = 0 +RM_NAME: Literal[1] = 1 +RM_OBJECTS: Literal[2] = 2 +RM_TYPE: Literal[3] = 3 + + +class ExpandedState(NamedTuple): + statement: str + additional_parameters: _CoreSingleExecuteParams + processors: Mapping[str, _ProcessorType] + positiontup: Optional[Sequence[str]] + parameter_expansion: Mapping[str, List[str]] + + +class Linting(IntEnum): + NO_LINTING = 0 + "Disable all linting." + + COLLECT_CARTESIAN_PRODUCTS = 1 + """Collect data on FROMs and cartesian products and gather into + 'self.from_linter'""" + + WARN_LINTING = 2 + "Emit warnings for linters that find problems" -FROM_LINTING = util.symbol( - "FROM_LINTING", - "Warn for cartesian products; " - "combines COLLECT_CARTESIAN_PRODUCTS and WARN_LINTING", - canonical=COLLECT_CARTESIAN_PRODUCTS | WARN_LINTING, + FROM_LINTING = COLLECT_CARTESIAN_PRODUCTS | WARN_LINTING + """Warn for cartesian products; combines COLLECT_CARTESIAN_PRODUCTS + and WARN_LINTING""" + + +NO_LINTING, COLLECT_CARTESIAN_PRODUCTS, WARN_LINTING, FROM_LINTING = tuple( + Linting ) @@ -389,7 +428,7 @@ class Compiled: _cached_metadata = None - _result_columns = None + _result_columns: Optional[List[ResultColumnsEntry]] = None schema_translate_map = None @@ -418,7 +457,8 @@ class Compiled: """ cache_key = None - _gen_time = None + + _gen_time: float def __init__( self, @@ -573,15 +613,43 @@ class SQLCompiler(Compiled): extract_map = EXTRACT_MAP + _result_columns: List[ResultColumnsEntry] + compound_keywords = COMPOUND_KEYWORDS - isdelete = isinsert = isupdate = False + isdelete: bool = False + isinsert: bool = False + isupdate: bool = False """class-level defaults which can be set at the instance level to define if this Compiled instance represents INSERT/UPDATE/DELETE """ - isplaintext = False + postfetch: Optional[List[Column[Any]]] + """list of columns that can be post-fetched after INSERT or UPDATE to + receive server-updated values""" + + insert_prefetch: Optional[List[Column[Any]]] + """list of columns for which default values should be evaluated before + an INSERT takes place""" + + update_prefetch: Optional[List[Column[Any]]] + """list of columns for which onupdate default values should be evaluated + before an UPDATE takes place""" + + returning: Optional[List[Column[Any]]] + """list of columns that will be delivered to cursor.description or + dialect equivalent via the RETURNING clause on an INSERT, UPDATE, or DELETE + + """ + + isplaintext: bool = False + + result_columns: List[ResultColumnsEntry] + """relates label names in the final SQL to a tuple of local + column/label name, ColumnElement object (if any) and + TypeEngine. CursorResult uses this for type processing and + column targeting""" returning = None """holds the "returning" collection of columns if @@ -589,18 +657,18 @@ class SQLCompiler(Compiled): either implicitly or explicitly """ - returning_precedes_values = False + returning_precedes_values: bool = False """set to True classwide to generate RETURNING clauses before the VALUES or WHERE clause (i.e. MSSQL) """ - render_table_with_column_in_update_from = False + render_table_with_column_in_update_from: bool = False """set to True classwide to indicate the SET clause in a multi-table UPDATE statement should qualify columns with the table name (i.e. MySQL only) """ - ansi_bind_rules = False + ansi_bind_rules: bool = False """SQL 92 doesn't allow bind parameters to be used in the columns clause of a SELECT, nor does it allow ambiguous expressions like "? = ?". A compiler @@ -608,33 +676,33 @@ class SQLCompiler(Compiled): driver/DB enforces this """ - _textual_ordered_columns = False + _textual_ordered_columns: bool = False """tell the result object that the column names as rendered are important, but they are also "ordered" vs. what is in the compiled object here. """ - _ordered_columns = True + _ordered_columns: bool = True """ if False, means we can't be sure the list of entries in _result_columns is actually the rendered order. Usually True unless using an unordered TextualSelect. """ - _loose_column_name_matching = False + _loose_column_name_matching: bool = False """tell the result object that the SQL statement is textual, wants to match up to Column objects, and may be using the ._tq_label in the SELECT rather than the base name. """ - _numeric_binds = False + _numeric_binds: bool = False """ True if paramstyle is "numeric". This paramstyle is trickier than all the others. """ - _render_postcompile = False + _render_postcompile: bool = False """ whether to render out POSTCOMPILE params during the compile phase. @@ -684,7 +752,7 @@ class SQLCompiler(Compiled): """ - positiontup = None + positiontup: Optional[Sequence[str]] = None """for a compiled construct that uses a positional paramstyle, will be a sequence of strings, indicating the names of bound parameters in order. @@ -699,7 +767,7 @@ class SQLCompiler(Compiled): """ - inline = False + inline: bool = False def __init__( self, @@ -760,10 +828,6 @@ class SQLCompiler(Compiled): # stack which keeps track of nested SELECT statements self.stack = [] - # relates label names in the final SQL to a tuple of local - # column/label name, ColumnElement object (if any) and - # TypeEngine. CursorResult uses this for type processing and - # column targeting self._result_columns = [] # true if the paramstyle is positional @@ -910,7 +974,9 @@ class SQLCompiler(Compiled): ) @util.memoized_property - def _bind_processors(self): + def _bind_processors( + self, + ) -> MutableMapping[str, Union[_ProcessorType, Sequence[_ProcessorType]]]: return dict( (key, value) for key, value in ( @@ -1098,8 +1164,10 @@ class SQLCompiler(Compiled): return self.construct_params(_check=False) def _process_parameters_for_postcompile( - self, parameters=None, _populate_self=False - ): + self, + parameters: Optional[_CoreSingleExecuteParams] = None, + _populate_self: bool = False, + ) -> ExpandedState: """handle special post compile parameters. These include: @@ -3070,7 +3138,13 @@ class SQLCompiler(Compiled): def get_render_as_alias_suffix(self, alias_name_text): return " AS " + alias_name_text - def _add_to_result_map(self, keyname, name, objects, type_): + def _add_to_result_map( + self, + keyname: str, + name: str, + objects: List[Any], + type_: TypeEngine[Any], + ) -> None: if keyname is None or keyname == "*": self._ordered_columns = False self._textual_ordered_columns = True @@ -3080,7 +3154,9 @@ class SQLCompiler(Compiled): "from a tuple() object. If this is an ORM query, " "consider using the Bundle object." ) - self._result_columns.append((keyname, name, objects, type_)) + self._result_columns.append( + ResultColumnsEntry(keyname, name, objects, type_) + ) def _label_returning_column(self, stmt, column, column_clause_args=None): """Render a column with necessary labels inside of a RETURNING clause. diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 0c532a135a..ac5dc46db1 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -61,6 +61,11 @@ if typing.TYPE_CHECKING: from .selectable import Select from .sqltypes import Boolean # noqa from .type_api import TypeEngine + from ..engine import Compiled + from ..engine import Connection + from ..engine import Dialect + from ..engine import Engine + _NUMERIC = Union[complex, "Decimal"] @@ -145,7 +150,12 @@ class CompilerElement(Visitable): @util.preload_module("sqlalchemy.engine.default") @util.preload_module("sqlalchemy.engine.url") - def compile(self, bind=None, dialect=None, **kw): + def compile( + self, + bind: Optional[Union[Engine, Connection]] = None, + dialect: Optional[Dialect] = None, + **kw: Any, + ) -> Compiled: """Compile this SQL expression. The return value is a :class:`~.Compiled` object. diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index 5286917959..fdae4d7b04 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -174,7 +174,13 @@ class SchemaItem(SchemaEventTarget, visitors.Visitable): _use_schema_map = True -class Table(DialectKWArgs, SchemaItem, TableClause): +class HasSchemaAttr(SchemaItem): + """schema item that includes a top-level schema name""" + + schema: Optional[str] + + +class Table(DialectKWArgs, HasSchemaAttr, TableClause): r"""Represent a table in a database. e.g.:: @@ -2850,7 +2856,7 @@ class IdentityOptions: self.order = order -class Sequence(IdentityOptions, DefaultGenerator): +class Sequence(HasSchemaAttr, IdentityOptions, DefaultGenerator): """Represents a named database sequence. The :class:`.Sequence` object represents the name and configurational @@ -4330,7 +4336,7 @@ class Index(DialectKWArgs, ColumnCollectionMixin, SchemaItem): DEFAULT_NAMING_CONVENTION = util.immutabledict({"ix": "ix_%(column_0_label)s"}) -class MetaData(SchemaItem): +class MetaData(HasSchemaAttr): """A collection of :class:`_schema.Table` objects and their associated schema constructs. diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index e3e358cdb9..e0248adf0d 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -21,9 +21,6 @@ from . import coercions from . import operators from . import roles from . import visitors -from ._typing import _ExecuteParams -from ._typing import _MultiExecuteParams -from ._typing import _SingleExecuteParams from .annotation import _deep_annotate # noqa from .annotation import _deep_deannotate # noqa from .annotation import _shallow_annotate # noqa @@ -54,6 +51,10 @@ from .. import exc from .. import util if typing.TYPE_CHECKING: + from ..engine.interfaces import _AnyExecuteParams + from ..engine.interfaces import _AnyMultiExecuteParams + from ..engine.interfaces import _AnySingleExecuteParams + from ..engine.interfaces import _CoreSingleExecuteParams from ..engine.row import Row @@ -550,12 +551,12 @@ class _repr_params(_repr_base): def __init__( self, - params: _ExecuteParams, + params: Optional[_AnyExecuteParams], batches: int, max_chars: int = 300, ismulti: Optional[bool] = None, ): - self.params: _ExecuteParams = params + self.params = params self.ismulti = ismulti self.batches = batches self.max_chars = max_chars @@ -575,7 +576,10 @@ class _repr_params(_repr_base): return self.trunc(self.params) if self.ismulti: - multi_params = cast(_MultiExecuteParams, self.params) + multi_params = cast( + "_AnyMultiExecuteParams", + self.params, + ) if len(self.params) > self.batches: msg = ( @@ -595,10 +599,18 @@ class _repr_params(_repr_base): return self._repr_multi(multi_params, typ) else: return self._repr_params( - cast(_SingleExecuteParams, self.params), typ + cast( + "_AnySingleExecuteParams", + self.params, + ), + typ, ) - def _repr_multi(self, multi_params: _MultiExecuteParams, typ) -> str: + def _repr_multi( + self, + multi_params: _AnyMultiExecuteParams, + typ, + ) -> str: if multi_params: if isinstance(multi_params[0], list): elem_type = self._LIST @@ -622,13 +634,19 @@ class _repr_params(_repr_base): else: return "(%s)" % elements - def _repr_params(self, params: _SingleExecuteParams, typ: int) -> str: + def _repr_params( + self, + params: Optional[_AnySingleExecuteParams], + typ: int, + ) -> str: trunc = self.trunc if typ is self._DICT: return "{%s}" % ( ", ".join( "%r: %s" % (key, trunc(value)) - for key, value in params.items() + for key, value in cast( + "_CoreSingleExecuteParams", params + ).items() ) ) elif typ is self._TUPLE: diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index 523426d092..111ecd32ef 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -28,6 +28,8 @@ from __future__ import annotations from collections import deque import itertools import operator +import typing +from typing import Any from typing import List from typing import Tuple @@ -35,12 +37,13 @@ from .. import exc from .. import util from ..util import langhelpers from ..util import symbol +from ..util._has_cy import HAS_CYEXTENSION from ..util.langhelpers import _symbol -try: - from sqlalchemy.cyextension.util import cache_anon_map as anon_map # noqa -except ImportError: +if typing.TYPE_CHECKING or not HAS_CYEXTENSION: from ._py_util import cache_anon_map as anon_map # noqa +else: + from sqlalchemy.cyextension.util import cache_anon_map as anon_map # noqa __all__ = [ "iterate", @@ -554,7 +557,7 @@ class ExternalTraversal: __traverse_options__ = {} - def traverse_single(self, obj, **kw): + def traverse_single(self, obj: Visitable, **kw: Any) -> Any: for v in self.visitor_iterator: meth = getattr(v, "visit_%s" % obj.__visit_name__, None) if meth: diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py index a414205045..e5cf9b92e6 100644 --- a/lib/sqlalchemy/util/__init__.py +++ b/lib/sqlalchemy/util/__init__.py @@ -94,7 +94,6 @@ from .langhelpers import decode_slice as decode_slice from .langhelpers import decorator as decorator from .langhelpers import dictlike_iteritems as dictlike_iteritems from .langhelpers import duck_type_collection as duck_type_collection -from .langhelpers import dynamic_property as dynamic_property from .langhelpers import ellipses_string as ellipses_string from .langhelpers import EnsureKWArg as EnsureKWArg from .langhelpers import format_argspec_init as format_argspec_init diff --git a/lib/sqlalchemy/util/_collections.py b/lib/sqlalchemy/util/_collections.py index 84735316d2..e0b53b4450 100644 --- a/lib/sqlalchemy/util/_collections.py +++ b/lib/sqlalchemy/util/_collections.py @@ -513,7 +513,12 @@ class LRUCache(typing.MutableMapping[_KT, _VT]): threshold: float size_alert: Optional[Callable[["LRUCache[_KT, _VT]"], None]] - def __init__(self, capacity=100, threshold=0.5, size_alert=None): + def __init__( + self, + capacity: int = 100, + threshold: float = 0.5, + size_alert: Optional[Callable[..., None]] = None, + ): self.capacity = capacity self.threshold = threshold self.size_alert = size_alert diff --git a/lib/sqlalchemy/util/_py_collections.py b/lib/sqlalchemy/util/_py_collections.py index ee54180ac4..771e974e93 100644 --- a/lib/sqlalchemy/util/_py_collections.py +++ b/lib/sqlalchemy/util/_py_collections.py @@ -15,9 +15,11 @@ from typing import Dict from typing import Iterable from typing import Iterator from typing import List +from typing import Mapping from typing import NoReturn from typing import Optional from typing import Set +from typing import Tuple from typing import TypeVar from typing import Union @@ -65,13 +67,15 @@ class immutabledict(ImmutableDictBase[_KT, _VT]): dict.__init__(new, *args) return new - def __init__(self, *args): + def __init__(self, *args: Union[Mapping[_KT, _VT], Tuple[_KT, _VT]]): pass def __reduce__(self): return immutabledict, (dict(self),) - def union(self, __d=None): + def union( + self, __d: Optional[Mapping[_KT, _VT]] = None + ) -> immutabledict[_KT, _VT]: if not __d: return self @@ -80,7 +84,9 @@ class immutabledict(ImmutableDictBase[_KT, _VT]): dict.update(new, __d) return new - def _union_w_kw(self, __d=None, **kw): + def _union_w_kw( + self, __d: Optional[Mapping[_KT, _VT]] = None, **kw: _VT + ) -> immutabledict[_KT, _VT]: # not sure if C version works correctly w/ this yet if not __d and not kw: return self @@ -92,7 +98,9 @@ class immutabledict(ImmutableDictBase[_KT, _VT]): dict.update(new, kw) # type: ignore return new - def merge_with(self, *dicts): + def merge_with( + self, *dicts: Optional[Mapping[_KT, _VT]] + ) -> immutabledict[_KT, _VT]: new = None for d in dicts: if d: @@ -105,7 +113,7 @@ class immutabledict(ImmutableDictBase[_KT, _VT]): return new - def __repr__(self): + def __repr__(self) -> str: return "immutabledict(%s)" % dict.__repr__(self) diff --git a/lib/sqlalchemy/util/deprecations.py b/lib/sqlalchemy/util/deprecations.py index 7e1d3213ab..a8e58a8bfb 100644 --- a/lib/sqlalchemy/util/deprecations.py +++ b/lib/sqlalchemy/util/deprecations.py @@ -28,7 +28,6 @@ from . import compat from .langhelpers import _hash_limit_string from .langhelpers import _warnings_warn from .langhelpers import decorator -from .langhelpers import dynamic_property from .langhelpers import inject_docstring_text from .langhelpers import inject_param_text from .. import exc @@ -103,7 +102,7 @@ def deprecated_property( add_deprecation_to_docstring: bool = True, warning: Optional[Type[exc.SADeprecationWarning]] = None, enable_warnings: bool = True, -) -> Callable[[Callable[..., _T]], dynamic_property[_T]]: +) -> Callable[[Callable[..., Any]], property]: """the @deprecated decorator with a @property. E.g.:: @@ -131,8 +130,8 @@ def deprecated_property( """ - def decorate(fn: Callable[..., _T]) -> dynamic_property[_T]: - return dynamic_property( + def decorate(fn: Callable[..., Any]) -> property: + return property( deprecated( version, message=message, diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py index 43f9d5c73f..1e79fd5474 100644 --- a/lib/sqlalchemy/util/langhelpers.py +++ b/lib/sqlalchemy/util/langhelpers.py @@ -317,15 +317,17 @@ _P = compat_typing.ParamSpec("_P") class PluginLoader: - def __init__(self, group, auto_fn=None): + def __init__( + self, group: str, auto_fn: Optional[Callable[..., Any]] = None + ): self.group = group - self.impls = {} + self.impls: Dict[str, Any] = {} self.auto_fn = auto_fn def clear(self): self.impls.clear() - def load(self, name): + def load(self, name: str) -> Any: if name in self.impls: return self.impls[name]() @@ -344,7 +346,7 @@ class PluginLoader: "Can't load plugin: %s:%s" % (self.group, name) ) - def register(self, name, modulepath, objname): + def register(self, name: str, modulepath: str, objname: str) -> None: def load(): mod = __import__(modulepath) for token in modulepath.split(".")[1:]: @@ -444,7 +446,7 @@ def get_cls_kwargs( return _set -def get_func_kwargs(func): +def get_func_kwargs(func: Callable[..., Any]) -> List[str]: """Return the set of legal kwargs for the given `func`. Uses getargspec so is safe to call for methods, functions, @@ -1125,22 +1127,13 @@ def as_interface(obj, cls=None, methods=None, required=None): ) -Selfdynamic_property = TypeVar( - "Selfdynamic_property", bound="dynamic_property[Any]" -) - Selfmemoized_property = TypeVar( "Selfmemoized_property", bound="memoized_property[Any]" ) -class dynamic_property(Generic[_T]): - """A read-only @property that is evaluated each time. - - This is mostly the same as @property except we can type it - alongside memoized_property - - """ +class memoized_property(Generic[_T]): + """A read-only @property that is only evaluated once.""" fget: Callable[..., _T] __doc__: Optional[str] @@ -1151,27 +1144,6 @@ class dynamic_property(Generic[_T]): self.__doc__ = doc or fget.__doc__ self.__name__ = fget.__name__ - @overload - def __get__( - self: Selfdynamic_property, obj: None, cls: Any - ) -> Selfdynamic_property: - ... - - @overload - def __get__(self, obj: Any, cls: Any) -> _T: - ... - - def __get__( - self: Selfdynamic_property, obj: Any, cls: Any - ) -> Union[Selfdynamic_property, _T]: - if obj is None: - return self - return self.fget(obj) # type: ignore[no-any-return] - - -class memoized_property(dynamic_property[_T]): - """A read-only @property that is only evaluated once.""" - @overload def __get__( self: Selfmemoized_property, obj: None, cls: Any @@ -1231,24 +1203,27 @@ def memoized_instancemethod(fn): class HasMemoized: - """A class that maintains the names of memoized elements in a + """A mixin class that maintains the names of memoized elements in a collection for easy cache clearing, generative, etc. """ - __slots__ = () + if not typing.TYPE_CHECKING: + # support classes that want to have __slots__ with an explicit + # slot for __dict__. not sure if that requires base __slots__ here. + __slots__ = () _memoized_keys: FrozenSet[str] = frozenset() - def _reset_memoizations(self): + def _reset_memoizations(self) -> None: for elem in self._memoized_keys: self.__dict__.pop(elem, None) - def _assert_no_memoizations(self): + def _assert_no_memoizations(self) -> None: for elem in self._memoized_keys: assert elem not in self.__dict__ - def _set_memoized_attribute(self, key, value): + def _set_memoized_attribute(self, key: str, value: Any) -> None: self.__dict__[key] = value self._memoized_keys |= {key} @@ -1342,7 +1317,7 @@ class MemoizedSlots: # from paste.deploy.converters -def asbool(obj): +def asbool(obj: Any) -> bool: if isinstance(obj, str): obj = obj.strip().lower() if obj in ["true", "yes", "on", "y", "t", "1"]: @@ -1354,13 +1329,13 @@ def asbool(obj): return bool(obj) -def bool_or_str(*text): +def bool_or_str(*text: str) -> Callable[[str], Union[str, bool]]: """Return a callable that will evaluate a string as boolean, or one of a set of "alternate" string values. """ - def bool_or_value(obj): + def bool_or_value(obj: str) -> Union[str, bool]: if obj in text: return obj else: @@ -1369,7 +1344,7 @@ def bool_or_str(*text): return bool_or_value -def asint(value): +def asint(value: Any) -> Optional[int]: """Coerce to integer.""" if value is None: @@ -1377,7 +1352,13 @@ def asint(value): return int(value) -def coerce_kw_type(kw, key, type_, flexi_bool=True, dest=None): +def coerce_kw_type( + kw: Dict[str, Any], + key: str, + type_: Type[Any], + flexi_bool: bool = True, + dest: Optional[Dict[str, Any]] = None, +) -> None: r"""If 'key' is present in dict 'kw', coerce its value to type 'type\_' if necessary. If 'flexi_bool' is True, the string '0' is considered false when coercing to boolean. @@ -1397,7 +1378,7 @@ def coerce_kw_type(kw, key, type_, flexi_bool=True, dest=None): dest[key] = type_(kw[key]) -def constructor_key(obj, cls): +def constructor_key(obj: Any, cls: Type[Any]) -> Tuple[Any, ...]: """Produce a tuple structure that is cacheable using the __dict__ of obj to retrieve values @@ -1408,7 +1389,7 @@ def constructor_key(obj, cls): ) -def constructor_copy(obj, cls, *args, **kw): +def constructor_copy(obj: _T, cls: Type[_T], *args: Any, **kw: Any) -> _T: """Instantiate cls using the __dict__ of obj as constructor arguments. Uses inspect to match the named arguments of ``cls``. @@ -1422,7 +1403,7 @@ def constructor_copy(obj, cls, *args, **kw): return cls(*args, **kw) -def counter(): +def counter() -> Callable[[], int]: """Return a threadsafe counter function.""" lock = threading.Lock() @@ -1436,47 +1417,51 @@ def counter(): return _next -def duck_type_collection(specimen, default=None): +def duck_type_collection( + specimen: Union[object, Type[Any]], default: Optional[Type[Any]] = None +) -> Type[Any]: """Given an instance or class, guess if it is or is acting as one of the basic collection types: list, set and dict. If the __emulates__ property is present, return that preferentially. """ + if typing.TYPE_CHECKING: + return object + else: + if hasattr(specimen, "__emulates__"): + # canonicalize set vs sets.Set to a standard: the builtin set + if specimen.__emulates__ is not None and issubclass( + specimen.__emulates__, set + ): + return set + else: + return specimen.__emulates__ - if hasattr(specimen, "__emulates__"): - # canonicalize set vs sets.Set to a standard: the builtin set - if specimen.__emulates__ is not None and issubclass( - specimen.__emulates__, set - ): + isa = isinstance(specimen, type) and issubclass or isinstance + if isa(specimen, list): + return list + elif isa(specimen, set): return set + elif isa(specimen, dict): + return dict + + if hasattr(specimen, "append"): + return list + elif hasattr(specimen, "add"): + return set + elif hasattr(specimen, "set"): + return dict else: - return specimen.__emulates__ - - isa = isinstance(specimen, type) and issubclass or isinstance - if isa(specimen, list): - return list - elif isa(specimen, set): - return set - elif isa(specimen, dict): - return dict - - if hasattr(specimen, "append"): - return list - elif hasattr(specimen, "add"): - return set - elif hasattr(specimen, "set"): - return dict - else: - return default + return default -def assert_arg_type(arg, argtype, name): +def assert_arg_type(arg: Any, argtype: Type[Any], name: str) -> Any: if isinstance(arg, argtype): return arg else: if isinstance(argtype, tuple): raise exc.ArgumentError( "Argument '%s' is expected to be one of type %s, got '%s'" - % (name, " or ".join("'%s'" % a for a in argtype), type(arg)) + % (name, " or ".join("'%s'" % a for a in argtype), type(arg)) # type: ignore # noqa E501 ) else: raise exc.ArgumentError( diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py index ddda420db1..ad9c8e5314 100644 --- a/lib/sqlalchemy/util/typing.py +++ b/lib/sqlalchemy/util/typing.py @@ -13,7 +13,7 @@ from typing import Type from typing import TypeVar from typing import Union -from typing_extensions import NotRequired # noqa +from typing_extensions import NotRequired as NotRequired # noqa from . import compat diff --git a/pyproject.toml b/pyproject.toml index e79c7292df..b2754b193d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,14 @@ markers = [ [tool.pyright] include = [ + "lib/sqlalchemy/engine/base.py", + "lib/sqlalchemy/engine/events.py", + "lib/sqlalchemy/engine/interfaces.py", + "lib/sqlalchemy/engine/_py_row.py", + "lib/sqlalchemy/engine/result.py", + "lib/sqlalchemy/engine/row.py", + "lib/sqlalchemy/engine/util.py", + "lib/sqlalchemy/engine/url.py", "lib/sqlalchemy/pool/", "lib/sqlalchemy/event/", "lib/sqlalchemy/events.py", @@ -79,9 +87,19 @@ strict = true # the whole library 100% strictly typed, so we have to tune this based on # the type of module or package we are dealing with +[[tool.mypy.overrides]] +# ad-hoc ignores +module = [ + "sqlalchemy.engine.reflection", # interim, should be strict +] + +ignore_errors = true + # strict checking [[tool.mypy.overrides]] module = [ + "sqlalchemy.connectors.*", + "sqlalchemy.engine.*", "sqlalchemy.pool.*", "sqlalchemy.event.*", "sqlalchemy.events", @@ -95,11 +113,16 @@ strict = true # partial checking, internals can be untyped [[tool.mypy.overrides]] -module="sqlalchemy.util.*" + +module = [ + "sqlalchemy.util.*", + "sqlalchemy.engine.cursor", + "sqlalchemy.engine.default", +] + + ignore_errors = false -# util is for internal use so we can get by without everything -# being typed allow_untyped_defs = true check_untyped_defs = false allow_untyped_calls = true diff --git a/test/base/test_result.py b/test/base/test_result.py index 8818ccb145..7a696d352a 100644 --- a/test/base/test_result.py +++ b/test/base/test_result.py @@ -1,7 +1,6 @@ from sqlalchemy import exc from sqlalchemy import testing from sqlalchemy.engine import result -from sqlalchemy.engine.row import Row from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import eq_ @@ -728,27 +727,6 @@ class ResultTest(fixtures.TestBase): # still slices eq_(m1.fetchone(), {"b": 1, "c": 2}) - def test_alt_row_fetch(self): - class AppleRow(Row): - def apple(self): - return "apple" - - result = self._fixture(alt_row=AppleRow) - - row = result.all()[0] - eq_(row.apple(), "apple") - - def test_alt_row_transform(self): - class AppleRow(Row): - def apple(self): - return "apple" - - result = self._fixture(alt_row=AppleRow) - - row = result.columns("c", "a").all()[2] - eq_(row.apple(), "apple") - eq_(row, (2, 1)) - def test_scalar_none_iterate(self): result = self._fixture( data=[ diff --git a/test/dialect/mssql/test_engine.py b/test/dialect/mssql/test_engine.py index b40981a99c..d54a37cebc 100644 --- a/test/dialect/mssql/test_engine.py +++ b/test/dialect/mssql/test_engine.py @@ -34,19 +34,19 @@ class ParseConnectTest(fixtures.TestBase): dialect = pyodbc.dialect() u = url.make_url("mssql+pyodbc://mydsn") connection = dialect.create_connect_args(u) - eq_([["dsn=mydsn;Trusted_Connection=Yes"], {}], connection) + eq_((("dsn=mydsn;Trusted_Connection=Yes",), {}), connection) def test_pyodbc_connect_old_style_dsn_trusted(self): dialect = pyodbc.dialect() u = url.make_url("mssql+pyodbc:///?dsn=mydsn") connection = dialect.create_connect_args(u) - eq_([["dsn=mydsn;Trusted_Connection=Yes"], {}], connection) + eq_((("dsn=mydsn;Trusted_Connection=Yes",), {}), connection) def test_pyodbc_connect_dsn_non_trusted(self): dialect = pyodbc.dialect() u = url.make_url("mssql+pyodbc://username:password@mydsn") connection = dialect.create_connect_args(u) - eq_([["dsn=mydsn;UID=username;PWD=password"], {}], connection) + eq_((("dsn=mydsn;UID=username;PWD=password",), {}), connection) def test_pyodbc_connect_dsn_extra(self): dialect = pyodbc.dialect() @@ -66,13 +66,13 @@ class ParseConnectTest(fixtures.TestBase): ) connection = dialect.create_connect_args(u) eq_( - [ - [ + ( + ( "DRIVER={SQL Server};Server=hostspec;Database=database;UI" - "D=username;PWD=password" - ], + "D=username;PWD=password", + ), {}, - ], + ), connection, ) @@ -99,13 +99,13 @@ class ParseConnectTest(fixtures.TestBase): ) eq_( - [ - [ + ( + ( "Server=hostspec;Database=database;UI" - "D=username;PWD=password" - ], + "D=username;PWD=password", + ), {}, - ], + ), connection, ) @@ -117,13 +117,13 @@ class ParseConnectTest(fixtures.TestBase): ) connection = dialect.create_connect_args(u) eq_( - [ - [ + ( + ( "DRIVER={SQL Server};Server=hostspec,12345;Database=datab" - "ase;UID=username;PWD=password" - ], + "ase;UID=username;PWD=password", + ), {}, - ], + ), connection, ) @@ -135,13 +135,13 @@ class ParseConnectTest(fixtures.TestBase): ) connection = dialect.create_connect_args(u) eq_( - [ - [ + ( + ( "DRIVER={SQL Server};Server=hostspec;Database=database;UI" - "D=username;PWD=password;port=12345" - ], + "D=username;PWD=password;port=12345", + ), {}, - ], + ), connection, ) @@ -193,13 +193,13 @@ class ParseConnectTest(fixtures.TestBase): ) connection = dialect.create_connect_args(u) eq_( - [ - [ + ( + ( "DRIVER={SQL Server};Server=hostspec;Database=database;UI" - "D=username;PWD=password" - ], + "D=username;PWD=password", + ), {}, - ], + ), connection, ) @@ -211,7 +211,7 @@ class ParseConnectTest(fixtures.TestBase): ) connection = dialect.create_connect_args(u) eq_( - [["dsn=mydsn;Database=database;UID=username;PWD=password"], {}], + (("dsn=mydsn;Database=database;UID=username;PWD=password",), {}), connection, ) @@ -225,13 +225,13 @@ class ParseConnectTest(fixtures.TestBase): ) connection = dialect.create_connect_args(u) eq_( - [ - [ + ( + ( "DRIVER={SQL Server};Server=hostspec;Database=database;UI" - "D=username;PWD=password" - ], + "D=username;PWD=password", + ), {}, - ], + ), connection, ) @@ -248,14 +248,14 @@ class ParseConnectTest(fixtures.TestBase): dialect = pyodbc.dialect() connection = dialect.create_connect_args(u) eq_( - [ - [ + ( + ( "DRIVER={foob};Server=somehost%3BPORT%3D50001;" "Database=somedb%3BPORT%3D50001;UID={someuser;PORT=50001};" - "PWD={some{strange}}pw;PORT=50001}" - ], + "PWD={some{strange}}pw;PORT=50001}", + ), {}, - ], + ), connection, ) @@ -265,7 +265,7 @@ class ParseConnectTest(fixtures.TestBase): u = url.make_url("mssql+pymssql://scott:tiger@somehost/test") connection = dialect.create_connect_args(u) eq_( - [ + ( [], { "host": "somehost", @@ -273,14 +273,14 @@ class ParseConnectTest(fixtures.TestBase): "user": "scott", "database": "test", }, - ], + ), connection, ) u = url.make_url("mssql+pymssql://scott:tiger@somehost:5000/test") connection = dialect.create_connect_args(u) eq_( - [ + ( [], { "host": "somehost:5000", @@ -288,7 +288,7 @@ class ParseConnectTest(fixtures.TestBase): "user": "scott", "database": "test", }, - ], + ), connection, ) @@ -584,7 +584,9 @@ class VersionDetectionTest(fixtures.TestBase): ) ) ), - connection=Mock(getinfo=Mock(return_value=vers)), + connection=Mock( + dbapi_connection=Mock(getinfo=Mock(return_value=vers)), + ), ) eq_(dialect._get_server_version_info(conn), expected) diff --git a/test/dialect/postgresql/test_dialect.py b/test/dialect/postgresql/test_dialect.py index 613fc80a5a..5a92ae6fe0 100644 --- a/test/dialect/postgresql/test_dialect.py +++ b/test/dialect/postgresql/test_dialect.py @@ -368,11 +368,11 @@ class ExecuteManyMode: mock.call( mock.ANY, stmt, - ( + [ {"x": "x1", "y": "y1"}, {"x": "x2", "y": "y2"}, {"x": "x3", "y": "y3"}, - ), + ], **expected_kwargs, ) ], @@ -417,11 +417,11 @@ class ExecuteManyMode: mock.call( mock.ANY, stmt, - ( + [ {"x": "x1", "y": "y1"}, {"x": "x2", "y": "y2"}, {"x": "x3", "y": "y3"}, - ), + ], **expected_kwargs, ) ], @@ -470,11 +470,11 @@ class ExecuteManyMode: mock.call( mock.ANY, stmt, - ( + [ {"x": "x1", "y": "y1"}, {"x": "x2", "y": "y2"}, {"x": "x3", "y": "y3"}, - ), + ], **expected_kwargs, ) ], @@ -524,10 +524,10 @@ class ExecuteManyMode: mock.call( mock.ANY, stmt, - ( + [ {"xval": "x1", "yval": "y5"}, {"xval": "x3", "yval": "y6"}, - ), + ], **expected_kwargs, ) ], @@ -714,11 +714,11 @@ class ExecutemanyValuesInsertsTest(ExecuteManyMode, fixtures.TablesTest): mock.call( mock.ANY, "INSERT INTO data (id, x, y, z) VALUES %s", - ( + [ {"id": 1, "y": "y1", "z": 1}, {"id": 2, "y": "y2", "z": 2}, {"id": 3, "y": "y3", "z": 3}, - ), + ], template="(%(id)s, (SELECT 5 \nFROM data), %(y)s, %(z)s)", fetch=False, page_size=connection.dialect.executemany_values_page_size, diff --git a/test/engine/test_deprecations.py b/test/engine/test_deprecations.py index 8dc1f0f484..e1c610701c 100644 --- a/test/engine/test_deprecations.py +++ b/test/engine/test_deprecations.py @@ -87,6 +87,93 @@ class ConnectionlessDeprecationTest(fixtures.TestBase): class CreateEngineTest(fixtures.TestBase): + @testing.requires.sqlite + def test_dbapi_clsmethod_renamed(self): + """The dbapi() class method is renamed to import_dbapi(), + so that the .dbapi attribute can be exclusively an instance + attribute. + + """ + + from sqlalchemy.dialects.sqlite import pysqlite + from sqlalchemy.dialects import registry + + canary = mock.Mock() + + class MyDialect(pysqlite.SQLiteDialect_pysqlite): + @classmethod + def dbapi(cls): + canary() + return __import__("sqlite3") + + tokens = __name__.split(".") + + global dialect + dialect = MyDialect + + registry.register( + "mockdialect1.sqlite", ".".join(tokens[0:-1]), tokens[-1] + ) + + with expect_deprecated( + r"The dbapi\(\) classmethod on dialect classes has " + r"been renamed to import_dbapi\(\). Implement an " + r"import_dbapi\(\) classmethod directly on class " + r".*MyDialect.* to remove this warning; the old " + r".dbapi\(\) classmethod may be maintained for backwards " + r"compatibility." + ): + e = create_engine("mockdialect1+sqlite://") + + eq_(canary.mock_calls, [mock.call()]) + sqlite3 = __import__("sqlite3") + is_(e.dialect.dbapi, sqlite3) + + @testing.requires.sqlite + def test_no_warning_for_dual_dbapi_clsmethod(self): + """The dbapi() class method is renamed to import_dbapi(), + so that the .dbapi attribute can be exclusively an instance + attribute. + + Dialect classes will likely have both a dbapi() classmethod + as well as an import_dbapi() class method to maintain + cross-compatibility. Make sure these updated classes don't get a + warning and that the new method is used. + + """ + + from sqlalchemy.dialects.sqlite import pysqlite + from sqlalchemy.dialects import registry + + canary = mock.Mock() + + class MyDialect(pysqlite.SQLiteDialect_pysqlite): + @classmethod + def dbapi(cls): + canary.dbapi() + return __import__("sqlite3") + + @classmethod + def import_dbapi(cls): + canary.import_dbapi() + return __import__("sqlite3") + + tokens = __name__.split(".") + + global dialect + dialect = MyDialect + + registry.register( + "mockdialect2.sqlite", ".".join(tokens[0:-1]), tokens[-1] + ) + + # no warning + e = create_engine("mockdialect2+sqlite://") + + eq_(canary.mock_calls, [mock.call.import_dbapi()]) + sqlite3 = __import__("sqlite3") + is_(e.dialect.dbapi, sqlite3) + def test_strategy_keyword_mock(self): def executor(x, y): pass diff --git a/test/engine/test_execute.py b/test/engine/test_execute.py index 59bc4863fb..dbd957703f 100644 --- a/test/engine/test_execute.py +++ b/test/engine/test_execute.py @@ -2285,6 +2285,40 @@ class EngineEventsTest(fixtures.TestBase): [call(c2, {"c1": "opt_c1"}), call(c4, {"c3": "opt_c3"})], ) + def test_execution_options_modify_inplace(self): + engine = engines.testing_engine() + + @event.listens_for(engine, "set_engine_execution_options") + def engine_tracker(conn, opt): + opt["engine_tracked"] = True + + @event.listens_for(engine, "set_connection_execution_options") + def conn_tracker(conn, opt): + opt["conn_tracked"] = True + + with mock.patch.object( + engine.dialect, "set_connection_execution_options" + ) as conn_opt, mock.patch.object( + engine.dialect, "set_engine_execution_options" + ) as engine_opt: + e2 = engine.execution_options(e1="opt_e1") + c1 = engine.connect() + c2 = c1.execution_options(c1="opt_c1") + + is_not(e2, engine) + is_(c1, c2) + + eq_(e2._execution_options, {"e1": "opt_e1", "engine_tracked": True}) + eq_(c2._execution_options, {"c1": "opt_c1", "conn_tracked": True}) + eq_( + engine_opt.mock_calls, + [mock.call(e2, {"e1": "opt_e1", "engine_tracked": True})], + ) + eq_( + conn_opt.mock_calls, + [mock.call(c1, {"c1": "opt_c1", "conn_tracked": True})], + ) + @testing.requires.sequences @testing.provide_metadata def test_cursor_execute(self): diff --git a/test/engine/test_parseconnect.py b/test/engine/test_parseconnect.py index 4c378dda18..23be61aaf8 100644 --- a/test/engine/test_parseconnect.py +++ b/test/engine/test_parseconnect.py @@ -1111,7 +1111,7 @@ class TestGetDialect(fixtures.TestBase): class MockDialect(DefaultDialect): @classmethod - def dbapi(cls, **kw): + def import_dbapi(cls, **kw): return MockDBAPI() diff --git a/test/engine/test_reconnect.py b/test/engine/test_reconnect.py index c1f0639bb8..b8d9a3618c 100644 --- a/test/engine/test_reconnect.py +++ b/test/engine/test_reconnect.py @@ -1023,6 +1023,23 @@ class RealReconnectTest(fixtures.TestBase): eq_(conn.execute(select(1)).scalar(), 1) assert not conn.invalidated + def test_detach_invalidated(self): + with self.engine.connect() as conn: + conn.invalidate() + with expect_raises_message( + exc.InvalidRequestError, + "Can't detach an invalidated Connection", + ): + conn.detach() + + def test_detach_closed(self): + with self.engine.connect() as conn: + pass + with expect_raises_message( + exc.ResourceClosedError, "This Connection is closed" + ): + conn.detach() + @testing.requires.independent_connections def test_multiple_invalidate(self): c1 = self.engine.connect() @@ -1078,8 +1095,23 @@ class RealReconnectTest(fixtures.TestBase): conn.begin() trans2 = conn.begin_nested() conn.invalidate() + + # this passes silently, as it will often be involved + # in error catching schemes trans2.rollback() + # still invalid though + with expect_raises(exc.PendingRollbackError): + conn.begin_nested() + + def test_no_begin_on_invalid(self): + with self.engine.connect() as conn: + conn.begin() + conn.invalidate() + + with expect_raises(exc.PendingRollbackError): + conn.commit() + def test_invalidate_twice(self): with self.engine.connect() as conn: conn.invalidate() diff --git a/test/profiles.txt b/test/profiles.txt index 67f155ebac..750b577800 100644 --- a/test/profiles.txt +++ b/test/profiles.txt @@ -98,48 +98,48 @@ test.aaa_profiling.test_misc.EnumTest.test_create_enum_from_pep_435_w_expensive_ # TEST: test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_bundle_w_annotation -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_bundle_w_annotation x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 50235 -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_bundle_w_annotation x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 62055 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_bundle_w_annotation x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 50735 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_bundle_w_annotation x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 61045 # TEST: test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_bundle_wo_annotation -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_bundle_wo_annotation x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 49335 -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_bundle_wo_annotation x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 61155 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_bundle_wo_annotation x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 49435 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_bundle_wo_annotation x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 59745 # TEST: test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_entity_w_annotations -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_entity_w_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 53235 -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_entity_w_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 63155 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_entity_w_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 53335 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_entity_w_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 61745 # TEST: test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_entity_wo_annotations -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_entity_wo_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 52335 -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_entity_wo_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 62255 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_entity_wo_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 52435 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_entity_wo_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 60845 # TEST: test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 45435 -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 49255 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 45035 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 48345 # TEST: test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle_w_annotations -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle_w_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 48235 -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle_w_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 57555 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle_w_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 48335 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle_w_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 56145 # TEST: test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle_wo_annotations -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle_wo_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 47335 -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle_wo_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 56655 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle_wo_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 47435 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_bundle_wo_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 55245 # TEST: test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_entity_w_annotations -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_entity_w_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 34205 -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_entity_w_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 37905 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_entity_w_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 33805 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_entity_w_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 37005 # TEST: test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_entity_wo_annotations -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_entity_wo_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 33305 -test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_entity_wo_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 37005 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_entity_wo_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 32905 +test.aaa_profiling.test_orm.AnnotatedOverheadTest.test_no_entity_wo_annotations x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 36105 # TEST: test.aaa_profiling.test_orm.AttributeOverheadTest.test_attribute_set @@ -163,13 +163,13 @@ test.aaa_profiling.test_orm.BranchedOptionTest.test_query_opts_unbound_branching # TEST: test.aaa_profiling.test_orm.DeferOptionsTest.test_baseline -test.aaa_profiling.test_orm.DeferOptionsTest.test_baseline x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 15227 -test.aaa_profiling.test_orm.DeferOptionsTest.test_baseline x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 34246 +test.aaa_profiling.test_orm.DeferOptionsTest.test_baseline x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 15313 +test.aaa_profiling.test_orm.DeferOptionsTest.test_baseline x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 26332 # TEST: test.aaa_profiling.test_orm.DeferOptionsTest.test_defer_many_cols -test.aaa_profiling.test_orm.DeferOptionsTest.test_defer_many_cols x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 21393 -test.aaa_profiling.test_orm.DeferOptionsTest.test_defer_many_cols x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 28412 +test.aaa_profiling.test_orm.DeferOptionsTest.test_defer_many_cols x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 21377 +test.aaa_profiling.test_orm.DeferOptionsTest.test_defer_many_cols x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 26396 # TEST: test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_b_aliased @@ -188,23 +188,23 @@ test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_b_plain x86_64_linux_cpy # TEST: test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d -test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 98439 -test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 103939 +test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 98506 +test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 104006 # TEST: test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d_aliased -test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d_aliased x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 96819 -test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d_aliased x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 102304 +test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d_aliased x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 96844 +test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d_aliased x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 102344 # TEST: test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_build_query -test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_build_query x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 520593 -test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_build_query x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 522453 +test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_build_query x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 520615 +test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_build_query x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 522475 # TEST: test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_fetch_results -test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_fetch_results x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 431505 -test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_fetch_results x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 464305 +test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_fetch_results x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 431905 +test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_fetch_results x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 450605 # TEST: test.aaa_profiling.test_orm.LoadManyToOneFromIdentityTest.test_many_to_one_load_identity @@ -213,18 +213,18 @@ test.aaa_profiling.test_orm.LoadManyToOneFromIdentityTest.test_many_to_one_load_ # TEST: test.aaa_profiling.test_orm.LoadManyToOneFromIdentityTest.test_many_to_one_load_no_identity -test.aaa_profiling.test_orm.LoadManyToOneFromIdentityTest.test_many_to_one_load_no_identity x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 106782 -test.aaa_profiling.test_orm.LoadManyToOneFromIdentityTest.test_many_to_one_load_no_identity x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 115789 +test.aaa_profiling.test_orm.LoadManyToOneFromIdentityTest.test_many_to_one_load_no_identity x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 106870 +test.aaa_profiling.test_orm.LoadManyToOneFromIdentityTest.test_many_to_one_load_no_identity x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 115127 # TEST: test.aaa_profiling.test_orm.MergeBackrefsTest.test_merge_pending_with_all_pks -test.aaa_profiling.test_orm.MergeBackrefsTest.test_merge_pending_with_all_pks x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 19931 -test.aaa_profiling.test_orm.MergeBackrefsTest.test_merge_pending_with_all_pks x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 21463 +test.aaa_profiling.test_orm.MergeBackrefsTest.test_merge_pending_with_all_pks x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 20030 +test.aaa_profiling.test_orm.MergeBackrefsTest.test_merge_pending_with_all_pks x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 21434 # TEST: test.aaa_profiling.test_orm.MergeTest.test_merge_load -test.aaa_profiling.test_orm.MergeTest.test_merge_load x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 1362 -test.aaa_profiling.test_orm.MergeTest.test_merge_load x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 1458 +test.aaa_profiling.test_orm.MergeTest.test_merge_load x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 1366 +test.aaa_profiling.test_orm.MergeTest.test_merge_load x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 1455 # TEST: test.aaa_profiling.test_orm.MergeTest.test_merge_no_load @@ -233,18 +233,18 @@ test.aaa_profiling.test_orm.MergeTest.test_merge_no_load x86_64_linux_cpython_3. # TEST: test.aaa_profiling.test_orm.QueryTest.test_query_cols -test.aaa_profiling.test_orm.QueryTest.test_query_cols x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 6109 -test.aaa_profiling.test_orm.QueryTest.test_query_cols x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 7329 +test.aaa_profiling.test_orm.QueryTest.test_query_cols x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 6167 +test.aaa_profiling.test_orm.QueryTest.test_query_cols x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 6987 # TEST: test.aaa_profiling.test_orm.SelectInEagerLoadTest.test_round_trip_results -test.aaa_profiling.test_orm.SelectInEagerLoadTest.test_round_trip_results x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 258605 -test.aaa_profiling.test_orm.SelectInEagerLoadTest.test_round_trip_results x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 281805 +test.aaa_profiling.test_orm.SelectInEagerLoadTest.test_round_trip_results x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 259205 +test.aaa_profiling.test_orm.SelectInEagerLoadTest.test_round_trip_results x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 278405 # TEST: test.aaa_profiling.test_orm.SessionTest.test_expire_lots -test.aaa_profiling.test_orm.SessionTest.test_expire_lots x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 1276 -test.aaa_profiling.test_orm.SessionTest.test_expire_lots x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 1268 +test.aaa_profiling.test_orm.SessionTest.test_expire_lots x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 1252 +test.aaa_profiling.test_orm.SessionTest.test_expire_lots x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 1260 # TEST: test.aaa_profiling.test_pool.QueuePoolTest.test_first_connect diff --git a/test/sql/test_resultset.py b/test/sql/test_resultset.py index e5b1a0a269..ff70fc184a 100644 --- a/test/sql/test_resultset.py +++ b/test/sql/test_resultset.py @@ -1612,15 +1612,15 @@ class CursorResultTest(fixtures.TablesTest): eq_(dict(row._mapping), {"a": "av", "b": "bv", "count": "cv"}) - with assertions.expect_raises_message( + with assertions.expect_raises( TypeError, - "TypeError: tuple indices must be integers or slices, not str", + "tuple indices must be integers or slices, not str", ): eq_(row["a"], "av") with assertions.expect_raises_message( TypeError, - "TypeError: tuple indices must be integers or slices, not str", + "tuple indices must be integers or slices, not str", ): eq_(row["count"], "cv") @@ -3197,8 +3197,7 @@ class GenerativeResultTest(fixtures.TablesTest): all_ = result.columns(*columns).all() eq_(all_, expected) - # ensure Row / LegacyRow comes out with .columns - assert type(all_[0]) is result._process_row + assert type(all_[0]) is Row def test_columns_twice(self, connection): users = self.tables.users @@ -3216,8 +3215,7 @@ class GenerativeResultTest(fixtures.TablesTest): ) eq_(all_, [("jack", 1)]) - # ensure Row / LegacyRow comes out with .columns - assert type(all_[0]) is result._process_row + assert type(all_[0]) is Row def test_columns_plus_getter(self, connection): users = self.tables.users