From ed3fa95e1671f3d920068f6bd36d23b98f9533bd Mon Sep 17 00:00:00 2001 From: Eugene Toder Date: Mon, 11 Mar 2024 07:42:47 -0400 Subject: [PATCH] Allow using AsyncEngine in compile This works, so only need to update the type annotation. This pull request is: - [x] A documentation / typographical / small typing error fix - Good to go, no issue or tests are needed Closes: #11103 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/11103 Pull-request-sha: ba9e61a3902d5666a5176aedd50afe8ae7762bff Change-Id: I3d08b930a8cae0539bf9b436d5e806d8912cdee0 (cherry picked from commit d2a743d0bcd88129f571f2256cd18f1b02036fd2) --- lib/sqlalchemy/engine/base.py | 1 + lib/sqlalchemy/ext/asyncio/engine.py | 4 ++-- lib/sqlalchemy/sql/_typing.py | 10 ++++++++ lib/sqlalchemy/sql/elements.py | 6 ++--- test/ext/asyncio/test_engine_py3k.py | 24 ++++++++++++------- test/typing/plain_files/engine/engines.py | 5 ++++ .../typing/plain_files/ext/asyncio/engines.py | 6 +++++ 7 files changed, 43 insertions(+), 13 deletions(-) diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 3c11d14d5b..403ec452b9 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -109,6 +109,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): """ + dialect: Dialect dispatch: dispatcher[ConnectionEventsTarget] _sqla_logger_namespace = "sqlalchemy.engine.Connection" diff --git a/lib/sqlalchemy/ext/asyncio/engine.py b/lib/sqlalchemy/ext/asyncio/engine.py index 5d7d7e6b42..8fc8e96db0 100644 --- a/lib/sqlalchemy/ext/asyncio/engine.py +++ b/lib/sqlalchemy/ext/asyncio/engine.py @@ -930,7 +930,7 @@ class AsyncConnection( return self._proxied.invalidated @property - def dialect(self) -> Any: + def dialect(self) -> Dialect: r"""Proxy for the :attr:`_engine.Connection.dialect` attribute on behalf of the :class:`_asyncio.AsyncConnection` class. @@ -939,7 +939,7 @@ class AsyncConnection( return self._proxied.dialect @dialect.setter - def dialect(self, attr: Any) -> None: + def dialect(self, attr: Dialect) -> None: self._proxied.dialect = attr @property diff --git a/lib/sqlalchemy/sql/_typing.py b/lib/sqlalchemy/sql/_typing.py index ea9cbe1f48..ba5faffd4d 100644 --- a/lib/sqlalchemy/sql/_typing.py +++ b/lib/sqlalchemy/sql/_typing.py @@ -69,6 +69,7 @@ if TYPE_CHECKING: from .sqltypes import TableValueType from .sqltypes import TupleType from .type_api import TypeEngine + from ..engine import Dialect from ..util.typing import TypeGuard _T = TypeVar("_T", bound=Any) @@ -92,6 +93,15 @@ class _CoreAdapterProto(Protocol): def __call__(self, obj: _CE) -> _CE: ... +class _HasDialect(Protocol): + """protocol for Engine/Connection-like objects that have dialect + attribute. + """ + + @property + def dialect(self) -> Dialect: ... + + # match column types that are not ORM entities _NOT_ENTITY = TypeVar( "_NOT_ENTITY", diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index e5a9fb0624..9f0ed10a4c 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -83,6 +83,7 @@ if typing.TYPE_CHECKING: from ._typing import _ByArgument from ._typing import _ColumnExpressionArgument from ._typing import _ColumnExpressionOrStrLabelArgument + from ._typing import _HasDialect from ._typing import _InfoType from ._typing import _PropagateAttrsType from ._typing import _TypeEngineArgument @@ -107,7 +108,6 @@ if typing.TYPE_CHECKING: from .visitors import anon_map from ..engine import Connection from ..engine import Dialect - from ..engine import Engine from ..engine.interfaces import _CoreMultiExecuteParams from ..engine.interfaces import CacheStats from ..engine.interfaces import CompiledCacheType @@ -244,7 +244,7 @@ class CompilerElement(Visitable): @util.preload_module("sqlalchemy.engine.url") def compile( self, - bind: Optional[Union[Engine, Connection]] = None, + bind: Optional[_HasDialect] = None, dialect: Optional[Dialect] = None, **kw: Any, ) -> Compiled: @@ -776,7 +776,7 @@ class DQLDMLClauseElement(ClauseElement): def compile( # noqa: A001 self, - bind: Optional[Union[Engine, Connection]] = None, + bind: Optional[_HasDialect] = None, dialect: Optional[Dialect] = None, **kw: Any, ) -> SQLCompiler: ... diff --git a/test/ext/asyncio/test_engine_py3k.py b/test/ext/asyncio/test_engine_py3k.py index c12363f4d0..9fb12e6936 100644 --- a/test/ext/asyncio/test_engine_py3k.py +++ b/test/ext/asyncio/test_engine_py3k.py @@ -403,6 +403,13 @@ class AsyncEngineTest(EngineFixture): eq_(m.mock_calls, []) + @async_test + async def test_statement_compile(self, async_engine): + stmt = _select1(async_engine) + eq_(str(select(1).compile(async_engine)), stmt) + async with async_engine.connect() as conn: + eq_(str(select(1).compile(conn)), stmt) + def test_clear_compiled_cache(self, async_engine): async_engine.sync_engine._compiled_cache["foo"] = "bar" eq_(async_engine.sync_engine._compiled_cache["foo"], "bar") @@ -954,19 +961,13 @@ class AsyncEventTest(EngineFixture): ): event.listen(async_engine, "checkout", mock.Mock()) - def select1(self, engine): - if engine.dialect.name == "oracle": - return "select 1 from dual" - else: - return "select 1" - @async_test async def test_sync_before_cursor_execute_engine(self, async_engine): canary = mock.Mock() event.listen(async_engine.sync_engine, "before_cursor_execute", canary) - s1 = self.select1(async_engine) + s1 = _select1(async_engine) async with async_engine.connect() as conn: sync_conn = conn.sync_connection await conn.execute(text(s1)) @@ -980,7 +981,7 @@ class AsyncEventTest(EngineFixture): async def test_sync_before_cursor_execute_connection(self, async_engine): canary = mock.Mock() - s1 = self.select1(async_engine) + s1 = _select1(async_engine) async with async_engine.connect() as conn: sync_conn = conn.sync_connection @@ -1522,3 +1523,10 @@ class PoolRegenTest(EngineFixture): tasks = [thing(engine) for _ in range(10)] await asyncio.gather(*tasks) + + +def _select1(engine): + if engine.dialect.name == "oracle": + return "SELECT 1 FROM DUAL" + else: + return "SELECT 1" diff --git a/test/typing/plain_files/engine/engines.py b/test/typing/plain_files/engine/engines.py index 5777b91484..a204fb9182 100644 --- a/test/typing/plain_files/engine/engines.py +++ b/test/typing/plain_files/engine/engines.py @@ -1,5 +1,6 @@ from sqlalchemy import create_engine from sqlalchemy import Pool +from sqlalchemy import select from sqlalchemy import text @@ -30,5 +31,9 @@ def regular() -> None: engine = create_engine("postgresql://scott:tiger@localhost/test") status: str = engine.pool.status() other_pool: Pool = engine.pool.recreate() + ce = select(1).compile(e) + ce.statement + cc = select(1).compile(conn) + cc.statement print(status, other_pool) diff --git a/test/typing/plain_files/ext/asyncio/engines.py b/test/typing/plain_files/ext/asyncio/engines.py index 01475dc71e..1f7843082a 100644 --- a/test/typing/plain_files/ext/asyncio/engines.py +++ b/test/typing/plain_files/ext/asyncio/engines.py @@ -1,6 +1,7 @@ from typing import Any from sqlalchemy import Connection +from sqlalchemy import select from sqlalchemy import text from sqlalchemy.ext.asyncio import create_async_engine @@ -65,3 +66,8 @@ async def asyncio() -> None: # EXPECTED_MYPY: Missing positional argument "foo" in call to "run_sync" of "AsyncConnection" await conn.run_sync(work_sync) + + ce = select(1).compile(e) + ce.statement + cc = select(1).compile(conn) + cc.statement -- 2.47.2