]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Add oracledb_async driver support
authorFederico Caselli <cfederico87@gmail.com>
Thu, 21 Dec 2023 22:41:56 +0000 (23:41 +0100)
committerFederico Caselli <cfederico87@gmail.com>
Tue, 2 Jan 2024 18:17:33 +0000 (19:17 +0100)
Added support for :ref:`oracledb` in async mode.
The current implementation has some limitation, preventing
the support for :meth:`_asyncio.AsyncConnection.stream`.
Improved support if planned for the 2.1 release of SQLAlchemy.

Fixes: #10679
Change-Id: Iff123cf6241bcfa0fbac57529b80f933951be0a7

16 files changed:
doc/build/changelog/unreleased_20/10679.rst [new file with mode: 0644]
lib/sqlalchemy/connectors/aioodbc.py
lib/sqlalchemy/connectors/asyncio.py
lib/sqlalchemy/dialects/oracle/__init__.py
lib/sqlalchemy/dialects/oracle/cx_oracle.py
lib/sqlalchemy/dialects/oracle/oracledb.py
lib/sqlalchemy/dialects/postgresql/psycopg.py
lib/sqlalchemy/ext/asyncio/engine.py
lib/sqlalchemy/testing/provision.py
setup.cfg
test/dialect/oracle/test_dialect.py
test/dialect/oracle/test_types.py
test/ext/asyncio/test_engine_py3k.py
test/ext/asyncio/test_session_py3k.py
test/sql/test_operators.py
tox.ini

diff --git a/doc/build/changelog/unreleased_20/10679.rst b/doc/build/changelog/unreleased_20/10679.rst
new file mode 100644 (file)
index 0000000..485a87e
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: oracle, asyncio
+    :tickets: 10679
+
+    Added support for :ref:`oracledb` in async mode.
+    The current implementation has some limitation, preventing
+    the support for :meth:`_asyncio.AsyncConnection.stream`.
+    Improved support if planned for the 2.1 release of SQLAlchemy.
index 14b660a69c2782e6fee6e139fb6eeb9add32ef1d..2423bc5ec80a881ef046ff1532f5e1e45a3db881 100644 (file)
@@ -153,18 +153,5 @@ class aiodbcConnector(PyODBCConnector):
 
         return (), kw
 
-    def _do_isolation_level(self, connection, autocommit, isolation_level):
-        connection.set_autocommit(autocommit)
-        connection.set_isolation_level(isolation_level)
-
-    def _do_autocommit(self, connection, value):
-        connection.set_autocommit(value)
-
-    def set_readonly(self, connection, value):
-        connection.set_read_only(value)
-
-    def set_deferrable(self, connection, value):
-        connection.set_deferrable(value)
-
     def get_driver_connection(self, connection):
         return connection._connection
index 5f6d8b72a9b091f9eaf70a469da3a3a0c54c0aec..5126a466080a231d1013fe3b93cd2936c8bbb27d 100644 (file)
@@ -134,14 +134,16 @@ class AsyncAdapt_dbapi_cursor:
         self._connection = adapt_connection._connection
 
         cursor = self._make_new_cursor(self._connection)
+        self._cursor = self._aenter_cursor(cursor)
 
+        self._rows = collections.deque()
+
+    def _aenter_cursor(self, cursor: AsyncIODBAPICursor) -> AsyncIODBAPICursor:
         try:
-            self._cursor = await_(cursor.__aenter__())
+            return await_(cursor.__aenter__())  # type: ignore[no-any-return]
         except Exception as error:
             self._adapt_connection._handle_exception(error)
 
-        self._rows = collections.deque()
-
     def _make_new_cursor(
         self, connection: AsyncIODBAPIConnection
     ) -> AsyncIODBAPICursor:
@@ -204,10 +206,6 @@ class AsyncAdapt_dbapi_cursor:
                 result = await self._cursor.execute(operation, parameters)
 
             if self._cursor.description and not self.server_side:
-                # aioodbc has a "fake" async result, so we have to pull it out
-                # of that here since our default result is not async.
-                # we could just as easily grab "_rows" here and be done with it
-                # but this is safer.
                 self._rows = collections.deque(await self._cursor.fetchall())
             return result
 
index e2c8d327a06cc472e0e1edbca2e05a0fb4394255..d855122ee0c7920a8683863d59244c501f02886b 100644 (file)
@@ -5,7 +5,7 @@
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
 # mypy: ignore-errors
-
+from types import ModuleType
 
 from . import base  # noqa
 from . import cx_oracle  # noqa
@@ -33,6 +33,10 @@ from .base import TIMESTAMP
 from .base import VARCHAR
 from .base import VARCHAR2
 
+# Alias oracledb also as oracledb_async
+oracledb_async = type(
+    "oracledb_async", (ModuleType,), {"dialect": oracledb.dialect_async}
+)
 
 base.dialect = dialect = cx_oracle.dialect
 
index 440ccad2bc17d94fdb2092927f719ca31cc8a74b..e8ed3ab5cb2c23da340d60fac1a1c8d7d087ea3d 100644 (file)
@@ -815,6 +815,8 @@ class OracleExecutionContext_cx_oracle(OracleExecutionContext):
 
                             out_parameters[name] = self.cursor.var(
                                 dbtype,
+                                # this is fine also in oracledb_async since
+                                # the driver will await the read coroutine
                                 outconverter=lambda value: value.read(),
                                 arraysize=len_params,
                             )
index 4c6e62446c06415ef7a74f78cbbb5305836425bc..78deecf4a24b5bb8d8332855b3d3c510ed702c80 100644 (file)
@@ -23,6 +23,31 @@ the Oracle Client Interface in the same way as cx_Oracle.
     :ref:`cx_oracle` - all of cx_Oracle's notes apply to the oracledb driver
     as well.
 
+The SQLAlchemy ``oracledb`` dialect provides both a sync and an async
+implementation under the same dialect name. The proper version is
+selected depending on how the engine is created:
+
+* calling :func:`_sa.create_engine` with ``oracle+oracledb://...`` will
+  automatically select the sync version, e.g.::
+
+    from sqlalchemy import create_engine
+    sync_engine = create_engine("oracle+oracledb://scott:tiger@localhost/?service_name=XEPDB1")
+
+* calling :func:`_asyncio.create_async_engine` with
+  ``oracle+oracledb://...`` will automatically select the async version,
+  e.g.::
+
+    from sqlalchemy.ext.asyncio import create_async_engine
+    asyncio_engine = create_async_engine("oracle+oracledb://scott:tiger@localhost/?service_name=XEPDB1")
+
+The asyncio version of the dialect may also be specified explicitly using the
+``oracledb_async`` suffix, as::
+
+    from sqlalchemy.ext.asyncio import create_async_engine
+    asyncio_engine = create_async_engine("oracle+oracledb_async://scott:tiger@localhost/?service_name=XEPDB1")
+
+.. versionadded:: 2.0.25 added support for the async version of oracledb.
+
 Thick mode support
 ------------------
 
@@ -49,15 +74,28 @@ like the ``lib_dir`` path, a dict may be passed to this parameter, as in::
 .. versionadded:: 2.0.0 added support for oracledb driver.
 
 """  # noqa
+from __future__ import annotations
+
+import collections
 import re
+from typing import Any
+from typing import TYPE_CHECKING
 
 from .cx_oracle import OracleDialect_cx_oracle as _OracleDialect_cx_oracle
 from ... import exc
+from ...connectors.asyncio import AsyncAdapt_dbapi_connection
+from ...connectors.asyncio import AsyncAdapt_dbapi_cursor
+from ...util import await_
+
+if TYPE_CHECKING:
+    from oracledb import AsyncConnection
+    from oracledb import AsyncCursor
 
 
 class OracleDialect_oracledb(_OracleDialect_cx_oracle):
     supports_statement_cache = True
     driver = "oracledb"
+    _min_version = (1,)
 
     def __init__(
         self,
@@ -92,6 +130,10 @@ class OracleDialect_oracledb(_OracleDialect_cx_oracle):
     def is_thin_mode(cls, connection):
         return connection.connection.dbapi_connection.thin
 
+    @classmethod
+    def get_async_dialect_cls(cls, url):
+        return OracleDialectAsync_oracledb
+
     def _load_version(self, dbapi_module):
         version = (0, 0, 0)
         if dbapi_module is not None:
@@ -101,10 +143,136 @@ class OracleDialect_oracledb(_OracleDialect_cx_oracle):
                     int(x) for x in m.group(1, 2, 3) if x is not None
                 )
         self.oracledb_ver = version
-        if self.oracledb_ver < (1,) and self.oracledb_ver > (0, 0, 0):
+        if (
+            self.oracledb_ver > (0, 0, 0)
+            and self.oracledb_ver < self._min_version
+        ):
             raise exc.InvalidRequestError(
-                "oracledb version 1 and above are supported"
+                f"oracledb version {self._min_version} and above are supported"
             )
 
 
+class AsyncAdapt_oracledb_cursor(AsyncAdapt_dbapi_cursor):
+    _cursor: AsyncCursor
+    __slots__ = ()
+
+    @property
+    def outputtypehandler(self):
+        return self._cursor.outputtypehandler
+
+    @outputtypehandler.setter
+    def outputtypehandler(self, value):
+        self._cursor.outputtypehandler = value
+
+    def var(self, *args, **kwargs):
+        return self._cursor.var(*args, **kwargs)
+
+    def close(self):
+        self._rows.clear()
+        self._cursor.close()
+
+    def setinputsizes(self, *args: Any, **kwargs: Any) -> Any:
+        return self._cursor.setinputsizes(*args, **kwargs)
+
+    def _aenter_cursor(self, cursor: AsyncCursor) -> AsyncCursor:
+        try:
+            return cursor.__enter__()
+        except Exception as error:
+            self._adapt_connection._handle_exception(error)
+
+    async def _execute_async(self, operation, parameters):
+        # override to not use mutex, oracledb already has mutex
+
+        if parameters is None:
+            result = await self._cursor.execute(operation)
+        else:
+            result = await self._cursor.execute(operation, parameters)
+
+        if self._cursor.description and not self.server_side:
+            self._rows = collections.deque(await self._cursor.fetchall())
+        return result
+
+    async def _executemany_async(
+        self,
+        operation,
+        seq_of_parameters,
+    ):
+        # override to not use mutex, oracledb already has mutex
+        return await self._cursor.executemany(operation, seq_of_parameters)
+
+
+class AsyncAdapt_oracledb_connection(AsyncAdapt_dbapi_connection):
+    _connection: AsyncConnection
+    __slots__ = ()
+
+    thin = True
+
+    _cursor_cls = AsyncAdapt_oracledb_cursor
+    _ss_cursor_cls = None
+
+    @property
+    def autocommit(self):
+        return self._connection.autocommit
+
+    @autocommit.setter
+    def autocommit(self, value):
+        self._connection.autocommit = value
+
+    @property
+    def outputtypehandler(self):
+        return self._connection.outputtypehandler
+
+    @outputtypehandler.setter
+    def outputtypehandler(self, value):
+        self._connection.outputtypehandler = value
+
+    @property
+    def version(self):
+        return self._connection.version
+
+    @property
+    def stmtcachesize(self):
+        return self._connection.stmtcachesize
+
+    @stmtcachesize.setter
+    def stmtcachesize(self, value):
+        self._connection.stmtcachesize = value
+
+    def cursor(self):
+        return AsyncAdapt_oracledb_cursor(self)
+
+
+class OracledbAdaptDBAPI:
+    def __init__(self, oracledb) -> None:
+        self.oracledb = oracledb
+
+        for k, v in self.oracledb.__dict__.items():
+            if k != "connect":
+                self.__dict__[k] = v
+
+    def connect(self, *arg, **kw):
+        creator_fn = kw.pop("async_creator_fn", self.oracledb.connect_async)
+        return AsyncAdapt_oracledb_connection(
+            self, await_(creator_fn(*arg, **kw))
+        )
+
+
+class OracleDialectAsync_oracledb(OracleDialect_oracledb):
+    is_async = True
+    supports_statement_cache = True
+
+    _min_version = (2,)
+
+    # thick_mode mode is not supported by asyncio, oracledb will raise
+    @classmethod
+    def import_dbapi(cls):
+        import oracledb
+
+        return OracledbAdaptDBAPI(oracledb)
+
+    def get_driver_connection(self, connection):
+        return connection._connection
+
+
 dialect = OracleDialect_oracledb
+dialect_async = OracleDialectAsync_oracledb
index 4ea9cbf3f8be96ff73ceb7d3dcda584c2d7e2e2a..9c18b7e667531c125dd8a664baeac9c15d4ab059 100644 (file)
@@ -80,6 +80,8 @@ from ...util.concurrency import await_
 if TYPE_CHECKING:
     from typing import Iterable
 
+    from psycopg import AsyncConnection
+
 logger = logging.getLogger("sqlalchemy.dialects.postgresql")
 
 
@@ -588,6 +590,7 @@ class AsyncAdapt_psycopg_ss_cursor(
 
 
 class AsyncAdapt_psycopg_connection(AsyncAdapt_dbapi_connection):
+    _connection: AsyncConnection
     __slots__ = ()
 
     _cursor_cls = AsyncAdapt_psycopg_cursor
index 5c4ec8cd050594972f62a954f36ee6cf074c1bcd..02b70ecd583389331f1b8eeb5f507d7ec8913187 100644 (file)
@@ -573,6 +573,11 @@ class AsyncConnection(
             :meth:`.AsyncConnection.stream_scalars`
 
         """
+        if not self.dialect.supports_server_side_cursors:
+            raise exc.InvalidRequestError(
+                "Cant use `stream` or `stream_scalars` with the current "
+                "dialect since it does not support server side cursors."
+            )
 
         result = await greenlet_spawn(
             self._proxied.execute,
index cdde264cb087d4809da7433af5c0c7637d685e2c..74cdb0c73d9beef874cfeecd2027e81331ff44c4 100644 (file)
@@ -146,7 +146,10 @@ def generate_db_urls(db_urls, extra_drivers):
     ]
 
     for url_obj, dialect in urls_plus_dialects:
-        backend_to_driver_we_already_have[dialect.name].add(dialect.driver)
+        # use get_driver_name instead of dialect.driver to account for
+        # "_async" virtual drivers like oracledb and psycopg
+        driver_name = url_obj.get_driver_name()
+        backend_to_driver_we_already_have[dialect.name].add(driver_name)
 
     backend_to_driver_we_need = {}
 
index f9248486262ed71f6e5d300996bd94c18d4eead2..2ff94822c6465f9139fd5cb81777d98c7d7ce363 100644 (file)
--- a/setup.cfg
+++ b/setup.cfg
@@ -178,4 +178,5 @@ docker_mssql = mssql+pyodbc://scott:tiger^5HHH@127.0.0.1:1433/test?driver=ODBC+D
 oracle = oracle+cx_oracle://scott:tiger@oracle18c/xe
 cxoracle = oracle+cx_oracle://scott:tiger@oracle18c/xe
 oracledb = oracle+oracledb://scott:tiger@oracle18c/xe
+oracledb_async = oracle+oracledb_async://scott:tiger@oracle18c/xe
 docker_oracle = oracle+cx_oracle://scott:tiger@127.0.0.1:1521/?service_name=FREEPDB1
index 93cf0b74578aaf94147048310eb4344d4889473e..68ee3f7180085c7c6ff2fa6dce3c6a24a8e0bc6a 100644 (file)
@@ -36,6 +36,7 @@ from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import is_false
 from sqlalchemy.testing import is_true
 from sqlalchemy.testing.assertions import expect_raises_message
+from sqlalchemy.testing.assertions import is_
 from sqlalchemy.testing.schema import Column
 from sqlalchemy.testing.schema import pep435_enum
 from sqlalchemy.testing.schema import Table
@@ -69,6 +70,8 @@ class CxOracleDialectTest(fixtures.TestBase):
 
 
 class OracleDbDialectTest(fixtures.TestBase):
+    __only_on__ = "oracle+oracledb"
+
     def test_oracledb_version_parse(self):
         dialect = oracledb.OracleDialect_oracledb()
 
@@ -84,19 +87,36 @@ class OracleDbDialectTest(fixtures.TestBase):
     def test_minimum_version(self):
         with expect_raises_message(
             exc.InvalidRequestError,
-            "oracledb version 1 and above are supported",
+            r"oracledb version \(1,\) and above are supported",
         ):
             oracledb.OracleDialect_oracledb(dbapi=Mock(version="0.1.5"))
 
         dialect = oracledb.OracleDialect_oracledb(dbapi=Mock(version="7.1.0"))
         eq_(dialect.oracledb_ver, (7, 1, 0))
 
+    def test_get_dialect(self):
+        u = url.URL.create("oracle://")
+        d = oracledb.OracleDialect_oracledb.get_dialect_cls(u)
+        is_(d, oracledb.OracleDialect_oracledb)
+        d = oracledb.OracleDialect_oracledb.get_async_dialect_cls(u)
+        is_(d, oracledb.OracleDialectAsync_oracledb)
+        d = oracledb.OracleDialectAsync_oracledb.get_dialect_cls(u)
+        is_(d, oracledb.OracleDialectAsync_oracledb)
+        d = oracledb.OracleDialectAsync_oracledb.get_dialect_cls(u)
+        is_(d, oracledb.OracleDialectAsync_oracledb)
+
+    def test_async_version(self):
+        e = create_engine("oracle+oracledb_async://")
+        is_true(isinstance(e.dialect, oracledb.OracleDialectAsync_oracledb))
+
 
 class OracledbMode(fixtures.TestBase):
     __backend__ = True
     __only_on__ = "oracle+oracledb"
 
     def _run_in_process(self, fn, fn_kw=None):
+        if config.db.dialect.is_async:
+            config.skip_test("thick mode unsupported in async mode")
         ctx = get_context("spawn")
         queue = ctx.Queue()
         process = ctx.Process(
@@ -202,6 +222,7 @@ class DialectWBackendTest(fixtures.TestBase):
                 testing.db.dialect.get_isolation_level(dbapi_conn),
                 "READ COMMITTED",
             )
+            conn.close()
 
     def test_graceful_failure_isolation_level_not_available(self):
         engine = engines.testing_engine()
index 82a81612e1ef5b903c8ed20c2fa55466a1efe28a..3bf78c105a047316f13a65ebe413e01fb8d94d00 100644 (file)
@@ -50,6 +50,7 @@ from sqlalchemy.testing.engines import testing_engine
 from sqlalchemy.testing.schema import Column
 from sqlalchemy.testing.schema import Table
 from sqlalchemy.util import b
+from sqlalchemy.util.concurrency import await_
 
 
 def exec_sql(conn, sql, *args, **kwargs):
@@ -998,13 +999,23 @@ class LOBFetchTest(fixtures.TablesTest):
         for i in range(1, 11):
             connection.execute(binary_table.insert(), dict(id=i, data=stream))
 
+    def _read_lob(self, engine, row):
+        if engine.dialect.is_async:
+            data = await_(row._mapping["data"].read())
+            bindata = await_(row._mapping["bindata"].read())
+        else:
+            data = row._mapping["data"].read()
+            bindata = row._mapping["bindata"].read()
+        return data, bindata
+
     def test_lobs_without_convert(self):
         engine = testing_engine(options=dict(auto_convert_lobs=False))
         t = self.tables.z_test
         with engine.begin() as conn:
             row = conn.execute(t.select().where(t.c.id == 1)).first()
-            eq_(row._mapping["data"].read(), "this is text 1")
-            eq_(row._mapping["bindata"].read(), b("this is binary 1"))
+            data, bindata = self._read_lob(engine, row)
+            eq_(data, "this is text 1")
+            eq_(bindata, b("this is binary 1"))
 
     def test_lobs_with_convert(self, connection):
         t = self.tables.z_test
@@ -1028,17 +1039,13 @@ class LOBFetchTest(fixtures.TablesTest):
         results = result.fetchall()
 
         def go():
-            eq_(
-                [
-                    dict(
-                        id=row._mapping["id"],
-                        data=row._mapping["data"].read(),
-                        bindata=row._mapping["bindata"].read(),
-                    )
-                    for row in results
-                ],
-                self.data,
-            )
+            actual = []
+            for row in results:
+                data, bindata = self._read_lob(engine, row)
+                actual.append(
+                    dict(id=row._mapping["id"], data=data, bindata=bindata)
+                )
+            eq_(actual, self.data)
 
         # this comes from cx_Oracle because these are raw
         # cx_Oracle.Variable objects
index 5ca465906a897210dd4360986ca359a5692216be..15a0ebfd7f22df416e818854575a195654db01c5 100644 (file)
@@ -785,6 +785,27 @@ class AsyncEngineTest(EngineFixture):
         finally:
             await greenlet_spawn(conn.close)
 
+    @testing.combinations("stream", "stream_scalars", argnames="method")
+    @async_test
+    async def test_server_side_required_for_scalars(
+        self, async_engine, method
+    ):
+        with mock.patch.object(
+            async_engine.dialect, "supports_server_side_cursors", False
+        ):
+            async with async_engine.connect() as c:
+                with expect_raises_message(
+                    exc.InvalidRequestError,
+                    "Cant use `stream` or `stream_scalars` with the current "
+                    "dialect since it does not support server side cursors.",
+                ):
+                    if method == "stream":
+                        await c.stream(select(1))
+                    elif method == "stream_scalars":
+                        await c.stream_scalars(select(1))
+                    else:
+                        testing.fail(method)
+
 
 class AsyncCreatePoolTest(fixtures.TestBase):
     @config.fixture
@@ -857,44 +878,44 @@ class AsyncEventTest(EngineFixture):
         ):
             event.listen(async_engine, "checkout", mock.Mock())
 
+    def select1(self, engine):
+        if engine.dialect.name == "oracle":
+            return "select 1 from dual"
+        else:
+            return "select 1"
+
     @async_test
     async def test_sync_before_cursor_execute_engine(self, async_engine):
         canary = mock.Mock()
 
         event.listen(async_engine.sync_engine, "before_cursor_execute", canary)
 
+        s1 = self.select1(async_engine)
         async with async_engine.connect() as conn:
             sync_conn = conn.sync_connection
-            await conn.execute(text("select 1"))
+            await conn.execute(text(s1))
 
         eq_(
             canary.mock_calls,
-            [
-                mock.call(
-                    sync_conn, mock.ANY, "select 1", mock.ANY, mock.ANY, False
-                )
-            ],
+            [mock.call(sync_conn, mock.ANY, s1, mock.ANY, mock.ANY, False)],
         )
 
     @async_test
     async def test_sync_before_cursor_execute_connection(self, async_engine):
         canary = mock.Mock()
 
+        s1 = self.select1(async_engine)
         async with async_engine.connect() as conn:
             sync_conn = conn.sync_connection
 
             event.listen(
                 async_engine.sync_engine, "before_cursor_execute", canary
             )
-            await conn.execute(text("select 1"))
+            await conn.execute(text(s1))
 
         eq_(
             canary.mock_calls,
-            [
-                mock.call(
-                    sync_conn, mock.ANY, "select 1", mock.ANY, mock.ANY, False
-                )
-            ],
+            [mock.call(sync_conn, mock.ANY, s1, mock.ANY, mock.ANY, False)],
         )
 
     @async_test
@@ -932,6 +953,9 @@ class AsyncInspection(EngineFixture):
 
 
 class AsyncResultTest(EngineFixture):
+    __backend__ = True
+    __requires__ = ("server_side_cursors", "async_dialect")
+
     @async_test
     async def test_no_ss_cursor_w_execute(self, async_engine):
         users = self.tables.users
@@ -1259,7 +1283,13 @@ class TextSyncDBAPI(fixtures.TestBase):
     def async_engine(self):
         engine = create_engine("sqlite:///:memory:", future=True)
         engine.dialect.is_async = True
-        return _async_engine.AsyncEngine(engine)
+        engine.dialect.supports_server_side_cursors = True
+        with mock.patch.object(
+            engine.dialect.execution_ctx_cls,
+            "create_server_side_cursor",
+            engine.dialect.execution_ctx_cls.create_default_cursor,
+        ):
+            yield _async_engine.AsyncEngine(engine)
 
     @async_test
     @combinations(
index e38a0cc52a90f7741d67c17999dc5781f7529226..2d6ce09da3af052672d8f6c651ad6bcd22b42e4f 100644 (file)
@@ -4,7 +4,6 @@ import contextlib
 from typing import List
 from typing import Optional
 
-from sqlalchemy import Column
 from sqlalchemy import event
 from sqlalchemy import exc
 from sqlalchemy import ForeignKey
@@ -47,6 +46,7 @@ from sqlalchemy.testing.assertions import is_false
 from sqlalchemy.testing.assertions import not_in
 from sqlalchemy.testing.entities import ComparableEntity
 from sqlalchemy.testing.provision import normalize_sequence
+from sqlalchemy.testing.schema import Column
 from .test_engine_py3k import AsyncFixture as _AsyncFixture
 from ...orm import _fixtures
 
@@ -314,6 +314,7 @@ class AsyncSessionQueryTest(AsyncFixture):
 
     @testing.combinations("statement", "execute", argnames="location")
     @async_test
+    @testing.requires.server_side_cursors
     async def test_no_ss_cursor_w_execute(self, async_session, location):
         User = self.classes.User
 
@@ -767,7 +768,9 @@ class AsyncORMBehaviorsTest(AsyncFixture):
             class A:
                 __tablename__ = "a"
 
-                id = Column(Integer, primary_key=True)
+                id = Column(
+                    Integer, primary_key=True, test_needs_autoincrement=True
+                )
                 b = relationship(
                     "B",
                     uselist=False,
@@ -779,7 +782,9 @@ class AsyncORMBehaviorsTest(AsyncFixture):
             @registry.mapped
             class B:
                 __tablename__ = "b"
-                id = Column(Integer, primary_key=True)
+                id = Column(
+                    Integer, primary_key=True, test_needs_autoincrement=True
+                )
                 a_id = Column(ForeignKey("a.id"))
 
             async with async_engine.begin() as conn:
@@ -790,14 +795,8 @@ class AsyncORMBehaviorsTest(AsyncFixture):
         return go
 
     @testing.combinations(
-        (
-            "legacy_style",
-            True,
-        ),
-        (
-            "new_style",
-            False,
-        ),
+        ("legacy_style", True),
+        ("new_style", False),
         argnames="_legacy_inactive_history_style",
         id_="ia",
     )
index af51010c761a0169db4326ac419d8cc74d24fd33..c841e364db5e6aa602dbbc303f20f228726b2cc3 100644 (file)
@@ -86,7 +86,7 @@ class LoopOperate(operators.ColumnOperators):
 class DefaultColumnComparatorTest(
     testing.AssertsCompiledSQL, fixtures.TestBase
 ):
-    dialect = "default_enhanced"
+    dialect = __dialect__ = "default_enhanced"
 
     @testing.combinations((operators.desc_op, desc), (operators.asc_op, asc))
     def test_scalar(self, operator, compare_to):
diff --git a/tox.ini b/tox.ini
index 4c3cca1f76a24df20c2d7a75fd96b16a564194db..cd07aa9620259ec8be3eabd8b326e578b47ac23f 100644 (file)
--- a/tox.ini
+++ b/tox.ini
@@ -111,7 +111,7 @@ setenv=
 
     oracle: WORKERS={env:TOX_WORKERS:-n2  --max-worker-restart=5}
     oracle: ORACLE={env:TOX_ORACLE:--db oracle}
-    oracle: EXTRA_ORACLE_DRIVERS={env:EXTRA_ORACLE_DRIVERS:--dbdriver cx_oracle --dbdriver oracledb}
+    oracle: EXTRA_ORACLE_DRIVERS={env:EXTRA_ORACLE_DRIVERS:--dbdriver cx_oracle --dbdriver oracledb --dbdriver oracledb_async}
 
     sqlite: SQLITE={env:TOX_SQLITE:--db sqlite}
     sqlite_file: SQLITE={env:TOX_SQLITE_FILE:--db sqlite_file}