]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Allow using AsyncEngine in compile
authorEugene Toder <eltoder@gmail.com>
Mon, 11 Mar 2024 11:42:47 +0000 (07:42 -0400)
committerFederico Caselli <cfederico87@gmail.com>
Mon, 11 Mar 2024 20:42:01 +0000 (21:42 +0100)
This works, so only need to update the type annotation.

This pull request is:

- [x] A documentation / typographical / small typing error fix
- Good to go, no issue or tests are needed

Closes: #11103
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/11103
Pull-request-sha: ba9e61a3902d5666a5176aedd50afe8ae7762bff

Change-Id: I3d08b930a8cae0539bf9b436d5e806d8912cdee0

lib/sqlalchemy/engine/base.py
lib/sqlalchemy/ext/asyncio/engine.py
lib/sqlalchemy/sql/_typing.py
lib/sqlalchemy/sql/elements.py
test/ext/asyncio/test_engine_py3k.py
test/typing/plain_files/engine/engines.py
test/typing/plain_files/ext/asyncio/engines.py

index 63631bdbd739cfe68918c5ab158432a24a179ea2..a674c5902b60d6e665b38d753471a0d220962529 100644 (file)
@@ -113,6 +113,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
 
     """
 
+    dialect: Dialect
     dispatch: dispatcher[ConnectionEventsTarget]
 
     _sqla_logger_namespace = "sqlalchemy.engine.Connection"
index 2be452747edf97630c4bba93544cc5e5880451a3..16d14ef5dbe6effa06b4ee92f94ce614b6efd456 100644 (file)
@@ -933,7 +933,7 @@ class AsyncConnection(
         return self._proxied.invalidated
 
     @property
-    def dialect(self) -> Any:
+    def dialect(self) -> Dialect:
         r"""Proxy for the :attr:`_engine.Connection.dialect` attribute
         on behalf of the :class:`_asyncio.AsyncConnection` class.
 
@@ -942,7 +942,7 @@ class AsyncConnection(
         return self._proxied.dialect
 
     @dialect.setter
-    def dialect(self, attr: Any) -> None:
+    def dialect(self, attr: Dialect) -> None:
         self._proxied.dialect = attr
 
     @property
index 2b50f2bdabed70bd71e6f9bcdb1a580c235b2e33..570db02aacd7030b654a6a69873d8f027e4dae29 100644 (file)
@@ -70,6 +70,7 @@ if TYPE_CHECKING:
     from .sqltypes import TableValueType
     from .sqltypes import TupleType
     from .type_api import TypeEngine
+    from ..engine import Dialect
     from ..util.typing import TypeGuard
 
 _T = TypeVar("_T", bound=Any)
@@ -93,6 +94,15 @@ class _CoreAdapterProto(Protocol):
     def __call__(self, obj: _CE) -> _CE: ...
 
 
+class _HasDialect(Protocol):
+    """protocol for Engine/Connection-like objects that have dialect
+    attribute.
+    """
+
+    @property
+    def dialect(self) -> Dialect: ...
+
+
 # match column types that are not ORM entities
 _NOT_ENTITY = TypeVar(
     "_NOT_ENTITY",
index 98f45d9dbf76ba745cc8a2820624d588529baa48..8f10dd8d5c143891bf9476f4e456b3f7e7e96ef2 100644 (file)
@@ -85,6 +85,7 @@ if typing.TYPE_CHECKING:
     from ._typing import _ByArgument
     from ._typing import _ColumnExpressionArgument
     from ._typing import _ColumnExpressionOrStrLabelArgument
+    from ._typing import _HasDialect
     from ._typing import _InfoType
     from ._typing import _PropagateAttrsType
     from ._typing import _TypeEngineArgument
@@ -109,7 +110,6 @@ if typing.TYPE_CHECKING:
     from .visitors import anon_map
     from ..engine import Connection
     from ..engine import Dialect
-    from ..engine import Engine
     from ..engine.interfaces import _CoreMultiExecuteParams
     from ..engine.interfaces import CacheStats
     from ..engine.interfaces import CompiledCacheType
@@ -246,7 +246,7 @@ class CompilerElement(Visitable):
     @util.preload_module("sqlalchemy.engine.url")
     def compile(
         self,
-        bind: Optional[Union[Engine, Connection]] = None,
+        bind: Optional[_HasDialect] = None,
         dialect: Optional[Dialect] = None,
         **kw: Any,
     ) -> Compiled:
@@ -780,7 +780,7 @@ class DQLDMLClauseElement(ClauseElement):
 
         def compile(  # noqa: A001
             self,
-            bind: Optional[Union[Engine, Connection]] = None,
+            bind: Optional[_HasDialect] = None,
             dialect: Optional[Dialect] = None,
             **kw: Any,
         ) -> SQLCompiler: ...
index c3d1e4835a02f3287aa99a4e8bd1a0269ee8cc05..ee5953636d447a56d7107ca6747825612965fe57 100644 (file)
@@ -403,6 +403,13 @@ class AsyncEngineTest(EngineFixture):
 
             eq_(m.mock_calls, [])
 
+    @async_test
+    async def test_statement_compile(self, async_engine):
+        stmt = _select1(async_engine)
+        eq_(str(select(1).compile(async_engine)), stmt)
+        async with async_engine.connect() as conn:
+            eq_(str(select(1).compile(conn)), stmt)
+
     def test_clear_compiled_cache(self, async_engine):
         async_engine.sync_engine._compiled_cache["foo"] = "bar"
         eq_(async_engine.sync_engine._compiled_cache["foo"], "bar")
@@ -954,19 +961,13 @@ class AsyncEventTest(EngineFixture):
         ):
             event.listen(async_engine, "checkout", mock.Mock())
 
-    def select1(self, engine):
-        if engine.dialect.name == "oracle":
-            return "select 1 from dual"
-        else:
-            return "select 1"
-
     @async_test
     async def test_sync_before_cursor_execute_engine(self, async_engine):
         canary = mock.Mock()
 
         event.listen(async_engine.sync_engine, "before_cursor_execute", canary)
 
-        s1 = self.select1(async_engine)
+        s1 = _select1(async_engine)
         async with async_engine.connect() as conn:
             sync_conn = conn.sync_connection
             await conn.execute(text(s1))
@@ -980,7 +981,7 @@ class AsyncEventTest(EngineFixture):
     async def test_sync_before_cursor_execute_connection(self, async_engine):
         canary = mock.Mock()
 
-        s1 = self.select1(async_engine)
+        s1 = _select1(async_engine)
         async with async_engine.connect() as conn:
             sync_conn = conn.sync_connection
 
@@ -1522,3 +1523,10 @@ class PoolRegenTest(EngineFixture):
 
         tasks = [thing(engine) for _ in range(10)]
         await asyncio.gather(*tasks)
+
+
+def _select1(engine):
+    if engine.dialect.name == "oracle":
+        return "SELECT 1 FROM DUAL"
+    else:
+        return "SELECT 1"
index 7d56c51a5bb7023c9106820f98b2451a45257bd3..15aa774e6aed891b1026e9058321425f96735fb0 100644 (file)
@@ -1,5 +1,6 @@
 from sqlalchemy import create_engine
 from sqlalchemy import Pool
+from sqlalchemy import select
 from sqlalchemy import text
 
 
@@ -30,5 +31,9 @@ def regular() -> None:
     engine = create_engine("postgresql://scott:tiger@localhost/test")
     status: str = engine.pool.status()
     other_pool: Pool = engine.pool.recreate()
+    ce = select(1).compile(e)
+    ce.statement
+    cc = select(1).compile(conn)
+    cc.statement
 
     print(status, other_pool)
index 1b13ff1e9524ad1216b3a8ab37a5d3968bca8855..df4b0a0f645aef26cebf647db94bd06c6ab9bfe8 100644 (file)
@@ -1,6 +1,7 @@
 from typing import Any
 
 from sqlalchemy import Connection
+from sqlalchemy import select
 from sqlalchemy import text
 from sqlalchemy.ext.asyncio import create_async_engine
 
@@ -65,3 +66,8 @@ async def asyncio() -> None:
 
         # EXPECTED_MYPY: Missing positional argument "foo" in call to "run_sync" of "AsyncConnection"
         await conn.run_sync(work_sync)
+
+    ce = select(1).compile(e)
+    ce.statement
+    cc = select(1).compile(conn)
+    cc.statement