From: Eugene Toder Date: Mon, 11 Mar 2024 11:42:47 +0000 (-0400) Subject: Allow using AsyncEngine in compile X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=d2a743d0bcd88129f571f2256cd18f1b02036fd2;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Allow using AsyncEngine in compile This works, so only need to update the type annotation. This pull request is: - [x] A documentation / typographical / small typing error fix - Good to go, no issue or tests are needed Closes: #11103 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/11103 Pull-request-sha: ba9e61a3902d5666a5176aedd50afe8ae7762bff Change-Id: I3d08b930a8cae0539bf9b436d5e806d8912cdee0 --- diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 63631bdbd7..a674c5902b 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -113,6 +113,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): """ + dialect: Dialect dispatch: dispatcher[ConnectionEventsTarget] _sqla_logger_namespace = "sqlalchemy.engine.Connection" diff --git a/lib/sqlalchemy/ext/asyncio/engine.py b/lib/sqlalchemy/ext/asyncio/engine.py index 2be452747e..16d14ef5db 100644 --- a/lib/sqlalchemy/ext/asyncio/engine.py +++ b/lib/sqlalchemy/ext/asyncio/engine.py @@ -933,7 +933,7 @@ class AsyncConnection( return self._proxied.invalidated @property - def dialect(self) -> Any: + def dialect(self) -> Dialect: r"""Proxy for the :attr:`_engine.Connection.dialect` attribute on behalf of the :class:`_asyncio.AsyncConnection` class. @@ -942,7 +942,7 @@ class AsyncConnection( return self._proxied.dialect @dialect.setter - def dialect(self, attr: Any) -> None: + def dialect(self, attr: Dialect) -> None: self._proxied.dialect = attr @property diff --git a/lib/sqlalchemy/sql/_typing.py b/lib/sqlalchemy/sql/_typing.py index 2b50f2bdab..570db02aac 100644 --- a/lib/sqlalchemy/sql/_typing.py +++ b/lib/sqlalchemy/sql/_typing.py @@ -70,6 +70,7 @@ if TYPE_CHECKING: from .sqltypes import TableValueType from .sqltypes import TupleType from .type_api import TypeEngine + from ..engine import Dialect from ..util.typing import TypeGuard _T = TypeVar("_T", bound=Any) @@ -93,6 +94,15 @@ class _CoreAdapterProto(Protocol): def __call__(self, obj: _CE) -> _CE: ... +class _HasDialect(Protocol): + """protocol for Engine/Connection-like objects that have dialect + attribute. + """ + + @property + def dialect(self) -> Dialect: ... + + # match column types that are not ORM entities _NOT_ENTITY = TypeVar( "_NOT_ENTITY", diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 98f45d9dbf..8f10dd8d5c 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -85,6 +85,7 @@ if typing.TYPE_CHECKING: from ._typing import _ByArgument from ._typing import _ColumnExpressionArgument from ._typing import _ColumnExpressionOrStrLabelArgument + from ._typing import _HasDialect from ._typing import _InfoType from ._typing import _PropagateAttrsType from ._typing import _TypeEngineArgument @@ -109,7 +110,6 @@ if typing.TYPE_CHECKING: from .visitors import anon_map from ..engine import Connection from ..engine import Dialect - from ..engine import Engine from ..engine.interfaces import _CoreMultiExecuteParams from ..engine.interfaces import CacheStats from ..engine.interfaces import CompiledCacheType @@ -246,7 +246,7 @@ class CompilerElement(Visitable): @util.preload_module("sqlalchemy.engine.url") def compile( self, - bind: Optional[Union[Engine, Connection]] = None, + bind: Optional[_HasDialect] = None, dialect: Optional[Dialect] = None, **kw: Any, ) -> Compiled: @@ -780,7 +780,7 @@ class DQLDMLClauseElement(ClauseElement): def compile( # noqa: A001 self, - bind: Optional[Union[Engine, Connection]] = None, + bind: Optional[_HasDialect] = None, dialect: Optional[Dialect] = None, **kw: Any, ) -> SQLCompiler: ... diff --git a/test/ext/asyncio/test_engine_py3k.py b/test/ext/asyncio/test_engine_py3k.py index c3d1e4835a..ee5953636d 100644 --- a/test/ext/asyncio/test_engine_py3k.py +++ b/test/ext/asyncio/test_engine_py3k.py @@ -403,6 +403,13 @@ class AsyncEngineTest(EngineFixture): eq_(m.mock_calls, []) + @async_test + async def test_statement_compile(self, async_engine): + stmt = _select1(async_engine) + eq_(str(select(1).compile(async_engine)), stmt) + async with async_engine.connect() as conn: + eq_(str(select(1).compile(conn)), stmt) + def test_clear_compiled_cache(self, async_engine): async_engine.sync_engine._compiled_cache["foo"] = "bar" eq_(async_engine.sync_engine._compiled_cache["foo"], "bar") @@ -954,19 +961,13 @@ class AsyncEventTest(EngineFixture): ): event.listen(async_engine, "checkout", mock.Mock()) - def select1(self, engine): - if engine.dialect.name == "oracle": - return "select 1 from dual" - else: - return "select 1" - @async_test async def test_sync_before_cursor_execute_engine(self, async_engine): canary = mock.Mock() event.listen(async_engine.sync_engine, "before_cursor_execute", canary) - s1 = self.select1(async_engine) + s1 = _select1(async_engine) async with async_engine.connect() as conn: sync_conn = conn.sync_connection await conn.execute(text(s1)) @@ -980,7 +981,7 @@ class AsyncEventTest(EngineFixture): async def test_sync_before_cursor_execute_connection(self, async_engine): canary = mock.Mock() - s1 = self.select1(async_engine) + s1 = _select1(async_engine) async with async_engine.connect() as conn: sync_conn = conn.sync_connection @@ -1522,3 +1523,10 @@ class PoolRegenTest(EngineFixture): tasks = [thing(engine) for _ in range(10)] await asyncio.gather(*tasks) + + +def _select1(engine): + if engine.dialect.name == "oracle": + return "SELECT 1 FROM DUAL" + else: + return "SELECT 1" diff --git a/test/typing/plain_files/engine/engines.py b/test/typing/plain_files/engine/engines.py index 7d56c51a5b..15aa774e6a 100644 --- a/test/typing/plain_files/engine/engines.py +++ b/test/typing/plain_files/engine/engines.py @@ -1,5 +1,6 @@ from sqlalchemy import create_engine from sqlalchemy import Pool +from sqlalchemy import select from sqlalchemy import text @@ -30,5 +31,9 @@ def regular() -> None: engine = create_engine("postgresql://scott:tiger@localhost/test") status: str = engine.pool.status() other_pool: Pool = engine.pool.recreate() + ce = select(1).compile(e) + ce.statement + cc = select(1).compile(conn) + cc.statement print(status, other_pool) diff --git a/test/typing/plain_files/ext/asyncio/engines.py b/test/typing/plain_files/ext/asyncio/engines.py index 1b13ff1e95..df4b0a0f64 100644 --- a/test/typing/plain_files/ext/asyncio/engines.py +++ b/test/typing/plain_files/ext/asyncio/engines.py @@ -1,6 +1,7 @@ from typing import Any from sqlalchemy import Connection +from sqlalchemy import select from sqlalchemy import text from sqlalchemy.ext.asyncio import create_async_engine @@ -65,3 +66,8 @@ async def asyncio() -> None: # EXPECTED_MYPY: Missing positional argument "foo" in call to "run_sync" of "AsyncConnection" await conn.run_sync(work_sync) + + ce = select(1).compile(e) + ce.statement + cc = select(1).compile(conn) + cc.statement