]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Use a protocol for dialect and add tests 11103/head
authorEugene Toder <eltoder@gmail.com>
Thu, 7 Mar 2024 18:50:27 +0000 (13:50 -0500)
committerEugene Toder <eltoder@gmail.com>
Thu, 7 Mar 2024 18:57:16 +0000 (13:57 -0500)
lib/sqlalchemy/sql/_typing.py
lib/sqlalchemy/sql/elements.py
test/ext/asyncio/test_engine_py3k.py

index 2b50f2bdabed70bd71e6f9bcdb1a580c235b2e33..23b0f7bc85e36dc38584c3b61eb50f5749e76539 100644 (file)
@@ -70,6 +70,7 @@ if TYPE_CHECKING:
     from .sqltypes import TableValueType
     from .sqltypes import TupleType
     from .type_api import TypeEngine
+    from ..engine import Dialect
     from ..util.typing import TypeGuard
 
 _T = TypeVar("_T", bound=Any)
@@ -93,6 +94,12 @@ class _CoreAdapterProto(Protocol):
     def __call__(self, obj: _CE) -> _CE: ...
 
 
+class _HasDialect(Protocol):
+    """protocol for Engine/Connection-like things that have dialect"""
+
+    dialect: Dialect
+
+
 # match column types that are not ORM entities
 _NOT_ENTITY = TypeVar(
     "_NOT_ENTITY",
index 1afc211b750585690fd23f588e764fe3f13471ea..8f10dd8d5c143891bf9476f4e456b3f7e7e96ef2 100644 (file)
@@ -85,6 +85,7 @@ if typing.TYPE_CHECKING:
     from ._typing import _ByArgument
     from ._typing import _ColumnExpressionArgument
     from ._typing import _ColumnExpressionOrStrLabelArgument
+    from ._typing import _HasDialect
     from ._typing import _InfoType
     from ._typing import _PropagateAttrsType
     from ._typing import _TypeEngineArgument
@@ -109,14 +110,12 @@ if typing.TYPE_CHECKING:
     from .visitors import anon_map
     from ..engine import Connection
     from ..engine import Dialect
-    from ..engine import Engine
     from ..engine.interfaces import _CoreMultiExecuteParams
     from ..engine.interfaces import CacheStats
     from ..engine.interfaces import CompiledCacheType
     from ..engine.interfaces import CoreExecuteOptionsParameter
     from ..engine.interfaces import SchemaTranslateMapType
     from ..engine.result import Result
-    from ..ext.asyncio import AsyncEngine
 
 _NUMERIC = Union[float, Decimal]
 _NUMBER = Union[float, int, Decimal]
@@ -247,7 +246,7 @@ class CompilerElement(Visitable):
     @util.preload_module("sqlalchemy.engine.url")
     def compile(
         self,
-        bind: Optional[Union[Engine, AsyncEngine, Connection]] = None,
+        bind: Optional[_HasDialect] = None,
         dialect: Optional[Dialect] = None,
         **kw: Any,
     ) -> Compiled:
@@ -781,7 +780,7 @@ class DQLDMLClauseElement(ClauseElement):
 
         def compile(  # noqa: A001
             self,
-            bind: Optional[Union[Engine, AsyncEngine, Connection]] = None,
+            bind: Optional[_HasDialect] = None,
             dialect: Optional[Dialect] = None,
             **kw: Any,
         ) -> SQLCompiler: ...
index c3d1e4835a02f3287aa99a4e8bd1a0269ee8cc05..d1cfcdd48f24edbb5334929e861550222e5a2d89 100644 (file)
@@ -403,6 +403,12 @@ class AsyncEngineTest(EngineFixture):
 
             eq_(m.mock_calls, [])
 
+    @async_test
+    async def test_statement_compile(self, async_engine):
+        eq_(str(select(1).compile(async_engine)), "SELECT 1")
+        async with async_engine.connect() as conn:
+            eq_(str(select(1).compile(conn)), "SELECT 1")
+
     def test_clear_compiled_cache(self, async_engine):
         async_engine.sync_engine._compiled_cache["foo"] = "bar"
         eq_(async_engine.sync_engine._compiled_cache["foo"], "bar")