From ba9e61a3902d5666a5176aedd50afe8ae7762bff Mon Sep 17 00:00:00 2001 From: Eugene Toder Date: Thu, 7 Mar 2024 13:50:27 -0500 Subject: [PATCH] Use a protocol for dialect and add tests --- lib/sqlalchemy/sql/_typing.py | 7 +++++++ lib/sqlalchemy/sql/elements.py | 7 +++---- test/ext/asyncio/test_engine_py3k.py | 6 ++++++ 3 files changed, 16 insertions(+), 4 deletions(-) diff --git a/lib/sqlalchemy/sql/_typing.py b/lib/sqlalchemy/sql/_typing.py index 2b50f2bdab..23b0f7bc85 100644 --- a/lib/sqlalchemy/sql/_typing.py +++ b/lib/sqlalchemy/sql/_typing.py @@ -70,6 +70,7 @@ if TYPE_CHECKING: from .sqltypes import TableValueType from .sqltypes import TupleType from .type_api import TypeEngine + from ..engine import Dialect from ..util.typing import TypeGuard _T = TypeVar("_T", bound=Any) @@ -93,6 +94,12 @@ class _CoreAdapterProto(Protocol): def __call__(self, obj: _CE) -> _CE: ... +class _HasDialect(Protocol): + """protocol for Engine/Connection-like things that have dialect""" + + dialect: Dialect + + # match column types that are not ORM entities _NOT_ENTITY = TypeVar( "_NOT_ENTITY", diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 1afc211b75..8f10dd8d5c 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -85,6 +85,7 @@ if typing.TYPE_CHECKING: from ._typing import _ByArgument from ._typing import _ColumnExpressionArgument from ._typing import _ColumnExpressionOrStrLabelArgument + from ._typing import _HasDialect from ._typing import _InfoType from ._typing import _PropagateAttrsType from ._typing import _TypeEngineArgument @@ -109,14 +110,12 @@ if typing.TYPE_CHECKING: from .visitors import anon_map from ..engine import Connection from ..engine import Dialect - from ..engine import Engine from ..engine.interfaces import _CoreMultiExecuteParams from ..engine.interfaces import CacheStats from ..engine.interfaces import CompiledCacheType from ..engine.interfaces import CoreExecuteOptionsParameter from ..engine.interfaces import SchemaTranslateMapType from ..engine.result import Result - from ..ext.asyncio import AsyncEngine _NUMERIC = Union[float, Decimal] _NUMBER = Union[float, int, Decimal] @@ -247,7 +246,7 @@ class CompilerElement(Visitable): @util.preload_module("sqlalchemy.engine.url") def compile( self, - bind: Optional[Union[Engine, AsyncEngine, Connection]] = None, + bind: Optional[_HasDialect] = None, dialect: Optional[Dialect] = None, **kw: Any, ) -> Compiled: @@ -781,7 +780,7 @@ class DQLDMLClauseElement(ClauseElement): def compile( # noqa: A001 self, - bind: Optional[Union[Engine, AsyncEngine, Connection]] = None, + bind: Optional[_HasDialect] = None, dialect: Optional[Dialect] = None, **kw: Any, ) -> SQLCompiler: ... diff --git a/test/ext/asyncio/test_engine_py3k.py b/test/ext/asyncio/test_engine_py3k.py index c3d1e4835a..d1cfcdd48f 100644 --- a/test/ext/asyncio/test_engine_py3k.py +++ b/test/ext/asyncio/test_engine_py3k.py @@ -403,6 +403,12 @@ class AsyncEngineTest(EngineFixture): eq_(m.mock_calls, []) + @async_test + async def test_statement_compile(self, async_engine): + eq_(str(select(1).compile(async_engine)), "SELECT 1") + async with async_engine.connect() as conn: + eq_(str(select(1).compile(conn)), "SELECT 1") + def test_clear_compiled_cache(self, async_engine): async_engine.sync_engine._compiled_cache["foo"] = "bar" eq_(async_engine.sync_engine._compiled_cache["foo"], "bar") -- 2.47.2