Added support for :ref:`oracledb` in async mode.
The current implementation has some limitation, preventing
the support for :meth:`_asyncio.AsyncConnection.stream`.
Improved support if planned for the 2.1 release of SQLAlchemy.
Fixes: #10679
Change-Id: Iff123cf6241bcfa0fbac57529b80f933951be0a7
--- /dev/null
+.. change::
+ :tags: oracle, asyncio
+ :tickets: 10679
+
+ Added support for :ref:`oracledb` in async mode.
+ The current implementation has some limitation, preventing
+ the support for :meth:`_asyncio.AsyncConnection.stream`.
+ Improved support if planned for the 2.1 release of SQLAlchemy.
return (), kw
- def _do_isolation_level(self, connection, autocommit, isolation_level):
- connection.set_autocommit(autocommit)
- connection.set_isolation_level(isolation_level)
-
- def _do_autocommit(self, connection, value):
- connection.set_autocommit(value)
-
- def set_readonly(self, connection, value):
- connection.set_read_only(value)
-
- def set_deferrable(self, connection, value):
- connection.set_deferrable(value)
-
def get_driver_connection(self, connection):
return connection._connection
self._connection = adapt_connection._connection
cursor = self._make_new_cursor(self._connection)
+ self._cursor = self._aenter_cursor(cursor)
+ self._rows = collections.deque()
+
+ def _aenter_cursor(self, cursor: AsyncIODBAPICursor) -> AsyncIODBAPICursor:
try:
- self._cursor = await_(cursor.__aenter__())
+ return await_(cursor.__aenter__()) # type: ignore[no-any-return]
except Exception as error:
self._adapt_connection._handle_exception(error)
- self._rows = collections.deque()
-
def _make_new_cursor(
self, connection: AsyncIODBAPIConnection
) -> AsyncIODBAPICursor:
result = await self._cursor.execute(operation, parameters)
if self._cursor.description and not self.server_side:
- # aioodbc has a "fake" async result, so we have to pull it out
- # of that here since our default result is not async.
- # we could just as easily grab "_rows" here and be done with it
- # but this is safer.
self._rows = collections.deque(await self._cursor.fetchall())
return result
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
-
+from types import ModuleType
from . import base # noqa
from . import cx_oracle # noqa
from .base import VARCHAR
from .base import VARCHAR2
+# Alias oracledb also as oracledb_async
+oracledb_async = type(
+ "oracledb_async", (ModuleType,), {"dialect": oracledb.dialect_async}
+)
base.dialect = dialect = cx_oracle.dialect
out_parameters[name] = self.cursor.var(
dbtype,
+ # this is fine also in oracledb_async since
+ # the driver will await the read coroutine
outconverter=lambda value: value.read(),
arraysize=len_params,
)
:ref:`cx_oracle` - all of cx_Oracle's notes apply to the oracledb driver
as well.
+The SQLAlchemy ``oracledb`` dialect provides both a sync and an async
+implementation under the same dialect name. The proper version is
+selected depending on how the engine is created:
+
+* calling :func:`_sa.create_engine` with ``oracle+oracledb://...`` will
+ automatically select the sync version, e.g.::
+
+ from sqlalchemy import create_engine
+ sync_engine = create_engine("oracle+oracledb://scott:tiger@localhost/?service_name=XEPDB1")
+
+* calling :func:`_asyncio.create_async_engine` with
+ ``oracle+oracledb://...`` will automatically select the async version,
+ e.g.::
+
+ from sqlalchemy.ext.asyncio import create_async_engine
+ asyncio_engine = create_async_engine("oracle+oracledb://scott:tiger@localhost/?service_name=XEPDB1")
+
+The asyncio version of the dialect may also be specified explicitly using the
+``oracledb_async`` suffix, as::
+
+ from sqlalchemy.ext.asyncio import create_async_engine
+ asyncio_engine = create_async_engine("oracle+oracledb_async://scott:tiger@localhost/?service_name=XEPDB1")
+
+.. versionadded:: 2.0.25 added support for the async version of oracledb.
+
Thick mode support
------------------
.. versionadded:: 2.0.0 added support for oracledb driver.
""" # noqa
+from __future__ import annotations
+
+import collections
import re
+from typing import Any
+from typing import TYPE_CHECKING
from .cx_oracle import OracleDialect_cx_oracle as _OracleDialect_cx_oracle
from ... import exc
+from ...connectors.asyncio import AsyncAdapt_dbapi_connection
+from ...connectors.asyncio import AsyncAdapt_dbapi_cursor
+from ...util import await_
+
+if TYPE_CHECKING:
+ from oracledb import AsyncConnection
+ from oracledb import AsyncCursor
class OracleDialect_oracledb(_OracleDialect_cx_oracle):
supports_statement_cache = True
driver = "oracledb"
+ _min_version = (1,)
def __init__(
self,
def is_thin_mode(cls, connection):
return connection.connection.dbapi_connection.thin
+ @classmethod
+ def get_async_dialect_cls(cls, url):
+ return OracleDialectAsync_oracledb
+
def _load_version(self, dbapi_module):
version = (0, 0, 0)
if dbapi_module is not None:
int(x) for x in m.group(1, 2, 3) if x is not None
)
self.oracledb_ver = version
- if self.oracledb_ver < (1,) and self.oracledb_ver > (0, 0, 0):
+ if (
+ self.oracledb_ver > (0, 0, 0)
+ and self.oracledb_ver < self._min_version
+ ):
raise exc.InvalidRequestError(
- "oracledb version 1 and above are supported"
+ f"oracledb version {self._min_version} and above are supported"
)
+class AsyncAdapt_oracledb_cursor(AsyncAdapt_dbapi_cursor):
+ _cursor: AsyncCursor
+ __slots__ = ()
+
+ @property
+ def outputtypehandler(self):
+ return self._cursor.outputtypehandler
+
+ @outputtypehandler.setter
+ def outputtypehandler(self, value):
+ self._cursor.outputtypehandler = value
+
+ def var(self, *args, **kwargs):
+ return self._cursor.var(*args, **kwargs)
+
+ def close(self):
+ self._rows.clear()
+ self._cursor.close()
+
+ def setinputsizes(self, *args: Any, **kwargs: Any) -> Any:
+ return self._cursor.setinputsizes(*args, **kwargs)
+
+ def _aenter_cursor(self, cursor: AsyncCursor) -> AsyncCursor:
+ try:
+ return cursor.__enter__()
+ except Exception as error:
+ self._adapt_connection._handle_exception(error)
+
+ async def _execute_async(self, operation, parameters):
+ # override to not use mutex, oracledb already has mutex
+
+ if parameters is None:
+ result = await self._cursor.execute(operation)
+ else:
+ result = await self._cursor.execute(operation, parameters)
+
+ if self._cursor.description and not self.server_side:
+ self._rows = collections.deque(await self._cursor.fetchall())
+ return result
+
+ async def _executemany_async(
+ self,
+ operation,
+ seq_of_parameters,
+ ):
+ # override to not use mutex, oracledb already has mutex
+ return await self._cursor.executemany(operation, seq_of_parameters)
+
+
+class AsyncAdapt_oracledb_connection(AsyncAdapt_dbapi_connection):
+ _connection: AsyncConnection
+ __slots__ = ()
+
+ thin = True
+
+ _cursor_cls = AsyncAdapt_oracledb_cursor
+ _ss_cursor_cls = None
+
+ @property
+ def autocommit(self):
+ return self._connection.autocommit
+
+ @autocommit.setter
+ def autocommit(self, value):
+ self._connection.autocommit = value
+
+ @property
+ def outputtypehandler(self):
+ return self._connection.outputtypehandler
+
+ @outputtypehandler.setter
+ def outputtypehandler(self, value):
+ self._connection.outputtypehandler = value
+
+ @property
+ def version(self):
+ return self._connection.version
+
+ @property
+ def stmtcachesize(self):
+ return self._connection.stmtcachesize
+
+ @stmtcachesize.setter
+ def stmtcachesize(self, value):
+ self._connection.stmtcachesize = value
+
+ def cursor(self):
+ return AsyncAdapt_oracledb_cursor(self)
+
+
+class OracledbAdaptDBAPI:
+ def __init__(self, oracledb) -> None:
+ self.oracledb = oracledb
+
+ for k, v in self.oracledb.__dict__.items():
+ if k != "connect":
+ self.__dict__[k] = v
+
+ def connect(self, *arg, **kw):
+ creator_fn = kw.pop("async_creator_fn", self.oracledb.connect_async)
+ return AsyncAdapt_oracledb_connection(
+ self, await_(creator_fn(*arg, **kw))
+ )
+
+
+class OracleDialectAsync_oracledb(OracleDialect_oracledb):
+ is_async = True
+ supports_statement_cache = True
+
+ _min_version = (2,)
+
+ # thick_mode mode is not supported by asyncio, oracledb will raise
+ @classmethod
+ def import_dbapi(cls):
+ import oracledb
+
+ return OracledbAdaptDBAPI(oracledb)
+
+ def get_driver_connection(self, connection):
+ return connection._connection
+
+
dialect = OracleDialect_oracledb
+dialect_async = OracleDialectAsync_oracledb
if TYPE_CHECKING:
from typing import Iterable
+ from psycopg import AsyncConnection
+
logger = logging.getLogger("sqlalchemy.dialects.postgresql")
class AsyncAdapt_psycopg_connection(AsyncAdapt_dbapi_connection):
+ _connection: AsyncConnection
__slots__ = ()
_cursor_cls = AsyncAdapt_psycopg_cursor
:meth:`.AsyncConnection.stream_scalars`
"""
+ if not self.dialect.supports_server_side_cursors:
+ raise exc.InvalidRequestError(
+ "Cant use `stream` or `stream_scalars` with the current "
+ "dialect since it does not support server side cursors."
+ )
result = await greenlet_spawn(
self._proxied.execute,
]
for url_obj, dialect in urls_plus_dialects:
- backend_to_driver_we_already_have[dialect.name].add(dialect.driver)
+ # use get_driver_name instead of dialect.driver to account for
+ # "_async" virtual drivers like oracledb and psycopg
+ driver_name = url_obj.get_driver_name()
+ backend_to_driver_we_already_have[dialect.name].add(driver_name)
backend_to_driver_we_need = {}
oracle = oracle+cx_oracle://scott:tiger@oracle18c/xe
cxoracle = oracle+cx_oracle://scott:tiger@oracle18c/xe
oracledb = oracle+oracledb://scott:tiger@oracle18c/xe
+oracledb_async = oracle+oracledb_async://scott:tiger@oracle18c/xe
docker_oracle = oracle+cx_oracle://scott:tiger@127.0.0.1:1521/?service_name=FREEPDB1
from sqlalchemy.testing import is_false
from sqlalchemy.testing import is_true
from sqlalchemy.testing.assertions import expect_raises_message
+from sqlalchemy.testing.assertions import is_
from sqlalchemy.testing.schema import Column
from sqlalchemy.testing.schema import pep435_enum
from sqlalchemy.testing.schema import Table
class OracleDbDialectTest(fixtures.TestBase):
+ __only_on__ = "oracle+oracledb"
+
def test_oracledb_version_parse(self):
dialect = oracledb.OracleDialect_oracledb()
def test_minimum_version(self):
with expect_raises_message(
exc.InvalidRequestError,
- "oracledb version 1 and above are supported",
+ r"oracledb version \(1,\) and above are supported",
):
oracledb.OracleDialect_oracledb(dbapi=Mock(version="0.1.5"))
dialect = oracledb.OracleDialect_oracledb(dbapi=Mock(version="7.1.0"))
eq_(dialect.oracledb_ver, (7, 1, 0))
+ def test_get_dialect(self):
+ u = url.URL.create("oracle://")
+ d = oracledb.OracleDialect_oracledb.get_dialect_cls(u)
+ is_(d, oracledb.OracleDialect_oracledb)
+ d = oracledb.OracleDialect_oracledb.get_async_dialect_cls(u)
+ is_(d, oracledb.OracleDialectAsync_oracledb)
+ d = oracledb.OracleDialectAsync_oracledb.get_dialect_cls(u)
+ is_(d, oracledb.OracleDialectAsync_oracledb)
+ d = oracledb.OracleDialectAsync_oracledb.get_dialect_cls(u)
+ is_(d, oracledb.OracleDialectAsync_oracledb)
+
+ def test_async_version(self):
+ e = create_engine("oracle+oracledb_async://")
+ is_true(isinstance(e.dialect, oracledb.OracleDialectAsync_oracledb))
+
class OracledbMode(fixtures.TestBase):
__backend__ = True
__only_on__ = "oracle+oracledb"
def _run_in_process(self, fn, fn_kw=None):
+ if config.db.dialect.is_async:
+ config.skip_test("thick mode unsupported in async mode")
ctx = get_context("spawn")
queue = ctx.Queue()
process = ctx.Process(
testing.db.dialect.get_isolation_level(dbapi_conn),
"READ COMMITTED",
)
+ conn.close()
def test_graceful_failure_isolation_level_not_available(self):
engine = engines.testing_engine()
from sqlalchemy.testing.schema import Column
from sqlalchemy.testing.schema import Table
from sqlalchemy.util import b
+from sqlalchemy.util.concurrency import await_
def exec_sql(conn, sql, *args, **kwargs):
for i in range(1, 11):
connection.execute(binary_table.insert(), dict(id=i, data=stream))
+ def _read_lob(self, engine, row):
+ if engine.dialect.is_async:
+ data = await_(row._mapping["data"].read())
+ bindata = await_(row._mapping["bindata"].read())
+ else:
+ data = row._mapping["data"].read()
+ bindata = row._mapping["bindata"].read()
+ return data, bindata
+
def test_lobs_without_convert(self):
engine = testing_engine(options=dict(auto_convert_lobs=False))
t = self.tables.z_test
with engine.begin() as conn:
row = conn.execute(t.select().where(t.c.id == 1)).first()
- eq_(row._mapping["data"].read(), "this is text 1")
- eq_(row._mapping["bindata"].read(), b("this is binary 1"))
+ data, bindata = self._read_lob(engine, row)
+ eq_(data, "this is text 1")
+ eq_(bindata, b("this is binary 1"))
def test_lobs_with_convert(self, connection):
t = self.tables.z_test
results = result.fetchall()
def go():
- eq_(
- [
- dict(
- id=row._mapping["id"],
- data=row._mapping["data"].read(),
- bindata=row._mapping["bindata"].read(),
- )
- for row in results
- ],
- self.data,
- )
+ actual = []
+ for row in results:
+ data, bindata = self._read_lob(engine, row)
+ actual.append(
+ dict(id=row._mapping["id"], data=data, bindata=bindata)
+ )
+ eq_(actual, self.data)
# this comes from cx_Oracle because these are raw
# cx_Oracle.Variable objects
finally:
await greenlet_spawn(conn.close)
+ @testing.combinations("stream", "stream_scalars", argnames="method")
+ @async_test
+ async def test_server_side_required_for_scalars(
+ self, async_engine, method
+ ):
+ with mock.patch.object(
+ async_engine.dialect, "supports_server_side_cursors", False
+ ):
+ async with async_engine.connect() as c:
+ with expect_raises_message(
+ exc.InvalidRequestError,
+ "Cant use `stream` or `stream_scalars` with the current "
+ "dialect since it does not support server side cursors.",
+ ):
+ if method == "stream":
+ await c.stream(select(1))
+ elif method == "stream_scalars":
+ await c.stream_scalars(select(1))
+ else:
+ testing.fail(method)
+
class AsyncCreatePoolTest(fixtures.TestBase):
@config.fixture
):
event.listen(async_engine, "checkout", mock.Mock())
+ def select1(self, engine):
+ if engine.dialect.name == "oracle":
+ return "select 1 from dual"
+ else:
+ return "select 1"
+
@async_test
async def test_sync_before_cursor_execute_engine(self, async_engine):
canary = mock.Mock()
event.listen(async_engine.sync_engine, "before_cursor_execute", canary)
+ s1 = self.select1(async_engine)
async with async_engine.connect() as conn:
sync_conn = conn.sync_connection
- await conn.execute(text("select 1"))
+ await conn.execute(text(s1))
eq_(
canary.mock_calls,
- [
- mock.call(
- sync_conn, mock.ANY, "select 1", mock.ANY, mock.ANY, False
- )
- ],
+ [mock.call(sync_conn, mock.ANY, s1, mock.ANY, mock.ANY, False)],
)
@async_test
async def test_sync_before_cursor_execute_connection(self, async_engine):
canary = mock.Mock()
+ s1 = self.select1(async_engine)
async with async_engine.connect() as conn:
sync_conn = conn.sync_connection
event.listen(
async_engine.sync_engine, "before_cursor_execute", canary
)
- await conn.execute(text("select 1"))
+ await conn.execute(text(s1))
eq_(
canary.mock_calls,
- [
- mock.call(
- sync_conn, mock.ANY, "select 1", mock.ANY, mock.ANY, False
- )
- ],
+ [mock.call(sync_conn, mock.ANY, s1, mock.ANY, mock.ANY, False)],
)
@async_test
class AsyncResultTest(EngineFixture):
+ __backend__ = True
+ __requires__ = ("server_side_cursors", "async_dialect")
+
@async_test
async def test_no_ss_cursor_w_execute(self, async_engine):
users = self.tables.users
def async_engine(self):
engine = create_engine("sqlite:///:memory:", future=True)
engine.dialect.is_async = True
- return _async_engine.AsyncEngine(engine)
+ engine.dialect.supports_server_side_cursors = True
+ with mock.patch.object(
+ engine.dialect.execution_ctx_cls,
+ "create_server_side_cursor",
+ engine.dialect.execution_ctx_cls.create_default_cursor,
+ ):
+ yield _async_engine.AsyncEngine(engine)
@async_test
@combinations(
from typing import List
from typing import Optional
-from sqlalchemy import Column
from sqlalchemy import event
from sqlalchemy import exc
from sqlalchemy import ForeignKey
from sqlalchemy.testing.assertions import not_in
from sqlalchemy.testing.entities import ComparableEntity
from sqlalchemy.testing.provision import normalize_sequence
+from sqlalchemy.testing.schema import Column
from .test_engine_py3k import AsyncFixture as _AsyncFixture
from ...orm import _fixtures
@testing.combinations("statement", "execute", argnames="location")
@async_test
+ @testing.requires.server_side_cursors
async def test_no_ss_cursor_w_execute(self, async_session, location):
User = self.classes.User
class A:
__tablename__ = "a"
- id = Column(Integer, primary_key=True)
+ id = Column(
+ Integer, primary_key=True, test_needs_autoincrement=True
+ )
b = relationship(
"B",
uselist=False,
@registry.mapped
class B:
__tablename__ = "b"
- id = Column(Integer, primary_key=True)
+ id = Column(
+ Integer, primary_key=True, test_needs_autoincrement=True
+ )
a_id = Column(ForeignKey("a.id"))
async with async_engine.begin() as conn:
return go
@testing.combinations(
- (
- "legacy_style",
- True,
- ),
- (
- "new_style",
- False,
- ),
+ ("legacy_style", True),
+ ("new_style", False),
argnames="_legacy_inactive_history_style",
id_="ia",
)
class DefaultColumnComparatorTest(
testing.AssertsCompiledSQL, fixtures.TestBase
):
- dialect = "default_enhanced"
+ dialect = __dialect__ = "default_enhanced"
@testing.combinations((operators.desc_op, desc), (operators.asc_op, asc))
def test_scalar(self, operator, compare_to):
oracle: WORKERS={env:TOX_WORKERS:-n2 --max-worker-restart=5}
oracle: ORACLE={env:TOX_ORACLE:--db oracle}
- oracle: EXTRA_ORACLE_DRIVERS={env:EXTRA_ORACLE_DRIVERS:--dbdriver cx_oracle --dbdriver oracledb}
+ oracle: EXTRA_ORACLE_DRIVERS={env:EXTRA_ORACLE_DRIVERS:--dbdriver cx_oracle --dbdriver oracledb --dbdriver oracledb_async}
sqlite: SQLITE={env:TOX_SQLITE:--db sqlite}
sqlite_file: SQLITE={env:TOX_SQLITE_FILE:--db sqlite_file}