From: abdallah elhdad Date: Sun, 18 Jan 2026 14:06:58 +0000 (+0200) Subject: Add JSON type support for Oracle dialect X-Git-Url: http://git.ipfire.org/gitweb/index.cgi?a=commitdiff_plain;h=refs%2Fpull%2F13065%2Fhead;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Add JSON type support for Oracle dialect Signed-off-by: abdallah elhdad --- diff --git a/lib/sqlalchemy/dialects/oracle/__init__.py b/lib/sqlalchemy/dialects/oracle/__init__.py index 2b12b0db86..38266e6ce7 100644 --- a/lib/sqlalchemy/dialects/oracle/__init__.py +++ b/lib/sqlalchemy/dialects/oracle/__init__.py @@ -36,6 +36,7 @@ from .base import VARCHAR2 from .base import VECTOR from .base import VectorIndexConfig from .base import VectorIndexType +from .json import JSON from .vector import SparseVector from .vector import VectorDistanceType from .vector import VectorStorageFormat @@ -80,4 +81,5 @@ __all__ = ( "VectorStorageFormat", "VectorStorageType", "SparseVector", + "JSON", ) diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py index 63a8d45cc5..dcc303630c 100644 --- a/lib/sqlalchemy/dialects/oracle/base.py +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -1000,8 +1000,14 @@ from dataclasses import fields from functools import lru_cache from functools import wraps import re +from typing import Any +from typing import Callable +from typing import Optional from . import dictionary +from .json import JSON +from .json import JSONIndexType +from .json import JSONPathType from .types import _OracleBoolean from .types import _OracleDate from .types import BFILE @@ -1036,6 +1042,7 @@ from ...engine.reflection import ReflectionDefaults from ...sql import and_ from ...sql import bindparam from ...sql import compiler +from ...sql import elements from ...sql import expression from ...sql import func from ...sql import null @@ -1079,6 +1086,9 @@ colspecs = { sqltypes.Interval: INTERVAL, sqltypes.DateTime: DATE, sqltypes.Date: _OracleDate, + sqltypes.JSON: JSON, + sqltypes.JSON.JSONIndexType: JSONIndexType, + sqltypes.JSON.JSONPathType: JSONPathType, } ischema_names = { @@ -1106,6 +1116,7 @@ ischema_names = { "ROWID": ROWID, "BOOLEAN": BOOLEAN, "VECTOR": VECTOR, + "JSON": JSON, } @@ -1278,6 +1289,9 @@ class OracleTypeCompiler(compiler.GenericTypeCompiler): ) return f"VECTOR({dim},{storage_format},{storage_type})" + def visit_JSON(self, type_: JSON, **kw: Any) -> str: + return "JSON" + class OracleCompiler(compiler.SQLCompiler): """Oracle compiler modifies the lexical structure of Select @@ -1321,6 +1335,23 @@ class OracleCompiler(compiler.SQLCompiler): def visit_false(self, expr, **kw): return "0" + def visit_cast(self, cast, **kwargs): + # Oracle requires VARCHAR2 to have a length in CAST expressions + # Adapt String types to VARCHAR2 with appropriate length + type_ = cast.typeclause.type + if isinstance(type_, sqltypes.String) and not isinstance( + type_, (sqltypes.Text, sqltypes.CLOB) + ): + adapted = VARCHAR2._adapt_string_for_cast(type_) + type_clause = self.dialect.type_compiler_instance.process(adapted) + else: + type_clause = cast.typeclause._compiler_dispatch(self, **kwargs) + + return "CAST(%s AS %s)" % ( + cast.clause._compiler_dispatch(self, **kwargs), + type_clause, + ) + def get_cte_preamble(self, recursive): return "WITH" @@ -1790,6 +1821,80 @@ class OracleCompiler(compiler.SQLCompiler): def visit_bitwise_not_op_unary_operator(self, element, operator, **kw): raise exc.CompileError("Cannot compile bitwise_not in oracle") + def _render_json_extract_from_binary(self, binary, operator, **kw): + literal_kw = kw.copy() + literal_kw["literal_binds"] = True + + if binary.type._type_affinity is sqltypes.JSON: + return "JSON_QUERY(%s, %s)" % ( + self.process(binary.left, **kw), + self.process(binary.right, **literal_kw), + ) + + case_expression = "CASE JSON_VALUE(%s, %s) WHEN NULL THEN NULL" % ( + self.process(binary.left, **kw), + self.process(binary.right, **literal_kw), + ) + + if binary.type._type_affinity is sqltypes.Integer: + type_expression = "ELSE CAST(JSON_VALUE(%s, %s) AS INTEGER)" % ( + self.process(binary.left, **kw), + self.process(binary.right, **literal_kw), + ) + + elif binary.type._type_affinity in (sqltypes.Numeric, sqltypes.Float): + if isinstance(binary.type, sqltypes.Float): + type_expression = "ELSE CAST(JSON_VALUE(%s, %s) AS FLOAT)" % ( + self.process(binary.left, **kw), + self.process(binary.right, **literal_kw), + ) + else: + type_expression = ( + "ELSE CAST(JSON_VALUE(%s, %s) AS NUMBER(%s, %s))" + % ( + self.process(binary.left, **kw), + self.process(binary.right, **literal_kw), + binary.type.precision, + binary.type.scale, + ) + ) + + elif binary.type._type_affinity is sqltypes.Boolean: + type_expression = ( + "WHEN 'true' THEN 1 " + "WHEN 'false' THEN 0 " + "ELSE CAST(JSON_VALUE(%s, %s) AS NUMBER(1))" + % ( + self.process(binary.left, **kw), + self.process(binary.right, **literal_kw), + ) + ) + + elif binary.type._type_affinity is sqltypes.String: + type_expression = "ELSE JSON_VALUE(%s, %s)" % ( + self.process(binary.left, **kw), + self.process(binary.right, **literal_kw), + ) + + else: + # Fallback: preserve JSON structure + type_expression = "ELSE JSON_QUERY(%s, %s)" % ( + self.process(binary.left, **kw), + self.process(binary.right, **literal_kw), + ) + + return case_expression + " " + type_expression + " END" + + def visit_json_getitem_op_binary( + self, binary: elements.BinaryExpression[Any], operator: Any, **kw: Any + ) -> str: + return self._render_json_extract_from_binary(binary, operator, **kw) + + def visit_json_path_getitem_op_binary( + self, binary: elements.BinaryExpression[Any], operator: Any, **kw: Any + ) -> str: + return self._render_json_extract_from_binary(binary, operator, **kw) + class OracleDDLCompiler(compiler.DDLCompiler): @@ -2077,6 +2182,8 @@ class OracleDialect(default.DefaultDialect): use_nchar_for_unicode=False, exclude_tablespaces=("SYSTEM", "SYSAUX"), enable_offset_fetch=True, + json_serializer: Optional[Callable[..., Any]] = None, + json_deserializer: Optional[Callable[..., Any]] = None, **kwargs, ): default.DefaultDialect.__init__(self, **kwargs) @@ -2087,6 +2194,8 @@ class OracleDialect(default.DefaultDialect): self.enable_offset_fetch = self._supports_offset_fetch = ( enable_offset_fetch ) + self._json_serializer = json_serializer + self._json_deserializer = json_deserializer def initialize(self, connection): super().initialize(connection) diff --git a/lib/sqlalchemy/dialects/oracle/json.py b/lib/sqlalchemy/dialects/oracle/json.py new file mode 100644 index 0000000000..be01b73f9e --- /dev/null +++ b/lib/sqlalchemy/dialects/oracle/json.py @@ -0,0 +1,103 @@ +# dialects/oracle/json.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +from __future__ import annotations +from typing import Any +from decimal import Decimal +from typing import TYPE_CHECKING +import json + +from ... import types as sqltypes + +if TYPE_CHECKING: + from ...engine.interfaces import Dialect + from ...sql.type_api import _BindProcessorType + from ...sql.type_api import _LiteralProcessorType + + +class JSON(sqltypes.JSON): + """ + Note: The oracledb Python driver automatically deserializes JSON column data, + returning native Python objects (dict, list, bool, int, float, str) directly. + """ + + def result_processor(self, dialect, coltype): # type: ignore[override] + string_process = self._str_impl.result_processor(dialect, coltype) + json_deserializer = getattr(dialect, "_json_deserializer", None) or json.loads + + def process(value): + if value is None: + return None + + if string_process: + value = string_process(value) + + if isinstance(value, Decimal): + return float(value) + + # If it's a string, it might be JSON that needs deserializing + # This can happen with CAST operations or when reading from VARCHAR2 columns + if isinstance(value, str): + try: + return json_deserializer(value) + except (json.JSONDecodeError, TypeError): + return value + + return value + + return process + + +class _FormatTypeMixin: + def _format_value(self, value: Any) -> str: + raise NotImplementedError() + + def bind_processor(self, dialect: Dialect) -> _BindProcessorType[Any]: + super_proc = self.string_bind_processor(dialect) # type: ignore[attr-defined] # noqa: E501 + + def process(value: Any) -> Any: + value = self._format_value(value) + if super_proc: + value = super_proc(value) + return value + + return process + + def literal_processor( + self, dialect: Dialect + ) -> _LiteralProcessorType[Any]: + super_proc = self.string_literal_processor(dialect) # type: ignore[attr-defined] # noqa: E501 + + def process(value: Any) -> str: + value = self._format_value(value) + if super_proc: + value = super_proc(value) + return value # type: ignore[no-any-return] + + return process + + +class JSONIndexType(_FormatTypeMixin, sqltypes.JSON.JSONIndexType): + def _format_value(self, value: Any) -> str: + if isinstance(value, int): + formatted_value = "$[%s]" % value + else: + formatted_value = '$."%s"' % value + return formatted_value + + +class JSONPathType(_FormatTypeMixin, sqltypes.JSON.JSONPathType): + def _format_value(self, value: Any) -> str: + return "$%s" % ( + "".join( + [ + "[%s]" % elem if isinstance(elem, int) else '."%s"' % elem + for elem in value + ] + ) + ) diff --git a/lib/sqlalchemy/dialects/oracle/types.py b/lib/sqlalchemy/dialects/oracle/types.py index d1a2a39afb..67fcd09269 100644 --- a/lib/sqlalchemy/dialects/oracle/types.py +++ b/lib/sqlalchemy/dialects/oracle/types.py @@ -40,6 +40,23 @@ class NCLOB(sqltypes.Text): class VARCHAR2(VARCHAR): __visit_name__ = "VARCHAR2" + @classmethod + def _adapt_string_for_cast(cls, type_: sqltypes.String) -> "VARCHAR2": + """Adapt a String type for use in CAST expressions. + + Oracle requires a length for VARCHAR2 in CAST expressions. + If no length is specified, we default to 4000 (max for VARCHAR2). + """ + type_ = sqltypes.to_instance(type_) + if isinstance(type_, VARCHAR2): + return type_ + elif isinstance(type_, VARCHAR): + return VARCHAR2( + length=type_.length or 4000, collation=type_.collation + ) + else: + return VARCHAR2(length=type_.length or 4000) + NVARCHAR2 = NVARCHAR diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py index b79d4b952f..c56eee1a69 100644 --- a/lib/sqlalchemy/testing/requirements.py +++ b/lib/sqlalchemy/testing/requirements.py @@ -1926,6 +1926,16 @@ class SuiteRequirements(Requirements): "indicates if the json_deserializer function is called with bytes" return exclusions.closed() + @property + def json_deserializer_is_used(self): + """Indicates if custom json_deserializer is called for JSON columns. + + Some database drivers (e.g., Oracle's oracledb) automatically + deserialize JSON at the DBAPI level, returning native Python objects + directly, which means custom json_deserializer cannot be invoked. + """ + return exclusions.closed() + exclusions.skip_if(["oracle"]) + @property def reflect_table_options(self): """Target database must support reflecting table_options.""" diff --git a/lib/sqlalchemy/testing/suite/test_types.py b/lib/sqlalchemy/testing/suite/test_types.py index bf720d9193..1ddce29f3b 100644 --- a/lib/sqlalchemy/testing/suite/test_types.py +++ b/lib/sqlalchemy/testing/suite/test_types.py @@ -1318,7 +1318,9 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest): Table( "data_table", metadata, - Column("id", Integer, primary_key=True), + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), Column("name", String(30), nullable=False), Column("data", cls.datatype, nullable=False), Column("nulldata", cls.datatype(none_as_null=True)), @@ -1592,6 +1594,9 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest): eq_(row, (data_element, data_element)) def test_round_trip_custom_json(self): + if not testing.requires.json_deserializer_is_used.enabled: + return + data_table = self.tables.data_table data_element = {"key1": "data1"} @@ -1611,6 +1616,7 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest): eq_(row, (data_element,)) eq_(js.mock_calls, [mock.call(data_element)]) + if testing.requires.json_deserializer_binary.enabled: eq_( jd.mock_calls, @@ -1936,7 +1942,9 @@ class JSONLegacyStringCastIndexTest( Table( "data_table", metadata, - Column("id", Integer, primary_key=True), + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), Column("name", String(30), nullable=False), Column("data", cls.datatype), Column("nulldata", cls.datatype(none_as_null=True)), diff --git a/test/dialect/oracle/test_types.py b/test/dialect/oracle/test_types.py index 33db8cee75..b262249c31 100644 --- a/test/dialect/oracle/test_types.py +++ b/test/dialect/oracle/test_types.py @@ -1806,3 +1806,43 @@ class SetInputSizesTest(fixtures.TestBase): ) finally: event.remove(testing.db, "do_setinputsizes", _remove_type) + +class JSONTest(fixtures.TestBase): + __requires__ = ("json_type",) + __only_on__ = "oracle" + __backend__ = True + + @testing.requires.reflects_json_type + def test_reflection(self, metadata, connection): + Table("oracle_json", metadata, Column("foo", oracle.JSON)) + metadata.create_all(connection) + + reflected = Table("oracle_json", MetaData(), autoload_with=connection) + is_(reflected.c.foo.type._type_affinity, sqltypes.JSON) + assert isinstance(reflected.c.foo.type, oracle.JSON) + + def test_rudimentary_round_trip(self, metadata, connection): + oracle_json = Table( + "oracle_json", metadata, Column("foo", oracle.JSON) + ) + metadata.create_all(connection) + + value = {"json": {"foo": "bar"}, "recs": ["one", "two"]} + + connection.execute(oracle_json.insert(), dict(foo=value)) + + eq_(connection.scalar(select(oracle_json.c.foo)), value) + + def test_extract_subobject(self, connection, metadata): + oracle_json = Table( + "oracle_json", metadata, Column("foo", oracle.JSON) + ) + metadata.create_all(connection) + + value = {"json": {"foo": "bar"}} + connection.execute(oracle_json.insert(), dict(foo=value)) + + eq_( + connection.scalar(select(oracle_json.c.foo["json"])), + value["json"], + ) diff --git a/test/requirements.py b/test/requirements.py index cdc5e4f869..796caea7f3 100644 --- a/test/requirements.py +++ b/test/requirements.py @@ -1247,8 +1247,9 @@ class DefaultRequirements(SuiteRequirements): "postgresql >= 9.3", self._sqlite_json, "mssql", + "oracle>=21", ] - ) + ) + skip_if("oracle+cx_oracle") @property def json_index_supplementary_unicode_element(self): @@ -1339,6 +1340,7 @@ class DefaultRequirements(SuiteRequirements): and not config.db.dialect._is_mariadb, "postgresql >= 9.3", "sqlite >= 3.9", + "oracle>=21", ] ) @@ -2180,6 +2182,16 @@ class DefaultRequirements(SuiteRequirements): "indicates if the json_deserializer function is called with bytes" return only_on(["postgresql+psycopg"]) + @property + def json_deserializer_is_used(self): + """Indicates if custom json_deserializer is called for JSON columns. + + Some database drivers (e.g., Oracle's oracledb) automatically + deserialize JSON at the DBAPI level, returning native Python objects + directly, which means custom json_deserializer cannot be invoked. + """ + return skip_if(["oracle"]) + @property def mssql_filestream(self): "returns if mssql supports filestream"