]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Add JSON type support for Oracle dialect 13065/head
authorabdallah elhdad <abdallahselhdad@gmail.com>
Sun, 18 Jan 2026 14:06:58 +0000 (16:06 +0200)
committerabdallah elhdad <abdallahselhdad@gmail.com>
Sun, 18 Jan 2026 14:06:58 +0000 (16:06 +0200)
Signed-off-by: abdallah elhdad <abdallahselhdad@gmail.com>
lib/sqlalchemy/dialects/oracle/__init__.py
lib/sqlalchemy/dialects/oracle/base.py
lib/sqlalchemy/dialects/oracle/json.py [new file with mode: 0644]
lib/sqlalchemy/dialects/oracle/types.py
lib/sqlalchemy/testing/requirements.py
lib/sqlalchemy/testing/suite/test_types.py
test/dialect/oracle/test_types.py
test/requirements.py

index 2b12b0db8615df8e5eb137aab617d3d294e3a203..38266e6ce7c2e3d5d4b9d6c59d2077afa2794c1a 100644 (file)
@@ -36,6 +36,7 @@ from .base import VARCHAR2
 from .base import VECTOR
 from .base import VectorIndexConfig
 from .base import VectorIndexType
+from .json import JSON
 from .vector import SparseVector
 from .vector import VectorDistanceType
 from .vector import VectorStorageFormat
@@ -80,4 +81,5 @@ __all__ = (
     "VectorStorageFormat",
     "VectorStorageType",
     "SparseVector",
+    "JSON",
 )
index 63a8d45cc54e601d5768a419a73abe6fdd1cee47..dcc303630c5ccfc97aa900d6d2cd5d749b9406db 100644 (file)
@@ -1000,8 +1000,14 @@ from dataclasses import fields
 from functools import lru_cache
 from functools import wraps
 import re
+from typing import Any
+from typing import Callable
+from typing import Optional
 
 from . import dictionary
+from .json import JSON
+from .json import JSONIndexType
+from .json import JSONPathType
 from .types import _OracleBoolean
 from .types import _OracleDate
 from .types import BFILE
@@ -1036,6 +1042,7 @@ from ...engine.reflection import ReflectionDefaults
 from ...sql import and_
 from ...sql import bindparam
 from ...sql import compiler
+from ...sql import elements
 from ...sql import expression
 from ...sql import func
 from ...sql import null
@@ -1079,6 +1086,9 @@ colspecs = {
     sqltypes.Interval: INTERVAL,
     sqltypes.DateTime: DATE,
     sqltypes.Date: _OracleDate,
+    sqltypes.JSON: JSON,
+    sqltypes.JSON.JSONIndexType: JSONIndexType,
+    sqltypes.JSON.JSONPathType: JSONPathType,
 }
 
 ischema_names = {
@@ -1106,6 +1116,7 @@ ischema_names = {
     "ROWID": ROWID,
     "BOOLEAN": BOOLEAN,
     "VECTOR": VECTOR,
+    "JSON": JSON,
 }
 
 
@@ -1278,6 +1289,9 @@ class OracleTypeCompiler(compiler.GenericTypeCompiler):
         )
         return f"VECTOR({dim},{storage_format},{storage_type})"
 
+    def visit_JSON(self, type_: JSON, **kw: Any) -> str:
+        return "JSON"
+
 
 class OracleCompiler(compiler.SQLCompiler):
     """Oracle compiler modifies the lexical structure of Select
@@ -1321,6 +1335,23 @@ class OracleCompiler(compiler.SQLCompiler):
     def visit_false(self, expr, **kw):
         return "0"
 
+    def visit_cast(self, cast, **kwargs):
+        # Oracle requires VARCHAR2 to have a length in CAST expressions
+        # Adapt String types to VARCHAR2 with appropriate length
+        type_ = cast.typeclause.type
+        if isinstance(type_, sqltypes.String) and not isinstance(
+            type_, (sqltypes.Text, sqltypes.CLOB)
+        ):
+            adapted = VARCHAR2._adapt_string_for_cast(type_)
+            type_clause = self.dialect.type_compiler_instance.process(adapted)
+        else:
+            type_clause = cast.typeclause._compiler_dispatch(self, **kwargs)
+
+        return "CAST(%s AS %s)" % (
+            cast.clause._compiler_dispatch(self, **kwargs),
+            type_clause,
+        )
+
     def get_cte_preamble(self, recursive):
         return "WITH"
 
@@ -1790,6 +1821,80 @@ class OracleCompiler(compiler.SQLCompiler):
     def visit_bitwise_not_op_unary_operator(self, element, operator, **kw):
         raise exc.CompileError("Cannot compile bitwise_not in oracle")
 
+    def _render_json_extract_from_binary(self, binary, operator, **kw):
+        literal_kw = kw.copy()
+        literal_kw["literal_binds"] = True
+
+        if binary.type._type_affinity is sqltypes.JSON:
+            return "JSON_QUERY(%s, %s)" % (
+                self.process(binary.left, **kw),
+                self.process(binary.right, **literal_kw),
+            )
+
+        case_expression = "CASE JSON_VALUE(%s, %s) WHEN NULL THEN NULL" % (
+            self.process(binary.left, **kw),
+            self.process(binary.right, **literal_kw),
+        )
+
+        if binary.type._type_affinity is sqltypes.Integer:
+            type_expression = "ELSE CAST(JSON_VALUE(%s, %s) AS INTEGER)" % (
+                self.process(binary.left, **kw),
+                self.process(binary.right, **literal_kw),
+            )
+
+        elif binary.type._type_affinity in (sqltypes.Numeric, sqltypes.Float):
+            if isinstance(binary.type, sqltypes.Float):
+                type_expression = "ELSE CAST(JSON_VALUE(%s, %s) AS FLOAT)" % (
+                    self.process(binary.left, **kw),
+                    self.process(binary.right, **literal_kw),
+                )
+            else:
+                type_expression = (
+                    "ELSE CAST(JSON_VALUE(%s, %s) AS NUMBER(%s, %s))"
+                    % (
+                        self.process(binary.left, **kw),
+                        self.process(binary.right, **literal_kw),
+                        binary.type.precision,
+                        binary.type.scale,
+                    )
+                )
+
+        elif binary.type._type_affinity is sqltypes.Boolean:
+            type_expression = (
+                "WHEN 'true' THEN 1 "
+                "WHEN 'false' THEN 0 "
+                "ELSE CAST(JSON_VALUE(%s, %s) AS NUMBER(1))"
+                % (
+                    self.process(binary.left, **kw),
+                    self.process(binary.right, **literal_kw),
+                )
+            )
+
+        elif binary.type._type_affinity is sqltypes.String:
+            type_expression = "ELSE JSON_VALUE(%s, %s)" % (
+                self.process(binary.left, **kw),
+                self.process(binary.right, **literal_kw),
+            )
+
+        else:
+            # Fallback: preserve JSON structure
+            type_expression = "ELSE JSON_QUERY(%s, %s)" % (
+                self.process(binary.left, **kw),
+                self.process(binary.right, **literal_kw),
+            )
+
+        return case_expression + " " + type_expression + " END"
+
+    def visit_json_getitem_op_binary(
+        self, binary: elements.BinaryExpression[Any], operator: Any, **kw: Any
+    ) -> str:
+        return self._render_json_extract_from_binary(binary, operator, **kw)
+
+    def visit_json_path_getitem_op_binary(
+        self, binary: elements.BinaryExpression[Any], operator: Any, **kw: Any
+    ) -> str:
+        return self._render_json_extract_from_binary(binary, operator, **kw)
+
 
 class OracleDDLCompiler(compiler.DDLCompiler):
 
@@ -2077,6 +2182,8 @@ class OracleDialect(default.DefaultDialect):
         use_nchar_for_unicode=False,
         exclude_tablespaces=("SYSTEM", "SYSAUX"),
         enable_offset_fetch=True,
+        json_serializer: Optional[Callable[..., Any]] = None,
+        json_deserializer: Optional[Callable[..., Any]] = None,
         **kwargs,
     ):
         default.DefaultDialect.__init__(self, **kwargs)
@@ -2087,6 +2194,8 @@ class OracleDialect(default.DefaultDialect):
         self.enable_offset_fetch = self._supports_offset_fetch = (
             enable_offset_fetch
         )
+        self._json_serializer = json_serializer
+        self._json_deserializer = json_deserializer
 
     def initialize(self, connection):
         super().initialize(connection)
diff --git a/lib/sqlalchemy/dialects/oracle/json.py b/lib/sqlalchemy/dialects/oracle/json.py
new file mode 100644 (file)
index 0000000..be01b73
--- /dev/null
@@ -0,0 +1,103 @@
+# dialects/oracle/json.py
+# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+# mypy: ignore-errors
+
+from __future__ import annotations
+from typing import Any
+from decimal import Decimal
+from typing import TYPE_CHECKING
+import json
+
+from ... import types as sqltypes
+
+if TYPE_CHECKING:
+    from ...engine.interfaces import Dialect
+    from ...sql.type_api import _BindProcessorType
+    from ...sql.type_api import _LiteralProcessorType
+
+
+class JSON(sqltypes.JSON):
+    """
+    Note: The oracledb Python driver automatically deserializes JSON column data,
+    returning native Python objects (dict, list, bool, int, float, str) directly.
+    """
+
+    def result_processor(self, dialect, coltype):  # type: ignore[override]
+        string_process = self._str_impl.result_processor(dialect, coltype)
+        json_deserializer = getattr(dialect, "_json_deserializer", None) or json.loads
+
+        def process(value):
+            if value is None:
+                return None
+
+            if string_process:
+                value = string_process(value)
+
+            if isinstance(value, Decimal):
+                return float(value)
+
+            # If it's a string, it might be JSON that needs deserializing
+            # This can happen with CAST operations or when reading from VARCHAR2 columns
+            if isinstance(value, str):
+                try:
+                    return json_deserializer(value)
+                except (json.JSONDecodeError, TypeError):
+                    return value
+
+            return value
+
+        return process
+
+
+class _FormatTypeMixin:
+    def _format_value(self, value: Any) -> str:
+        raise NotImplementedError()
+
+    def bind_processor(self, dialect: Dialect) -> _BindProcessorType[Any]:
+        super_proc = self.string_bind_processor(dialect)  # type: ignore[attr-defined]  # noqa: E501
+
+        def process(value: Any) -> Any:
+            value = self._format_value(value)
+            if super_proc:
+                value = super_proc(value)
+            return value
+
+        return process
+
+    def literal_processor(
+        self, dialect: Dialect
+    ) -> _LiteralProcessorType[Any]:
+        super_proc = self.string_literal_processor(dialect)  # type: ignore[attr-defined]  # noqa: E501
+
+        def process(value: Any) -> str:
+            value = self._format_value(value)
+            if super_proc:
+                value = super_proc(value)
+            return value  # type: ignore[no-any-return]
+
+        return process
+
+
+class JSONIndexType(_FormatTypeMixin, sqltypes.JSON.JSONIndexType):
+    def _format_value(self, value: Any) -> str:
+        if isinstance(value, int):
+            formatted_value = "$[%s]" % value
+        else:
+            formatted_value = '$."%s"' % value
+        return formatted_value
+
+
+class JSONPathType(_FormatTypeMixin, sqltypes.JSON.JSONPathType):
+    def _format_value(self, value: Any) -> str:
+        return "$%s" % (
+            "".join(
+                [
+                    "[%s]" % elem if isinstance(elem, int) else '."%s"' % elem
+                    for elem in value
+                ]
+            )
+        )
index d1a2a39afb7b1356d50b90e2397d710a8ca5cb27..67fcd092693958ed5cca9c38f3ca77144c17766a 100644 (file)
@@ -40,6 +40,23 @@ class NCLOB(sqltypes.Text):
 class VARCHAR2(VARCHAR):
     __visit_name__ = "VARCHAR2"
 
+    @classmethod
+    def _adapt_string_for_cast(cls, type_: sqltypes.String) -> "VARCHAR2":
+        """Adapt a String type for use in CAST expressions.
+
+        Oracle requires a length for VARCHAR2 in CAST expressions.
+        If no length is specified, we default to 4000 (max for VARCHAR2).
+        """
+        type_ = sqltypes.to_instance(type_)
+        if isinstance(type_, VARCHAR2):
+            return type_
+        elif isinstance(type_, VARCHAR):
+            return VARCHAR2(
+                length=type_.length or 4000, collation=type_.collation
+            )
+        else:
+            return VARCHAR2(length=type_.length or 4000)
+
 
 NVARCHAR2 = NVARCHAR
 
index b79d4b952f836b50747a3411bdcda018bab3df5c..c56eee1a69272951a7e3197ef446360d924638af 100644 (file)
@@ -1926,6 +1926,16 @@ class SuiteRequirements(Requirements):
         "indicates if the json_deserializer function is called with bytes"
         return exclusions.closed()
 
+    @property
+    def json_deserializer_is_used(self):
+        """Indicates if custom json_deserializer is called for JSON columns.
+        
+        Some database drivers (e.g., Oracle's oracledb) automatically
+        deserialize JSON at the DBAPI level, returning native Python objects
+        directly, which means custom json_deserializer cannot be invoked.
+        """
+        return exclusions.closed() + exclusions.skip_if(["oracle"])
+
     @property
     def reflect_table_options(self):
         """Target database must support reflecting table_options."""
index bf720d9193dd88b8a50239d4a47fc81cabd38312..1ddce29f3b46fa88ff4c2843ba25ab8344c6a8ea 100644 (file)
@@ -1318,7 +1318,9 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest):
         Table(
             "data_table",
             metadata,
-            Column("id", Integer, primary_key=True),
+            Column(
+                "id", Integer, primary_key=True, test_needs_autoincrement=True
+            ),
             Column("name", String(30), nullable=False),
             Column("data", cls.datatype, nullable=False),
             Column("nulldata", cls.datatype(none_as_null=True)),
@@ -1592,6 +1594,9 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest):
             eq_(row, (data_element, data_element))
 
     def test_round_trip_custom_json(self):
+        if not testing.requires.json_deserializer_is_used.enabled:
+            return
+
         data_table = self.tables.data_table
         data_element = {"key1": "data1"}
 
@@ -1611,6 +1616,7 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest):
 
             eq_(row, (data_element,))
             eq_(js.mock_calls, [mock.call(data_element)])
+
             if testing.requires.json_deserializer_binary.enabled:
                 eq_(
                     jd.mock_calls,
@@ -1936,7 +1942,9 @@ class JSONLegacyStringCastIndexTest(
         Table(
             "data_table",
             metadata,
-            Column("id", Integer, primary_key=True),
+            Column(
+                "id", Integer, primary_key=True, test_needs_autoincrement=True
+            ),
             Column("name", String(30), nullable=False),
             Column("data", cls.datatype),
             Column("nulldata", cls.datatype(none_as_null=True)),
index 33db8cee75472b4cc428f08dc1111ecaa4ba6ca4..b262249c31509bea25ecad3337263ef1dd714f1a 100644 (file)
@@ -1806,3 +1806,43 @@ class SetInputSizesTest(fixtures.TestBase):
             )
         finally:
             event.remove(testing.db, "do_setinputsizes", _remove_type)
+
+class JSONTest(fixtures.TestBase):
+    __requires__ = ("json_type",)
+    __only_on__ = "oracle"
+    __backend__ = True
+
+    @testing.requires.reflects_json_type
+    def test_reflection(self, metadata, connection):
+        Table("oracle_json", metadata, Column("foo", oracle.JSON))
+        metadata.create_all(connection)
+
+        reflected = Table("oracle_json", MetaData(), autoload_with=connection)
+        is_(reflected.c.foo.type._type_affinity, sqltypes.JSON)
+        assert isinstance(reflected.c.foo.type, oracle.JSON)
+
+    def test_rudimentary_round_trip(self, metadata, connection):
+        oracle_json = Table(
+            "oracle_json", metadata, Column("foo", oracle.JSON)
+        )
+        metadata.create_all(connection)
+
+        value = {"json": {"foo": "bar"}, "recs": ["one", "two"]}
+
+        connection.execute(oracle_json.insert(), dict(foo=value))
+
+        eq_(connection.scalar(select(oracle_json.c.foo)), value)
+
+    def test_extract_subobject(self, connection, metadata):
+        oracle_json = Table(
+            "oracle_json", metadata, Column("foo", oracle.JSON)
+        )
+        metadata.create_all(connection)
+
+        value = {"json": {"foo": "bar"}}
+        connection.execute(oracle_json.insert(), dict(foo=value))
+
+        eq_(
+            connection.scalar(select(oracle_json.c.foo["json"])),
+            value["json"],
+        )
index cdc5e4f869e0317e3d5ff8a5f5ecd776170f325b..796caea7f31bbae9d04cfcaf54dd7c86cc00ca42 100644 (file)
@@ -1247,8 +1247,9 @@ class DefaultRequirements(SuiteRequirements):
                 "postgresql >= 9.3",
                 self._sqlite_json,
                 "mssql",
+                "oracle>=21",
             ]
-        )
+        ) + skip_if("oracle+cx_oracle")
 
     @property
     def json_index_supplementary_unicode_element(self):
@@ -1339,6 +1340,7 @@ class DefaultRequirements(SuiteRequirements):
                 and not config.db.dialect._is_mariadb,
                 "postgresql >= 9.3",
                 "sqlite >= 3.9",
+                "oracle>=21",
             ]
         )
 
@@ -2180,6 +2182,16 @@ class DefaultRequirements(SuiteRequirements):
         "indicates if the json_deserializer function is called with bytes"
         return only_on(["postgresql+psycopg"])
 
+    @property
+    def json_deserializer_is_used(self):
+        """Indicates if custom json_deserializer is called for JSON columns.
+        
+        Some database drivers (e.g., Oracle's oracledb) automatically
+        deserialize JSON at the DBAPI level, returning native Python objects
+        directly, which means custom json_deserializer cannot be invoked.
+        """
+        return skip_if(["oracle"])
+
     @property
     def mssql_filestream(self):
         "returns if mssql supports filestream"