]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Allow using AsyncEngine in compile
authorEugene Toder <eltoder@gmail.com>
Mon, 11 Mar 2024 11:42:47 +0000 (07:42 -0400)
committerFederico Caselli <cfederico87@gmail.com>
Mon, 11 Mar 2024 20:45:16 +0000 (21:45 +0100)
This works, so only need to update the type annotation.

This pull request is:

- [x] A documentation / typographical / small typing error fix
- Good to go, no issue or tests are needed

Closes: #11103
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/11103
Pull-request-sha: ba9e61a3902d5666a5176aedd50afe8ae7762bff

Change-Id: I3d08b930a8cae0539bf9b436d5e806d8912cdee0
(cherry picked from commit d2a743d0bcd88129f571f2256cd18f1b02036fd2)

lib/sqlalchemy/engine/base.py
lib/sqlalchemy/ext/asyncio/engine.py
lib/sqlalchemy/sql/_typing.py
lib/sqlalchemy/sql/elements.py
test/ext/asyncio/test_engine_py3k.py
test/typing/plain_files/engine/engines.py
test/typing/plain_files/ext/asyncio/engines.py

index 3c11d14d5b72c06dfd9634ccd289494dded0d61a..403ec452b9a2f2c947636c18c2529510722f3f83 100644 (file)
@@ -109,6 +109,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
 
     """
 
+    dialect: Dialect
     dispatch: dispatcher[ConnectionEventsTarget]
 
     _sqla_logger_namespace = "sqlalchemy.engine.Connection"
index 5d7d7e6b4253fb3d4b7c8fca3a3fea4be1422ba4..8fc8e96db063dc9ab15c52536d34572e79f0251b 100644 (file)
@@ -930,7 +930,7 @@ class AsyncConnection(
         return self._proxied.invalidated
 
     @property
-    def dialect(self) -> Any:
+    def dialect(self) -> Dialect:
         r"""Proxy for the :attr:`_engine.Connection.dialect` attribute
         on behalf of the :class:`_asyncio.AsyncConnection` class.
 
@@ -939,7 +939,7 @@ class AsyncConnection(
         return self._proxied.dialect
 
     @dialect.setter
-    def dialect(self, attr: Any) -> None:
+    def dialect(self, attr: Dialect) -> None:
         self._proxied.dialect = attr
 
     @property
index ea9cbe1f482fef453663269f514d8ae9a3104531..ba5faffd4d6bbba1f093f0ae90ac303e529ab1be 100644 (file)
@@ -69,6 +69,7 @@ if TYPE_CHECKING:
     from .sqltypes import TableValueType
     from .sqltypes import TupleType
     from .type_api import TypeEngine
+    from ..engine import Dialect
     from ..util.typing import TypeGuard
 
 _T = TypeVar("_T", bound=Any)
@@ -92,6 +93,15 @@ class _CoreAdapterProto(Protocol):
     def __call__(self, obj: _CE) -> _CE: ...
 
 
+class _HasDialect(Protocol):
+    """protocol for Engine/Connection-like objects that have dialect
+    attribute.
+    """
+
+    @property
+    def dialect(self) -> Dialect: ...
+
+
 # match column types that are not ORM entities
 _NOT_ENTITY = TypeVar(
     "_NOT_ENTITY",
index e5a9fb0624cce647c7dc15d5df00f26a8f0e6001..9f0ed10a4c9e686e7b36e36e8bdbf0a6735b2228 100644 (file)
@@ -83,6 +83,7 @@ if typing.TYPE_CHECKING:
     from ._typing import _ByArgument
     from ._typing import _ColumnExpressionArgument
     from ._typing import _ColumnExpressionOrStrLabelArgument
+    from ._typing import _HasDialect
     from ._typing import _InfoType
     from ._typing import _PropagateAttrsType
     from ._typing import _TypeEngineArgument
@@ -107,7 +108,6 @@ if typing.TYPE_CHECKING:
     from .visitors import anon_map
     from ..engine import Connection
     from ..engine import Dialect
-    from ..engine import Engine
     from ..engine.interfaces import _CoreMultiExecuteParams
     from ..engine.interfaces import CacheStats
     from ..engine.interfaces import CompiledCacheType
@@ -244,7 +244,7 @@ class CompilerElement(Visitable):
     @util.preload_module("sqlalchemy.engine.url")
     def compile(
         self,
-        bind: Optional[Union[Engine, Connection]] = None,
+        bind: Optional[_HasDialect] = None,
         dialect: Optional[Dialect] = None,
         **kw: Any,
     ) -> Compiled:
@@ -776,7 +776,7 @@ class DQLDMLClauseElement(ClauseElement):
 
         def compile(  # noqa: A001
             self,
-            bind: Optional[Union[Engine, Connection]] = None,
+            bind: Optional[_HasDialect] = None,
             dialect: Optional[Dialect] = None,
             **kw: Any,
         ) -> SQLCompiler: ...
index c12363f4d0b8c1850b0dc9ab2a402963a04f8a66..9fb12e6936f7e5372ea4c471aa5ba4658ed68452 100644 (file)
@@ -403,6 +403,13 @@ class AsyncEngineTest(EngineFixture):
 
             eq_(m.mock_calls, [])
 
+    @async_test
+    async def test_statement_compile(self, async_engine):
+        stmt = _select1(async_engine)
+        eq_(str(select(1).compile(async_engine)), stmt)
+        async with async_engine.connect() as conn:
+            eq_(str(select(1).compile(conn)), stmt)
+
     def test_clear_compiled_cache(self, async_engine):
         async_engine.sync_engine._compiled_cache["foo"] = "bar"
         eq_(async_engine.sync_engine._compiled_cache["foo"], "bar")
@@ -954,19 +961,13 @@ class AsyncEventTest(EngineFixture):
         ):
             event.listen(async_engine, "checkout", mock.Mock())
 
-    def select1(self, engine):
-        if engine.dialect.name == "oracle":
-            return "select 1 from dual"
-        else:
-            return "select 1"
-
     @async_test
     async def test_sync_before_cursor_execute_engine(self, async_engine):
         canary = mock.Mock()
 
         event.listen(async_engine.sync_engine, "before_cursor_execute", canary)
 
-        s1 = self.select1(async_engine)
+        s1 = _select1(async_engine)
         async with async_engine.connect() as conn:
             sync_conn = conn.sync_connection
             await conn.execute(text(s1))
@@ -980,7 +981,7 @@ class AsyncEventTest(EngineFixture):
     async def test_sync_before_cursor_execute_connection(self, async_engine):
         canary = mock.Mock()
 
-        s1 = self.select1(async_engine)
+        s1 = _select1(async_engine)
         async with async_engine.connect() as conn:
             sync_conn = conn.sync_connection
 
@@ -1522,3 +1523,10 @@ class PoolRegenTest(EngineFixture):
 
         tasks = [thing(engine) for _ in range(10)]
         await asyncio.gather(*tasks)
+
+
+def _select1(engine):
+    if engine.dialect.name == "oracle":
+        return "SELECT 1 FROM DUAL"
+    else:
+        return "SELECT 1"
index 5777b91484180aaa6784f53485c9b2729fb0f7e0..a204fb9182fd96903608e96cebeea252e67eac0f 100644 (file)
@@ -1,5 +1,6 @@
 from sqlalchemy import create_engine
 from sqlalchemy import Pool
+from sqlalchemy import select
 from sqlalchemy import text
 
 
@@ -30,5 +31,9 @@ def regular() -> None:
     engine = create_engine("postgresql://scott:tiger@localhost/test")
     status: str = engine.pool.status()
     other_pool: Pool = engine.pool.recreate()
+    ce = select(1).compile(e)
+    ce.statement
+    cc = select(1).compile(conn)
+    cc.statement
 
     print(status, other_pool)
index 01475dc71e594fcfc486ce518d6c8d3a1ecdfddd..1f7843082a915ba355b487f792d5ff0b12602f28 100644 (file)
@@ -1,6 +1,7 @@
 from typing import Any
 
 from sqlalchemy import Connection
+from sqlalchemy import select
 from sqlalchemy import text
 from sqlalchemy.ext.asyncio import create_async_engine
 
@@ -65,3 +66,8 @@ async def asyncio() -> None:
 
         # EXPECTED_MYPY: Missing positional argument "foo" in call to "run_sync" of "AsyncConnection"
         await conn.run_sync(work_sync)
+
+    ce = select(1).compile(e)
+    ce.statement
+    cc = select(1).compile(conn)
+    cc.statement