]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Added support for server-side cursor in oracledb async dialect.
authorFederico Caselli <cfederico87@gmail.com>
Thu, 1 Aug 2024 19:16:20 +0000 (21:16 +0200)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 5 Aug 2024 16:16:47 +0000 (12:16 -0400)
Added API support for server-side cursors for the oracledb async dialect,
allowing use of the :meth:`_asyncio.AsyncConnection.stream` and similar
stream methods.

Fixes: #10820
Change-Id: I861670ccc20a81ec5ee45132b8059fc2a0359087

doc/build/changelog/unreleased_20/10820.rst [new file with mode: 0644]
lib/sqlalchemy/connectors/asyncio.py
lib/sqlalchemy/dialects/oracle/oracledb.py
lib/sqlalchemy/dialects/postgresql/asyncpg.py
lib/sqlalchemy/dialects/postgresql/psycopg.py
lib/sqlalchemy/testing/suite/test_results.py
test/engine/test_deprecations.py
test/engine/test_execute.py
test/ext/asyncio/test_engine_py3k.py

diff --git a/doc/build/changelog/unreleased_20/10820.rst b/doc/build/changelog/unreleased_20/10820.rst
new file mode 100644 (file)
index 0000000..e2cc717
--- /dev/null
@@ -0,0 +1,7 @@
+.. change::
+    :tags: oracle, usecase
+    :tickets: 10820
+
+    Added API support for server-side cursors for the oracledb async dialect,
+    allowing use of the :meth:`_asyncio.AsyncConnection.stream` and similar
+    stream methods.
index 34820facb6a2a3f21a9883e79dcc95fc07eacb15..27d438cda2741e1fe42357ba1fbc3ba3a8b85d2f 100644 (file)
@@ -13,6 +13,7 @@ import asyncio
 import collections
 import sys
 from typing import Any
+from typing import AsyncIterator
 from typing import Deque
 from typing import Iterator
 from typing import NoReturn
@@ -97,6 +98,8 @@ class AsyncIODBAPICursor(Protocol):
 
     async def nextset(self) -> Optional[bool]: ...
 
+    def __aiter__(self) -> AsyncIterator[Any]: ...
+
 
 class AsyncAdapt_dbapi_cursor:
     server_side = False
@@ -119,7 +122,8 @@ class AsyncAdapt_dbapi_cursor:
         cursor = self._make_new_cursor(self._connection)
         self._cursor = self._aenter_cursor(cursor)
 
-        self._rows = collections.deque()
+        if not self.server_side:
+            self._rows = collections.deque()
 
     def _aenter_cursor(self, cursor: AsyncIODBAPICursor) -> AsyncIODBAPICursor:
         try:
@@ -258,6 +262,14 @@ class AsyncAdapt_dbapi_ss_cursor(AsyncAdapt_dbapi_cursor):
     def fetchall(self) -> Sequence[Any]:
         return await_(self._cursor.fetchall())
 
+    def __iter__(self) -> Iterator[Any]:
+        iterator = self._cursor.__aiter__()
+        while True:
+            try:
+                yield await_(iterator.__anext__())
+            except StopAsyncIteration:
+                break
+
 
 class AsyncAdapt_dbapi_connection(AdaptedConnection):
     _cursor_cls = AsyncAdapt_dbapi_cursor
index e48dcdc6bbe0a54d33f54841bd99f9e74842b061..377310f642533be8c9008f87509a110fdb54ff79 100644 (file)
@@ -94,10 +94,12 @@ import re
 from typing import Any
 from typing import TYPE_CHECKING
 
-from .cx_oracle import OracleDialect_cx_oracle as _OracleDialect_cx_oracle
+from . import cx_oracle as _cx_oracle
 from ... import exc
 from ...connectors.asyncio import AsyncAdapt_dbapi_connection
 from ...connectors.asyncio import AsyncAdapt_dbapi_cursor
+from ...connectors.asyncio import AsyncAdapt_dbapi_ss_cursor
+from ...engine import default
 from ...util import await_
 
 if TYPE_CHECKING:
@@ -105,8 +107,16 @@ if TYPE_CHECKING:
     from oracledb import AsyncCursor
 
 
-class OracleDialect_oracledb(_OracleDialect_cx_oracle):
+class OracleExecutionContext_oracledb(
+    _cx_oracle.OracleExecutionContext_cx_oracle
+):
+    pass
+
+
+class OracleDialect_oracledb(_cx_oracle.OracleDialect_cx_oracle):
     supports_statement_cache = True
+    execution_ctx_cls = OracleExecutionContext_oracledb
+
     driver = "oracledb"
     _min_version = (1,)
 
@@ -257,6 +267,17 @@ class AsyncAdapt_oracledb_cursor(AsyncAdapt_dbapi_cursor):
         return await self._cursor.executemany(operation, seq_of_parameters)
 
 
+class AsyncAdapt_oracledb_ss_cursor(
+    AsyncAdapt_dbapi_ss_cursor, AsyncAdapt_oracledb_cursor
+):
+    __slots__ = ()
+
+    def close(self) -> None:
+        if self._cursor is not None:
+            self._cursor.close()
+            self._cursor = None  # type: ignore
+
+
 class AsyncAdapt_oracledb_connection(AsyncAdapt_dbapi_connection):
     _connection: AsyncConnection
     __slots__ = ()
@@ -297,6 +318,9 @@ class AsyncAdapt_oracledb_connection(AsyncAdapt_dbapi_connection):
     def cursor(self):
         return AsyncAdapt_oracledb_cursor(self)
 
+    def ss_cursor(self):
+        return AsyncAdapt_oracledb_ss_cursor(self)
+
     def xid(self, *args: Any, **kwargs: Any) -> Any:
         return self._connection.xid(*args, **kwargs)
 
@@ -331,9 +355,31 @@ class OracledbAdaptDBAPI:
         )
 
 
+class OracleExecutionContextAsync_oracledb(OracleExecutionContext_oracledb):
+    # restore default create cursor
+    create_cursor = default.DefaultExecutionContext.create_cursor
+
+    def create_default_cursor(self):
+        # copy of OracleExecutionContext_cx_oracle.create_cursor
+        c = self._dbapi_connection.cursor()
+        if self.dialect.arraysize:
+            c.arraysize = self.dialect.arraysize
+
+        return c
+
+    def create_server_side_cursor(self):
+        c = self._dbapi_connection.ss_cursor()
+        if self.dialect.arraysize:
+            c.arraysize = self.dialect.arraysize
+
+        return c
+
+
 class OracleDialectAsync_oracledb(OracleDialect_oracledb):
     is_async = True
+    supports_server_side_cursors = True
     supports_statement_cache = True
+    execution_ctx_cls = OracleExecutionContextAsync_oracledb
 
     _min_version = (2,)
 
index 66cdeb84639ea044675582ce6c4652a0a851457d..cb6b75154f31b0fb3189e677324b1e0b0fd0f3f3 100644 (file)
@@ -520,8 +520,6 @@ class AsyncAdapt_asyncpg_cursor(AsyncAdapt_dbapi_cursor):
         "_invalidate_schema_cache_asof",
     )
 
-    server_side = False
-
     _adapt_connection: AsyncAdapt_asyncpg_connection
     _connection: _AsyncpgConnection
     _cursor: Optional[_AsyncpgCursor]
@@ -636,7 +634,6 @@ class AsyncAdapt_asyncpg_cursor(AsyncAdapt_dbapi_cursor):
 class AsyncAdapt_asyncpg_ss_cursor(
     AsyncAdapt_dbapi_ss_cursor, AsyncAdapt_asyncpg_cursor
 ):
-    server_side = True
     __slots__ = ("_rowbuffer",)
 
     def __init__(self, adapt_connection):
index 5bdae1703a8ae674d63bec5682f99393c17896d7..a1fdce1b4635a08245aae478f6c90e6a1a98c798 100644 (file)
@@ -611,15 +611,6 @@ class AsyncAdapt_psycopg_ss_cursor(
     def _make_new_cursor(self, connection):
         return connection.cursor(self.name)
 
-    # TODO: should this be on the base asyncio adapter?
-    def __iter__(self):
-        iterator = self._cursor.__aiter__()
-        while True:
-            try:
-                yield await_(iterator.__anext__())
-            except StopAsyncIteration:
-                break
-
 
 class AsyncAdapt_psycopg_connection(AsyncAdapt_dbapi_connection):
     _connection: AsyncConnection
index 639a5d056b7ce7845b3d32849817491dd919973a..7d1565bba3d494e9391f9bcba5878c4cad81df79 100644 (file)
@@ -7,6 +7,7 @@
 # mypy: ignore-errors
 
 import datetime
+import re
 
 from .. import engines
 from .. import fixtures
@@ -429,6 +430,8 @@ class ServerSideCursorsTest(
             return getattr(cursor, "server_side", False)
         elif self.engine.dialect.driver == "psycopg":
             return bool(getattr(cursor, "name", False))
+        elif self.engine.dialect.driver == "oracledb":
+            return getattr(cursor, "server_side", False)
         else:
             return False
 
@@ -449,11 +452,26 @@ class ServerSideCursorsTest(
             )
         return self.engine
 
+    def stringify(self, str_):
+        return re.compile(r"SELECT (\d+)", re.I).sub(
+            lambda m: str(select(int(m.group(1))).compile(testing.db)), str_
+        )
+
     @testing.combinations(
-        ("global_string", True, "select 1", True),
-        ("global_text", True, text("select 1"), True),
+        ("global_string", True, lambda stringify: stringify("select 1"), True),
+        (
+            "global_text",
+            True,
+            lambda stringify: text(stringify("select 1")),
+            True,
+        ),
         ("global_expr", True, select(1), True),
-        ("global_off_explicit", False, text("select 1"), False),
+        (
+            "global_off_explicit",
+            False,
+            lambda stringify: text(stringify("select 1")),
+            False,
+        ),
         (
             "stmt_option",
             False,
@@ -471,15 +489,22 @@ class ServerSideCursorsTest(
         (
             "for_update_string",
             True,
-            "SELECT 1 FOR UPDATE",
+            lambda stringify: stringify("SELECT 1 FOR UPDATE"),
             True,
             testing.skip_if(["sqlite", "mssql"]),
         ),
-        ("text_no_ss", False, text("select 42"), False),
+        (
+            "text_no_ss",
+            False,
+            lambda stringify: text(stringify("select 42")),
+            False,
+        ),
         (
             "text_ss_option",
             False,
-            text("select 42").execution_options(stream_results=True),
+            lambda stringify: text(stringify("select 42")).execution_options(
+                stream_results=True
+            ),
             True,
         ),
         id_="iaaa",
@@ -490,6 +515,11 @@ class ServerSideCursorsTest(
     ):
         engine = self._fixture(engine_ss_arg)
         with engine.begin() as conn:
+            if callable(statement):
+                statement = testing.resolve_lambda(
+                    statement, stringify=self.stringify
+                )
+
             if isinstance(statement, str):
                 result = conn.exec_driver_sql(statement)
             else:
@@ -504,7 +534,7 @@ class ServerSideCursorsTest(
             # should be enabled for this one
             result = conn.execution_options(
                 stream_results=True
-            ).exec_driver_sql("select 1")
+            ).exec_driver_sql(self.stringify("select 1"))
             assert self._is_server_side(result.cursor)
 
             # the connection has autobegun, which means at the end of the
@@ -558,7 +588,9 @@ class ServerSideCursorsTest(
         test_table = Table(
             "test_table",
             md,
-            Column("id", Integer, primary_key=True),
+            Column(
+                "id", Integer, primary_key=True, test_needs_autoincrement=True
+            ),
             Column("data", String(50)),
         )
 
@@ -598,7 +630,9 @@ class ServerSideCursorsTest(
         test_table = Table(
             "test_table",
             md,
-            Column("id", Integer, primary_key=True),
+            Column(
+                "id", Integer, primary_key=True, test_needs_autoincrement=True
+            ),
             Column("data", String(50)),
         )
 
index f6fa21f29dd418a25863aa800491299caf151bc5..a4a6f1f47cd55da5e7f61323046b95f550b16a65 100644 (file)
@@ -300,10 +300,6 @@ class PoolTest(fixtures.TestBase):
             is_(fairy.connection, fairy.dbapi_connection)
 
 
-def select1(db):
-    return str(select(1).compile(dialect=db.dialect))
-
-
 class ResetEventTest(fixtures.TestBase):
     def _fixture(self, **kw):
         dbapi = Mock()
index 31a9c4a70a5a1181f74d761521940f8df2498ab4..148d0be1a28527a7017da5bd86c95878f91adedc 100644 (file)
@@ -1964,13 +1964,10 @@ class EngineEventsTest(fixtures.TestBase):
     def test_new_exec_driver_sql_no_events(self):
         m1 = Mock()
 
-        def select1(db):
-            return str(select(1).compile(dialect=db.dialect))
-
         with testing.db.connect() as conn:
             event.listen(conn, "before_execute", m1.before_execute)
             event.listen(conn, "after_execute", m1.after_execute)
-            conn.exec_driver_sql(select1(testing.db))
+            conn.exec_driver_sql(str(select(1).compile(testing.db)))
         eq_(m1.mock_calls, [])
 
     def test_add_event_after_connect(self, testing_engine):
index ee5953636d447a56d7107ca6747825612965fe57..60edbf608d924382acb69a0646bb814159f802fd 100644 (file)
@@ -21,6 +21,7 @@ from sqlalchemy import String
 from sqlalchemy import Table
 from sqlalchemy import testing
 from sqlalchemy import text
+from sqlalchemy import true
 from sqlalchemy import union_all
 from sqlalchemy.engine import cursor as _cursor
 from sqlalchemy.ext.asyncio import async_engine_from_config
@@ -405,8 +406,7 @@ class AsyncEngineTest(EngineFixture):
 
     @async_test
     async def test_statement_compile(self, async_engine):
-        stmt = _select1(async_engine)
-        eq_(str(select(1).compile(async_engine)), stmt)
+        stmt = str(select(1).compile(async_engine))
         async with async_engine.connect() as conn:
             eq_(str(select(1).compile(conn)), stmt)
 
@@ -967,11 +967,11 @@ class AsyncEventTest(EngineFixture):
 
         event.listen(async_engine.sync_engine, "before_cursor_execute", canary)
 
-        s1 = _select1(async_engine)
         async with async_engine.connect() as conn:
             sync_conn = conn.sync_connection
-            await conn.execute(text(s1))
+            await conn.execute(select(1))
 
+        s1 = str(select(1).compile(async_engine))
         eq_(
             canary.mock_calls,
             [mock.call(sync_conn, mock.ANY, s1, mock.ANY, mock.ANY, False)],
@@ -981,15 +981,15 @@ class AsyncEventTest(EngineFixture):
     async def test_sync_before_cursor_execute_connection(self, async_engine):
         canary = mock.Mock()
 
-        s1 = _select1(async_engine)
         async with async_engine.connect() as conn:
             sync_conn = conn.sync_connection
 
             event.listen(
                 async_engine.sync_engine, "before_cursor_execute", canary
             )
-            await conn.execute(text(s1))
+            await conn.execute(select(1))
 
+        s1 = str(select(1).compile(async_engine))
         eq_(
             canary.mock_calls,
             [mock.call(sync_conn, mock.ANY, s1, mock.ANY, mock.ANY, False)],
@@ -1331,20 +1331,51 @@ class AsyncResultTest(EngineFixture):
             ):
                 await result.one()
 
-    @testing.combinations(
-        ("scalars",), ("stream_scalars",), argnames="filter_"
-    )
+    @testing.combinations(("scalars",), ("stream_scalars",), argnames="case")
     @async_test
-    async def test_scalars(self, async_engine, filter_):
+    async def test_scalars(self, async_engine, case):
         users = self.tables.users
         async with async_engine.connect() as conn:
-            if filter_ == "scalars":
+            if case == "scalars":
                 result = (await conn.scalars(select(users))).all()
-            elif filter_ == "stream_scalars":
+            elif case == "stream_scalars":
                 result = await (await conn.stream_scalars(select(users))).all()
 
         eq_(result, list(range(1, 20)))
 
+    @async_test
+    @testing.combinations(("stream",), ("stream_scalars",), argnames="case")
+    async def test_stream_fetch_many_not_complete(self, async_engine, case):
+        users = self.tables.users
+        big_query = select(users).join(users.alias("other"), true())
+        async with async_engine.connect() as conn:
+            if case == "stream":
+                result = await conn.stream(big_query)
+            elif case == "stream_scalars":
+                result = await conn.stream_scalars(big_query)
+
+            f1 = await result.fetchmany(5)
+            f2 = await result.fetchmany(10)
+            f3 = await result.fetchmany(7)
+            eq_(len(f1) + len(f2) + len(f3), 22)
+
+            res = await result.fetchall()
+            eq_(len(res), 19 * 19 - 22)
+
+    @async_test
+    @testing.combinations(("stream",), ("execute",), argnames="case")
+    async def test_cursor_close(self, async_engine, case):
+        users = self.tables.users
+        async with async_engine.connect() as conn:
+            if case == "stream":
+                result = await conn.stream(select(users))
+                cursor = result._real_result.cursor
+            elif case == "execute":
+                result = await conn.execute(select(users))
+                cursor = result.cursor
+
+            await conn.run_sync(lambda _: cursor.close())
+
 
 class TextSyncDBAPI(fixtures.TestBase):
     __requires__ = ("asyncio",)
@@ -1516,17 +1547,10 @@ class PoolRegenTest(EngineFixture):
 
         async def thing(engine):
             async with engine.connect() as conn:
-                await conn.exec_driver_sql("select 1")
+                await conn.exec_driver_sql(str(select(1).compile(engine)))
 
         if do_dispose:
             await engine.dispose()
 
         tasks = [thing(engine) for _ in range(10)]
         await asyncio.gather(*tasks)
-
-
-def _select1(engine):
-    if engine.dialect.name == "oracle":
-        return "SELECT 1 FROM DUAL"
-    else:
-        return "SELECT 1"