From: Federico Caselli Date: Thu, 1 Aug 2024 19:16:20 +0000 (+0200) Subject: Added support for server-side cursor in oracledb async dialect. X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=ffb2e2d033f8e227b80ba3c5d06c67a96310e1ec;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Added support for server-side cursor in oracledb async dialect. Added API support for server-side cursors for the oracledb async dialect, allowing use of the :meth:`_asyncio.AsyncConnection.stream` and similar stream methods. Fixes: #10820 Change-Id: I861670ccc20a81ec5ee45132b8059fc2a0359087 --- diff --git a/doc/build/changelog/unreleased_20/10820.rst b/doc/build/changelog/unreleased_20/10820.rst new file mode 100644 index 0000000000..e2cc717e2e --- /dev/null +++ b/doc/build/changelog/unreleased_20/10820.rst @@ -0,0 +1,7 @@ +.. change:: + :tags: oracle, usecase + :tickets: 10820 + + Added API support for server-side cursors for the oracledb async dialect, + allowing use of the :meth:`_asyncio.AsyncConnection.stream` and similar + stream methods. diff --git a/lib/sqlalchemy/connectors/asyncio.py b/lib/sqlalchemy/connectors/asyncio.py index 34820facb6..27d438cda2 100644 --- a/lib/sqlalchemy/connectors/asyncio.py +++ b/lib/sqlalchemy/connectors/asyncio.py @@ -13,6 +13,7 @@ import asyncio import collections import sys from typing import Any +from typing import AsyncIterator from typing import Deque from typing import Iterator from typing import NoReturn @@ -97,6 +98,8 @@ class AsyncIODBAPICursor(Protocol): async def nextset(self) -> Optional[bool]: ... + def __aiter__(self) -> AsyncIterator[Any]: ... + class AsyncAdapt_dbapi_cursor: server_side = False @@ -119,7 +122,8 @@ class AsyncAdapt_dbapi_cursor: cursor = self._make_new_cursor(self._connection) self._cursor = self._aenter_cursor(cursor) - self._rows = collections.deque() + if not self.server_side: + self._rows = collections.deque() def _aenter_cursor(self, cursor: AsyncIODBAPICursor) -> AsyncIODBAPICursor: try: @@ -258,6 +262,14 @@ class AsyncAdapt_dbapi_ss_cursor(AsyncAdapt_dbapi_cursor): def fetchall(self) -> Sequence[Any]: return await_(self._cursor.fetchall()) + def __iter__(self) -> Iterator[Any]: + iterator = self._cursor.__aiter__() + while True: + try: + yield await_(iterator.__anext__()) + except StopAsyncIteration: + break + class AsyncAdapt_dbapi_connection(AdaptedConnection): _cursor_cls = AsyncAdapt_dbapi_cursor diff --git a/lib/sqlalchemy/dialects/oracle/oracledb.py b/lib/sqlalchemy/dialects/oracle/oracledb.py index e48dcdc6bb..377310f642 100644 --- a/lib/sqlalchemy/dialects/oracle/oracledb.py +++ b/lib/sqlalchemy/dialects/oracle/oracledb.py @@ -94,10 +94,12 @@ import re from typing import Any from typing import TYPE_CHECKING -from .cx_oracle import OracleDialect_cx_oracle as _OracleDialect_cx_oracle +from . import cx_oracle as _cx_oracle from ... import exc from ...connectors.asyncio import AsyncAdapt_dbapi_connection from ...connectors.asyncio import AsyncAdapt_dbapi_cursor +from ...connectors.asyncio import AsyncAdapt_dbapi_ss_cursor +from ...engine import default from ...util import await_ if TYPE_CHECKING: @@ -105,8 +107,16 @@ if TYPE_CHECKING: from oracledb import AsyncCursor -class OracleDialect_oracledb(_OracleDialect_cx_oracle): +class OracleExecutionContext_oracledb( + _cx_oracle.OracleExecutionContext_cx_oracle +): + pass + + +class OracleDialect_oracledb(_cx_oracle.OracleDialect_cx_oracle): supports_statement_cache = True + execution_ctx_cls = OracleExecutionContext_oracledb + driver = "oracledb" _min_version = (1,) @@ -257,6 +267,17 @@ class AsyncAdapt_oracledb_cursor(AsyncAdapt_dbapi_cursor): return await self._cursor.executemany(operation, seq_of_parameters) +class AsyncAdapt_oracledb_ss_cursor( + AsyncAdapt_dbapi_ss_cursor, AsyncAdapt_oracledb_cursor +): + __slots__ = () + + def close(self) -> None: + if self._cursor is not None: + self._cursor.close() + self._cursor = None # type: ignore + + class AsyncAdapt_oracledb_connection(AsyncAdapt_dbapi_connection): _connection: AsyncConnection __slots__ = () @@ -297,6 +318,9 @@ class AsyncAdapt_oracledb_connection(AsyncAdapt_dbapi_connection): def cursor(self): return AsyncAdapt_oracledb_cursor(self) + def ss_cursor(self): + return AsyncAdapt_oracledb_ss_cursor(self) + def xid(self, *args: Any, **kwargs: Any) -> Any: return self._connection.xid(*args, **kwargs) @@ -331,9 +355,31 @@ class OracledbAdaptDBAPI: ) +class OracleExecutionContextAsync_oracledb(OracleExecutionContext_oracledb): + # restore default create cursor + create_cursor = default.DefaultExecutionContext.create_cursor + + def create_default_cursor(self): + # copy of OracleExecutionContext_cx_oracle.create_cursor + c = self._dbapi_connection.cursor() + if self.dialect.arraysize: + c.arraysize = self.dialect.arraysize + + return c + + def create_server_side_cursor(self): + c = self._dbapi_connection.ss_cursor() + if self.dialect.arraysize: + c.arraysize = self.dialect.arraysize + + return c + + class OracleDialectAsync_oracledb(OracleDialect_oracledb): is_async = True + supports_server_side_cursors = True supports_statement_cache = True + execution_ctx_cls = OracleExecutionContextAsync_oracledb _min_version = (2,) diff --git a/lib/sqlalchemy/dialects/postgresql/asyncpg.py b/lib/sqlalchemy/dialects/postgresql/asyncpg.py index 66cdeb8463..cb6b75154f 100644 --- a/lib/sqlalchemy/dialects/postgresql/asyncpg.py +++ b/lib/sqlalchemy/dialects/postgresql/asyncpg.py @@ -520,8 +520,6 @@ class AsyncAdapt_asyncpg_cursor(AsyncAdapt_dbapi_cursor): "_invalidate_schema_cache_asof", ) - server_side = False - _adapt_connection: AsyncAdapt_asyncpg_connection _connection: _AsyncpgConnection _cursor: Optional[_AsyncpgCursor] @@ -636,7 +634,6 @@ class AsyncAdapt_asyncpg_cursor(AsyncAdapt_dbapi_cursor): class AsyncAdapt_asyncpg_ss_cursor( AsyncAdapt_dbapi_ss_cursor, AsyncAdapt_asyncpg_cursor ): - server_side = True __slots__ = ("_rowbuffer",) def __init__(self, adapt_connection): diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg.py b/lib/sqlalchemy/dialects/postgresql/psycopg.py index 5bdae1703a..a1fdce1b46 100644 --- a/lib/sqlalchemy/dialects/postgresql/psycopg.py +++ b/lib/sqlalchemy/dialects/postgresql/psycopg.py @@ -611,15 +611,6 @@ class AsyncAdapt_psycopg_ss_cursor( def _make_new_cursor(self, connection): return connection.cursor(self.name) - # TODO: should this be on the base asyncio adapter? - def __iter__(self): - iterator = self._cursor.__aiter__() - while True: - try: - yield await_(iterator.__anext__()) - except StopAsyncIteration: - break - class AsyncAdapt_psycopg_connection(AsyncAdapt_dbapi_connection): _connection: AsyncConnection diff --git a/lib/sqlalchemy/testing/suite/test_results.py b/lib/sqlalchemy/testing/suite/test_results.py index 639a5d056b..7d1565bba3 100644 --- a/lib/sqlalchemy/testing/suite/test_results.py +++ b/lib/sqlalchemy/testing/suite/test_results.py @@ -7,6 +7,7 @@ # mypy: ignore-errors import datetime +import re from .. import engines from .. import fixtures @@ -429,6 +430,8 @@ class ServerSideCursorsTest( return getattr(cursor, "server_side", False) elif self.engine.dialect.driver == "psycopg": return bool(getattr(cursor, "name", False)) + elif self.engine.dialect.driver == "oracledb": + return getattr(cursor, "server_side", False) else: return False @@ -449,11 +452,26 @@ class ServerSideCursorsTest( ) return self.engine + def stringify(self, str_): + return re.compile(r"SELECT (\d+)", re.I).sub( + lambda m: str(select(int(m.group(1))).compile(testing.db)), str_ + ) + @testing.combinations( - ("global_string", True, "select 1", True), - ("global_text", True, text("select 1"), True), + ("global_string", True, lambda stringify: stringify("select 1"), True), + ( + "global_text", + True, + lambda stringify: text(stringify("select 1")), + True, + ), ("global_expr", True, select(1), True), - ("global_off_explicit", False, text("select 1"), False), + ( + "global_off_explicit", + False, + lambda stringify: text(stringify("select 1")), + False, + ), ( "stmt_option", False, @@ -471,15 +489,22 @@ class ServerSideCursorsTest( ( "for_update_string", True, - "SELECT 1 FOR UPDATE", + lambda stringify: stringify("SELECT 1 FOR UPDATE"), True, testing.skip_if(["sqlite", "mssql"]), ), - ("text_no_ss", False, text("select 42"), False), + ( + "text_no_ss", + False, + lambda stringify: text(stringify("select 42")), + False, + ), ( "text_ss_option", False, - text("select 42").execution_options(stream_results=True), + lambda stringify: text(stringify("select 42")).execution_options( + stream_results=True + ), True, ), id_="iaaa", @@ -490,6 +515,11 @@ class ServerSideCursorsTest( ): engine = self._fixture(engine_ss_arg) with engine.begin() as conn: + if callable(statement): + statement = testing.resolve_lambda( + statement, stringify=self.stringify + ) + if isinstance(statement, str): result = conn.exec_driver_sql(statement) else: @@ -504,7 +534,7 @@ class ServerSideCursorsTest( # should be enabled for this one result = conn.execution_options( stream_results=True - ).exec_driver_sql("select 1") + ).exec_driver_sql(self.stringify("select 1")) assert self._is_server_side(result.cursor) # the connection has autobegun, which means at the end of the @@ -558,7 +588,9 @@ class ServerSideCursorsTest( test_table = Table( "test_table", md, - Column("id", Integer, primary_key=True), + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), Column("data", String(50)), ) @@ -598,7 +630,9 @@ class ServerSideCursorsTest( test_table = Table( "test_table", md, - Column("id", Integer, primary_key=True), + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), Column("data", String(50)), ) diff --git a/test/engine/test_deprecations.py b/test/engine/test_deprecations.py index f6fa21f29d..a4a6f1f47c 100644 --- a/test/engine/test_deprecations.py +++ b/test/engine/test_deprecations.py @@ -300,10 +300,6 @@ class PoolTest(fixtures.TestBase): is_(fairy.connection, fairy.dbapi_connection) -def select1(db): - return str(select(1).compile(dialect=db.dialect)) - - class ResetEventTest(fixtures.TestBase): def _fixture(self, **kw): dbapi = Mock() diff --git a/test/engine/test_execute.py b/test/engine/test_execute.py index 31a9c4a70a..148d0be1a2 100644 --- a/test/engine/test_execute.py +++ b/test/engine/test_execute.py @@ -1964,13 +1964,10 @@ class EngineEventsTest(fixtures.TestBase): def test_new_exec_driver_sql_no_events(self): m1 = Mock() - def select1(db): - return str(select(1).compile(dialect=db.dialect)) - with testing.db.connect() as conn: event.listen(conn, "before_execute", m1.before_execute) event.listen(conn, "after_execute", m1.after_execute) - conn.exec_driver_sql(select1(testing.db)) + conn.exec_driver_sql(str(select(1).compile(testing.db))) eq_(m1.mock_calls, []) def test_add_event_after_connect(self, testing_engine): diff --git a/test/ext/asyncio/test_engine_py3k.py b/test/ext/asyncio/test_engine_py3k.py index ee5953636d..60edbf608d 100644 --- a/test/ext/asyncio/test_engine_py3k.py +++ b/test/ext/asyncio/test_engine_py3k.py @@ -21,6 +21,7 @@ from sqlalchemy import String from sqlalchemy import Table from sqlalchemy import testing from sqlalchemy import text +from sqlalchemy import true from sqlalchemy import union_all from sqlalchemy.engine import cursor as _cursor from sqlalchemy.ext.asyncio import async_engine_from_config @@ -405,8 +406,7 @@ class AsyncEngineTest(EngineFixture): @async_test async def test_statement_compile(self, async_engine): - stmt = _select1(async_engine) - eq_(str(select(1).compile(async_engine)), stmt) + stmt = str(select(1).compile(async_engine)) async with async_engine.connect() as conn: eq_(str(select(1).compile(conn)), stmt) @@ -967,11 +967,11 @@ class AsyncEventTest(EngineFixture): event.listen(async_engine.sync_engine, "before_cursor_execute", canary) - s1 = _select1(async_engine) async with async_engine.connect() as conn: sync_conn = conn.sync_connection - await conn.execute(text(s1)) + await conn.execute(select(1)) + s1 = str(select(1).compile(async_engine)) eq_( canary.mock_calls, [mock.call(sync_conn, mock.ANY, s1, mock.ANY, mock.ANY, False)], @@ -981,15 +981,15 @@ class AsyncEventTest(EngineFixture): async def test_sync_before_cursor_execute_connection(self, async_engine): canary = mock.Mock() - s1 = _select1(async_engine) async with async_engine.connect() as conn: sync_conn = conn.sync_connection event.listen( async_engine.sync_engine, "before_cursor_execute", canary ) - await conn.execute(text(s1)) + await conn.execute(select(1)) + s1 = str(select(1).compile(async_engine)) eq_( canary.mock_calls, [mock.call(sync_conn, mock.ANY, s1, mock.ANY, mock.ANY, False)], @@ -1331,20 +1331,51 @@ class AsyncResultTest(EngineFixture): ): await result.one() - @testing.combinations( - ("scalars",), ("stream_scalars",), argnames="filter_" - ) + @testing.combinations(("scalars",), ("stream_scalars",), argnames="case") @async_test - async def test_scalars(self, async_engine, filter_): + async def test_scalars(self, async_engine, case): users = self.tables.users async with async_engine.connect() as conn: - if filter_ == "scalars": + if case == "scalars": result = (await conn.scalars(select(users))).all() - elif filter_ == "stream_scalars": + elif case == "stream_scalars": result = await (await conn.stream_scalars(select(users))).all() eq_(result, list(range(1, 20))) + @async_test + @testing.combinations(("stream",), ("stream_scalars",), argnames="case") + async def test_stream_fetch_many_not_complete(self, async_engine, case): + users = self.tables.users + big_query = select(users).join(users.alias("other"), true()) + async with async_engine.connect() as conn: + if case == "stream": + result = await conn.stream(big_query) + elif case == "stream_scalars": + result = await conn.stream_scalars(big_query) + + f1 = await result.fetchmany(5) + f2 = await result.fetchmany(10) + f3 = await result.fetchmany(7) + eq_(len(f1) + len(f2) + len(f3), 22) + + res = await result.fetchall() + eq_(len(res), 19 * 19 - 22) + + @async_test + @testing.combinations(("stream",), ("execute",), argnames="case") + async def test_cursor_close(self, async_engine, case): + users = self.tables.users + async with async_engine.connect() as conn: + if case == "stream": + result = await conn.stream(select(users)) + cursor = result._real_result.cursor + elif case == "execute": + result = await conn.execute(select(users)) + cursor = result.cursor + + await conn.run_sync(lambda _: cursor.close()) + class TextSyncDBAPI(fixtures.TestBase): __requires__ = ("asyncio",) @@ -1516,17 +1547,10 @@ class PoolRegenTest(EngineFixture): async def thing(engine): async with engine.connect() as conn: - await conn.exec_driver_sql("select 1") + await conn.exec_driver_sql(str(select(1).compile(engine))) if do_dispose: await engine.dispose() tasks = [thing(engine) for _ in range(10)] await asyncio.gather(*tasks) - - -def _select1(engine): - if engine.dialect.name == "oracle": - return "SELECT 1 FROM DUAL" - else: - return "SELECT 1"