--- /dev/null
+.. change::
+ :tags: oracle, usecase
+ :tickets: 10820
+
+ Added API support for server-side cursors for the oracledb async dialect,
+ allowing use of the :meth:`_asyncio.AsyncConnection.stream` and similar
+ stream methods.
import collections
import sys
from typing import Any
+from typing import AsyncIterator
from typing import Deque
from typing import Iterator
from typing import NoReturn
async def nextset(self) -> Optional[bool]: ...
+ def __aiter__(self) -> AsyncIterator[Any]: ...
+
class AsyncAdapt_dbapi_cursor:
server_side = False
cursor = self._make_new_cursor(self._connection)
self._cursor = self._aenter_cursor(cursor)
- self._rows = collections.deque()
+ if not self.server_side:
+ self._rows = collections.deque()
def _aenter_cursor(self, cursor: AsyncIODBAPICursor) -> AsyncIODBAPICursor:
try:
def fetchall(self) -> Sequence[Any]:
return await_(self._cursor.fetchall())
+ def __iter__(self) -> Iterator[Any]:
+ iterator = self._cursor.__aiter__()
+ while True:
+ try:
+ yield await_(iterator.__anext__())
+ except StopAsyncIteration:
+ break
+
class AsyncAdapt_dbapi_connection(AdaptedConnection):
_cursor_cls = AsyncAdapt_dbapi_cursor
from typing import Any
from typing import TYPE_CHECKING
-from .cx_oracle import OracleDialect_cx_oracle as _OracleDialect_cx_oracle
+from . import cx_oracle as _cx_oracle
from ... import exc
from ...connectors.asyncio import AsyncAdapt_dbapi_connection
from ...connectors.asyncio import AsyncAdapt_dbapi_cursor
+from ...connectors.asyncio import AsyncAdapt_dbapi_ss_cursor
+from ...engine import default
from ...util import await_
if TYPE_CHECKING:
from oracledb import AsyncCursor
-class OracleDialect_oracledb(_OracleDialect_cx_oracle):
+class OracleExecutionContext_oracledb(
+ _cx_oracle.OracleExecutionContext_cx_oracle
+):
+ pass
+
+
+class OracleDialect_oracledb(_cx_oracle.OracleDialect_cx_oracle):
supports_statement_cache = True
+ execution_ctx_cls = OracleExecutionContext_oracledb
+
driver = "oracledb"
_min_version = (1,)
return await self._cursor.executemany(operation, seq_of_parameters)
+class AsyncAdapt_oracledb_ss_cursor(
+ AsyncAdapt_dbapi_ss_cursor, AsyncAdapt_oracledb_cursor
+):
+ __slots__ = ()
+
+ def close(self) -> None:
+ if self._cursor is not None:
+ self._cursor.close()
+ self._cursor = None # type: ignore
+
+
class AsyncAdapt_oracledb_connection(AsyncAdapt_dbapi_connection):
_connection: AsyncConnection
__slots__ = ()
def cursor(self):
return AsyncAdapt_oracledb_cursor(self)
+ def ss_cursor(self):
+ return AsyncAdapt_oracledb_ss_cursor(self)
+
def xid(self, *args: Any, **kwargs: Any) -> Any:
return self._connection.xid(*args, **kwargs)
)
+class OracleExecutionContextAsync_oracledb(OracleExecutionContext_oracledb):
+ # restore default create cursor
+ create_cursor = default.DefaultExecutionContext.create_cursor
+
+ def create_default_cursor(self):
+ # copy of OracleExecutionContext_cx_oracle.create_cursor
+ c = self._dbapi_connection.cursor()
+ if self.dialect.arraysize:
+ c.arraysize = self.dialect.arraysize
+
+ return c
+
+ def create_server_side_cursor(self):
+ c = self._dbapi_connection.ss_cursor()
+ if self.dialect.arraysize:
+ c.arraysize = self.dialect.arraysize
+
+ return c
+
+
class OracleDialectAsync_oracledb(OracleDialect_oracledb):
is_async = True
+ supports_server_side_cursors = True
supports_statement_cache = True
+ execution_ctx_cls = OracleExecutionContextAsync_oracledb
_min_version = (2,)
"_invalidate_schema_cache_asof",
)
- server_side = False
-
_adapt_connection: AsyncAdapt_asyncpg_connection
_connection: _AsyncpgConnection
_cursor: Optional[_AsyncpgCursor]
class AsyncAdapt_asyncpg_ss_cursor(
AsyncAdapt_dbapi_ss_cursor, AsyncAdapt_asyncpg_cursor
):
- server_side = True
__slots__ = ("_rowbuffer",)
def __init__(self, adapt_connection):
def _make_new_cursor(self, connection):
return connection.cursor(self.name)
- # TODO: should this be on the base asyncio adapter?
- def __iter__(self):
- iterator = self._cursor.__aiter__()
- while True:
- try:
- yield await_(iterator.__anext__())
- except StopAsyncIteration:
- break
-
class AsyncAdapt_psycopg_connection(AsyncAdapt_dbapi_connection):
_connection: AsyncConnection
# mypy: ignore-errors
import datetime
+import re
from .. import engines
from .. import fixtures
return getattr(cursor, "server_side", False)
elif self.engine.dialect.driver == "psycopg":
return bool(getattr(cursor, "name", False))
+ elif self.engine.dialect.driver == "oracledb":
+ return getattr(cursor, "server_side", False)
else:
return False
)
return self.engine
+ def stringify(self, str_):
+ return re.compile(r"SELECT (\d+)", re.I).sub(
+ lambda m: str(select(int(m.group(1))).compile(testing.db)), str_
+ )
+
@testing.combinations(
- ("global_string", True, "select 1", True),
- ("global_text", True, text("select 1"), True),
+ ("global_string", True, lambda stringify: stringify("select 1"), True),
+ (
+ "global_text",
+ True,
+ lambda stringify: text(stringify("select 1")),
+ True,
+ ),
("global_expr", True, select(1), True),
- ("global_off_explicit", False, text("select 1"), False),
+ (
+ "global_off_explicit",
+ False,
+ lambda stringify: text(stringify("select 1")),
+ False,
+ ),
(
"stmt_option",
False,
(
"for_update_string",
True,
- "SELECT 1 FOR UPDATE",
+ lambda stringify: stringify("SELECT 1 FOR UPDATE"),
True,
testing.skip_if(["sqlite", "mssql"]),
),
- ("text_no_ss", False, text("select 42"), False),
+ (
+ "text_no_ss",
+ False,
+ lambda stringify: text(stringify("select 42")),
+ False,
+ ),
(
"text_ss_option",
False,
- text("select 42").execution_options(stream_results=True),
+ lambda stringify: text(stringify("select 42")).execution_options(
+ stream_results=True
+ ),
True,
),
id_="iaaa",
):
engine = self._fixture(engine_ss_arg)
with engine.begin() as conn:
+ if callable(statement):
+ statement = testing.resolve_lambda(
+ statement, stringify=self.stringify
+ )
+
if isinstance(statement, str):
result = conn.exec_driver_sql(statement)
else:
# should be enabled for this one
result = conn.execution_options(
stream_results=True
- ).exec_driver_sql("select 1")
+ ).exec_driver_sql(self.stringify("select 1"))
assert self._is_server_side(result.cursor)
# the connection has autobegun, which means at the end of the
test_table = Table(
"test_table",
md,
- Column("id", Integer, primary_key=True),
+ Column(
+ "id", Integer, primary_key=True, test_needs_autoincrement=True
+ ),
Column("data", String(50)),
)
test_table = Table(
"test_table",
md,
- Column("id", Integer, primary_key=True),
+ Column(
+ "id", Integer, primary_key=True, test_needs_autoincrement=True
+ ),
Column("data", String(50)),
)
is_(fairy.connection, fairy.dbapi_connection)
-def select1(db):
- return str(select(1).compile(dialect=db.dialect))
-
-
class ResetEventTest(fixtures.TestBase):
def _fixture(self, **kw):
dbapi = Mock()
def test_new_exec_driver_sql_no_events(self):
m1 = Mock()
- def select1(db):
- return str(select(1).compile(dialect=db.dialect))
-
with testing.db.connect() as conn:
event.listen(conn, "before_execute", m1.before_execute)
event.listen(conn, "after_execute", m1.after_execute)
- conn.exec_driver_sql(select1(testing.db))
+ conn.exec_driver_sql(str(select(1).compile(testing.db)))
eq_(m1.mock_calls, [])
def test_add_event_after_connect(self, testing_engine):
from sqlalchemy import Table
from sqlalchemy import testing
from sqlalchemy import text
+from sqlalchemy import true
from sqlalchemy import union_all
from sqlalchemy.engine import cursor as _cursor
from sqlalchemy.ext.asyncio import async_engine_from_config
@async_test
async def test_statement_compile(self, async_engine):
- stmt = _select1(async_engine)
- eq_(str(select(1).compile(async_engine)), stmt)
+ stmt = str(select(1).compile(async_engine))
async with async_engine.connect() as conn:
eq_(str(select(1).compile(conn)), stmt)
event.listen(async_engine.sync_engine, "before_cursor_execute", canary)
- s1 = _select1(async_engine)
async with async_engine.connect() as conn:
sync_conn = conn.sync_connection
- await conn.execute(text(s1))
+ await conn.execute(select(1))
+ s1 = str(select(1).compile(async_engine))
eq_(
canary.mock_calls,
[mock.call(sync_conn, mock.ANY, s1, mock.ANY, mock.ANY, False)],
async def test_sync_before_cursor_execute_connection(self, async_engine):
canary = mock.Mock()
- s1 = _select1(async_engine)
async with async_engine.connect() as conn:
sync_conn = conn.sync_connection
event.listen(
async_engine.sync_engine, "before_cursor_execute", canary
)
- await conn.execute(text(s1))
+ await conn.execute(select(1))
+ s1 = str(select(1).compile(async_engine))
eq_(
canary.mock_calls,
[mock.call(sync_conn, mock.ANY, s1, mock.ANY, mock.ANY, False)],
):
await result.one()
- @testing.combinations(
- ("scalars",), ("stream_scalars",), argnames="filter_"
- )
+ @testing.combinations(("scalars",), ("stream_scalars",), argnames="case")
@async_test
- async def test_scalars(self, async_engine, filter_):
+ async def test_scalars(self, async_engine, case):
users = self.tables.users
async with async_engine.connect() as conn:
- if filter_ == "scalars":
+ if case == "scalars":
result = (await conn.scalars(select(users))).all()
- elif filter_ == "stream_scalars":
+ elif case == "stream_scalars":
result = await (await conn.stream_scalars(select(users))).all()
eq_(result, list(range(1, 20)))
+ @async_test
+ @testing.combinations(("stream",), ("stream_scalars",), argnames="case")
+ async def test_stream_fetch_many_not_complete(self, async_engine, case):
+ users = self.tables.users
+ big_query = select(users).join(users.alias("other"), true())
+ async with async_engine.connect() as conn:
+ if case == "stream":
+ result = await conn.stream(big_query)
+ elif case == "stream_scalars":
+ result = await conn.stream_scalars(big_query)
+
+ f1 = await result.fetchmany(5)
+ f2 = await result.fetchmany(10)
+ f3 = await result.fetchmany(7)
+ eq_(len(f1) + len(f2) + len(f3), 22)
+
+ res = await result.fetchall()
+ eq_(len(res), 19 * 19 - 22)
+
+ @async_test
+ @testing.combinations(("stream",), ("execute",), argnames="case")
+ async def test_cursor_close(self, async_engine, case):
+ users = self.tables.users
+ async with async_engine.connect() as conn:
+ if case == "stream":
+ result = await conn.stream(select(users))
+ cursor = result._real_result.cursor
+ elif case == "execute":
+ result = await conn.execute(select(users))
+ cursor = result.cursor
+
+ await conn.run_sync(lambda _: cursor.close())
+
class TextSyncDBAPI(fixtures.TestBase):
__requires__ = ("asyncio",)
async def thing(engine):
async with engine.connect() as conn:
- await conn.exec_driver_sql("select 1")
+ await conn.exec_driver_sql(str(select(1).compile(engine)))
if do_dispose:
await engine.dispose()
tasks = [thing(engine) for _ in range(10)]
await asyncio.gather(*tasks)
-
-
-def _select1(engine):
- if engine.dialect.name == "oracle":
- return "SELECT 1 FROM DUAL"
- else:
- return "SELECT 1"