From: Federico Caselli Date: Thu, 1 Aug 2024 19:16:20 +0000 (+0200) Subject: Added support for server-side cursor in oracledb async dialect. X-Git-Tag: rel_2_0_32~1 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=de7727d25cf980e9a215ec73603b9fd469b7d357;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Added support for server-side cursor in oracledb async dialect. Added API support for server-side cursors for the oracledb async dialect, allowing use of the :meth:`_asyncio.AsyncConnection.stream` and similar stream methods. Fixes: #10820 Change-Id: I861670ccc20a81ec5ee45132b8059fc2a0359087 (cherry picked from commit ffb2e2d033f8e227b80ba3c5d06c67a96310e1ec) --- diff --git a/doc/build/changelog/unreleased_20/10820.rst b/doc/build/changelog/unreleased_20/10820.rst new file mode 100644 index 0000000000..e2cc717e2e --- /dev/null +++ b/doc/build/changelog/unreleased_20/10820.rst @@ -0,0 +1,7 @@ +.. change:: + :tags: oracle, usecase + :tickets: 10820 + + Added API support for server-side cursors for the oracledb async dialect, + allowing use of the :meth:`_asyncio.AsyncConnection.stream` and similar + stream methods. diff --git a/lib/sqlalchemy/connectors/asyncio.py b/lib/sqlalchemy/connectors/asyncio.py index 8dc198cf8e..9b19bef78f 100644 --- a/lib/sqlalchemy/connectors/asyncio.py +++ b/lib/sqlalchemy/connectors/asyncio.py @@ -36,7 +36,8 @@ class AsyncAdapt_dbapi_cursor: cursor = self._connection.cursor() self._cursor = self._aenter_cursor(cursor) - self._rows = collections.deque() + if not self.server_side: + self._rows = collections.deque() def _aenter_cursor(self, cursor): return self.await_(cursor.__aenter__()) @@ -149,6 +150,14 @@ class AsyncAdapt_dbapi_ss_cursor(AsyncAdapt_dbapi_cursor): def fetchall(self): return self.await_(self._cursor.fetchall()) + def __iter__(self): + iterator = self._cursor.__aiter__() + while True: + try: + yield self.await_(iterator.__anext__()) + except StopAsyncIteration: + break + class AsyncAdapt_dbapi_connection(AdaptedConnection): _cursor_cls = AsyncAdapt_dbapi_cursor diff --git a/lib/sqlalchemy/dialects/oracle/oracledb.py b/lib/sqlalchemy/dialects/oracle/oracledb.py index 1f5a19b876..0667ed768e 100644 --- a/lib/sqlalchemy/dialects/oracle/oracledb.py +++ b/lib/sqlalchemy/dialects/oracle/oracledb.py @@ -94,12 +94,14 @@ import re from typing import Any from typing import TYPE_CHECKING -from .cx_oracle import OracleDialect_cx_oracle as _OracleDialect_cx_oracle +from . import cx_oracle as _cx_oracle from ... import exc from ... import pool from ...connectors.asyncio import AsyncAdapt_dbapi_connection from ...connectors.asyncio import AsyncAdapt_dbapi_cursor +from ...connectors.asyncio import AsyncAdapt_dbapi_ss_cursor from ...connectors.asyncio import AsyncAdaptFallback_dbapi_connection +from ...engine import default from ...util import asbool from ...util import await_fallback from ...util import await_only @@ -109,8 +111,16 @@ if TYPE_CHECKING: from oracledb import AsyncCursor -class OracleDialect_oracledb(_OracleDialect_cx_oracle): +class OracleExecutionContext_oracledb( + _cx_oracle.OracleExecutionContext_cx_oracle +): + pass + + +class OracleDialect_oracledb(_cx_oracle.OracleDialect_cx_oracle): supports_statement_cache = True + execution_ctx_cls = OracleExecutionContext_oracledb + driver = "oracledb" _min_version = (1,) @@ -267,6 +277,17 @@ class AsyncAdapt_oracledb_cursor(AsyncAdapt_dbapi_cursor): self.close() +class AsyncAdapt_oracledb_ss_cursor( + AsyncAdapt_dbapi_ss_cursor, AsyncAdapt_oracledb_cursor +): + __slots__ = () + + def close(self) -> None: + if self._cursor is not None: + self._cursor.close() + self._cursor = None # type: ignore + + class AsyncAdapt_oracledb_connection(AsyncAdapt_dbapi_connection): _connection: AsyncConnection __slots__ = () @@ -307,6 +328,9 @@ class AsyncAdapt_oracledb_connection(AsyncAdapt_dbapi_connection): def cursor(self): return AsyncAdapt_oracledb_cursor(self) + def ss_cursor(self): + return AsyncAdapt_oracledb_ss_cursor(self) + def xid(self, *args: Any, **kwargs: Any) -> Any: return self._connection.xid(*args, **kwargs) @@ -355,9 +379,31 @@ class OracledbAdaptDBAPI: ) +class OracleExecutionContextAsync_oracledb(OracleExecutionContext_oracledb): + # restore default create cursor + create_cursor = default.DefaultExecutionContext.create_cursor + + def create_default_cursor(self): + # copy of OracleExecutionContext_cx_oracle.create_cursor + c = self._dbapi_connection.cursor() + if self.dialect.arraysize: + c.arraysize = self.dialect.arraysize + + return c + + def create_server_side_cursor(self): + c = self._dbapi_connection.ss_cursor() + if self.dialect.arraysize: + c.arraysize = self.dialect.arraysize + + return c + + class OracleDialectAsync_oracledb(OracleDialect_oracledb): is_async = True + supports_server_side_cursors = True supports_statement_cache = True + execution_ctx_cls = OracleExecutionContextAsync_oracledb _min_version = (2,) diff --git a/lib/sqlalchemy/testing/suite/test_results.py b/lib/sqlalchemy/testing/suite/test_results.py index b3f432fb76..2b91a559db 100644 --- a/lib/sqlalchemy/testing/suite/test_results.py +++ b/lib/sqlalchemy/testing/suite/test_results.py @@ -7,6 +7,7 @@ # mypy: ignore-errors import datetime +import re from .. import engines from .. import fixtures @@ -273,6 +274,8 @@ class ServerSideCursorsTest( return getattr(cursor, "server_side", False) elif self.engine.dialect.driver == "psycopg": return bool(getattr(cursor, "name", False)) + elif self.engine.dialect.driver == "oracledb": + return getattr(cursor, "server_side", False) else: return False @@ -293,11 +296,26 @@ class ServerSideCursorsTest( ) return self.engine + def stringify(self, str_): + return re.compile(r"SELECT (\d+)", re.I).sub( + lambda m: str(select(int(m.group(1))).compile(testing.db)), str_ + ) + @testing.combinations( - ("global_string", True, "select 1", True), - ("global_text", True, text("select 1"), True), + ("global_string", True, lambda stringify: stringify("select 1"), True), + ( + "global_text", + True, + lambda stringify: text(stringify("select 1")), + True, + ), ("global_expr", True, select(1), True), - ("global_off_explicit", False, text("select 1"), False), + ( + "global_off_explicit", + False, + lambda stringify: text(stringify("select 1")), + False, + ), ( "stmt_option", False, @@ -315,15 +333,22 @@ class ServerSideCursorsTest( ( "for_update_string", True, - "SELECT 1 FOR UPDATE", + lambda stringify: stringify("SELECT 1 FOR UPDATE"), True, testing.skip_if(["sqlite", "mssql"]), ), - ("text_no_ss", False, text("select 42"), False), + ( + "text_no_ss", + False, + lambda stringify: text(stringify("select 42")), + False, + ), ( "text_ss_option", False, - text("select 42").execution_options(stream_results=True), + lambda stringify: text(stringify("select 42")).execution_options( + stream_results=True + ), True, ), id_="iaaa", @@ -334,6 +359,11 @@ class ServerSideCursorsTest( ): engine = self._fixture(engine_ss_arg) with engine.begin() as conn: + if callable(statement): + statement = testing.resolve_lambda( + statement, stringify=self.stringify + ) + if isinstance(statement, str): result = conn.exec_driver_sql(statement) else: @@ -348,7 +378,7 @@ class ServerSideCursorsTest( # should be enabled for this one result = conn.execution_options( stream_results=True - ).exec_driver_sql("select 1") + ).exec_driver_sql(self.stringify("select 1")) assert self._is_server_side(result.cursor) # the connection has autobegun, which means at the end of the @@ -402,7 +432,9 @@ class ServerSideCursorsTest( test_table = Table( "test_table", md, - Column("id", Integer, primary_key=True), + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), Column("data", String(50)), ) @@ -442,7 +474,9 @@ class ServerSideCursorsTest( test_table = Table( "test_table", md, - Column("id", Integer, primary_key=True), + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), Column("data", String(50)), ) diff --git a/test/engine/test_deprecations.py b/test/engine/test_deprecations.py index 9041a6af10..30bf9e66f6 100644 --- a/test/engine/test_deprecations.py +++ b/test/engine/test_deprecations.py @@ -300,10 +300,6 @@ class PoolTest(fixtures.TestBase): is_(fairy.connection, fairy.dbapi_connection) -def select1(db): - return str(select(1).compile(dialect=db.dialect)) - - class ResetEventTest(fixtures.TestBase): def _fixture(self, **kw): dbapi = Mock() diff --git a/test/engine/test_execute.py b/test/engine/test_execute.py index 122c08461d..61c422bb56 100644 --- a/test/engine/test_execute.py +++ b/test/engine/test_execute.py @@ -1940,13 +1940,10 @@ class EngineEventsTest(fixtures.TestBase): def test_new_exec_driver_sql_no_events(self): m1 = Mock() - def select1(db): - return str(select(1).compile(dialect=db.dialect)) - with testing.db.connect() as conn: event.listen(conn, "before_execute", m1.before_execute) event.listen(conn, "after_execute", m1.after_execute) - conn.exec_driver_sql(select1(testing.db)) + conn.exec_driver_sql(str(select(1).compile(testing.db))) eq_(m1.mock_calls, []) def test_add_event_after_connect(self, testing_engine): diff --git a/test/ext/asyncio/test_engine_py3k.py b/test/ext/asyncio/test_engine_py3k.py index 9fb12e6936..227307e086 100644 --- a/test/ext/asyncio/test_engine_py3k.py +++ b/test/ext/asyncio/test_engine_py3k.py @@ -21,6 +21,7 @@ from sqlalchemy import String from sqlalchemy import Table from sqlalchemy import testing from sqlalchemy import text +from sqlalchemy import true from sqlalchemy import union_all from sqlalchemy.engine import cursor as _cursor from sqlalchemy.ext.asyncio import async_engine_from_config @@ -405,8 +406,7 @@ class AsyncEngineTest(EngineFixture): @async_test async def test_statement_compile(self, async_engine): - stmt = _select1(async_engine) - eq_(str(select(1).compile(async_engine)), stmt) + stmt = str(select(1).compile(async_engine)) async with async_engine.connect() as conn: eq_(str(select(1).compile(conn)), stmt) @@ -967,11 +967,11 @@ class AsyncEventTest(EngineFixture): event.listen(async_engine.sync_engine, "before_cursor_execute", canary) - s1 = _select1(async_engine) async with async_engine.connect() as conn: sync_conn = conn.sync_connection - await conn.execute(text(s1)) + await conn.execute(select(1)) + s1 = str(select(1).compile(async_engine)) eq_( canary.mock_calls, [mock.call(sync_conn, mock.ANY, s1, mock.ANY, mock.ANY, False)], @@ -981,15 +981,15 @@ class AsyncEventTest(EngineFixture): async def test_sync_before_cursor_execute_connection(self, async_engine): canary = mock.Mock() - s1 = _select1(async_engine) async with async_engine.connect() as conn: sync_conn = conn.sync_connection event.listen( async_engine.sync_engine, "before_cursor_execute", canary ) - await conn.execute(text(s1)) + await conn.execute(select(1)) + s1 = str(select(1).compile(async_engine)) eq_( canary.mock_calls, [mock.call(sync_conn, mock.ANY, s1, mock.ANY, mock.ANY, False)], @@ -1331,20 +1331,51 @@ class AsyncResultTest(EngineFixture): ): await result.one() - @testing.combinations( - ("scalars",), ("stream_scalars",), argnames="filter_" - ) + @testing.combinations(("scalars",), ("stream_scalars",), argnames="case") @async_test - async def test_scalars(self, async_engine, filter_): + async def test_scalars(self, async_engine, case): users = self.tables.users async with async_engine.connect() as conn: - if filter_ == "scalars": + if case == "scalars": result = (await conn.scalars(select(users))).all() - elif filter_ == "stream_scalars": + elif case == "stream_scalars": result = await (await conn.stream_scalars(select(users))).all() eq_(result, list(range(1, 20))) + @async_test + @testing.combinations(("stream",), ("stream_scalars",), argnames="case") + async def test_stream_fetch_many_not_complete(self, async_engine, case): + users = self.tables.users + big_query = select(users).join(users.alias("other"), true()) + async with async_engine.connect() as conn: + if case == "stream": + result = await conn.stream(big_query) + elif case == "stream_scalars": + result = await conn.stream_scalars(big_query) + + f1 = await result.fetchmany(5) + f2 = await result.fetchmany(10) + f3 = await result.fetchmany(7) + eq_(len(f1) + len(f2) + len(f3), 22) + + res = await result.fetchall() + eq_(len(res), 19 * 19 - 22) + + @async_test + @testing.combinations(("stream",), ("execute",), argnames="case") + async def test_cursor_close(self, async_engine, case): + users = self.tables.users + async with async_engine.connect() as conn: + if case == "stream": + result = await conn.stream(select(users)) + cursor = result._real_result.cursor + elif case == "execute": + result = await conn.execute(select(users)) + cursor = result.cursor + + await conn.run_sync(lambda _: cursor.close()) + class TextSyncDBAPI(fixtures.TestBase): __requires__ = ("asyncio",) @@ -1516,17 +1547,10 @@ class PoolRegenTest(EngineFixture): async def thing(engine): async with engine.connect() as conn: - await conn.exec_driver_sql("select 1") + await conn.exec_driver_sql(str(select(1).compile(engine))) if do_dispose: await engine.dispose() tasks = [thing(engine) for _ in range(10)] await asyncio.gather(*tasks) - - -def _select1(engine): - if engine.dialect.name == "oracle": - return "SELECT 1 FROM DUAL" - else: - return "SELECT 1"