from .sqltypes import TableValueType
from .sqltypes import TupleType
from .type_api import TypeEngine
+ from ..engine import Dialect
from ..util.typing import TypeGuard
_T = TypeVar("_T", bound=Any)
def __call__(self, obj: _CE) -> _CE: ...
+class _HasDialect(Protocol):
+ """protocol for Engine/Connection-like things that have dialect"""
+
+ dialect: Dialect
+
+
# match column types that are not ORM entities
_NOT_ENTITY = TypeVar(
"_NOT_ENTITY",
from ._typing import _ByArgument
from ._typing import _ColumnExpressionArgument
from ._typing import _ColumnExpressionOrStrLabelArgument
+ from ._typing import _HasDialect
from ._typing import _InfoType
from ._typing import _PropagateAttrsType
from ._typing import _TypeEngineArgument
from .visitors import anon_map
from ..engine import Connection
from ..engine import Dialect
- from ..engine import Engine
from ..engine.interfaces import _CoreMultiExecuteParams
from ..engine.interfaces import CacheStats
from ..engine.interfaces import CompiledCacheType
from ..engine.interfaces import CoreExecuteOptionsParameter
from ..engine.interfaces import SchemaTranslateMapType
from ..engine.result import Result
- from ..ext.asyncio import AsyncEngine
_NUMERIC = Union[float, Decimal]
_NUMBER = Union[float, int, Decimal]
@util.preload_module("sqlalchemy.engine.url")
def compile(
self,
- bind: Optional[Union[Engine, AsyncEngine, Connection]] = None,
+ bind: Optional[_HasDialect] = None,
dialect: Optional[Dialect] = None,
**kw: Any,
) -> Compiled:
def compile( # noqa: A001
self,
- bind: Optional[Union[Engine, AsyncEngine, Connection]] = None,
+ bind: Optional[_HasDialect] = None,
dialect: Optional[Dialect] = None,
**kw: Any,
) -> SQLCompiler: ...
eq_(m.mock_calls, [])
+ @async_test
+ async def test_statement_compile(self, async_engine):
+ eq_(str(select(1).compile(async_engine)), "SELECT 1")
+ async with async_engine.connect() as conn:
+ eq_(str(select(1).compile(conn)), "SELECT 1")
+
def test_clear_compiled_cache(self, async_engine):
async_engine.sync_engine._compiled_cache["foo"] = "bar"
eq_(async_engine.sync_engine._compiled_cache["foo"], "bar")