]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Fixes: sqlalchemy#12317 Added vector datatype support in Oracle dialect
authorsuraj <suraj.shaw@oracle.com>
Thu, 10 Apr 2025 14:24:07 +0000 (19:54 +0530)
committersuraj <suraj.shaw@oracle.com>
Tue, 15 Apr 2025 12:27:24 +0000 (17:57 +0530)
included vector_Storage

Fixes: sqlalchemy#12317 Added vector datatype support in Oracle dialect
doc/build/dialects/oracle.rst
lib/sqlalchemy/dialects/oracle/__init__.py
lib/sqlalchemy/dialects/oracle/base.py
lib/sqlalchemy/dialects/oracle/oracledb.py
lib/sqlalchemy/dialects/oracle/types.py
lib/sqlalchemy/dialects/oracle/vector.py [new file with mode: 0644]
test/dialect/oracle/test_types.py

index b3d44858ced703749d2da27ca470c77a5087aafb..c054a92211bd5809712ed5cac8269c6f68a2a734 100644 (file)
@@ -31,6 +31,7 @@ originate from :mod:`sqlalchemy.types` or from the local dialect::
         TIMESTAMP,
         VARCHAR,
         VARCHAR2,
+        VECTOR,
     )
 
 .. versionadded:: 1.2.19 Added :class:`_types.NCHAR` to the list of datatypes
@@ -80,6 +81,9 @@ construction arguments, are as follows:
 .. autoclass:: TIMESTAMP
   :members: __init__
 
+.. autoclass:: VECTOR
+  :members: __init__
+
 .. _oracledb:
 
 python-oracledb
index c05d8bdf872961b15bcf8ad967ee9807cd23798e..2265de033c93236c166f3e91c3dcb00770d5d7d9 100644 (file)
@@ -33,6 +33,10 @@ from .base import TIMESTAMP
 from .base import VARCHAR
 from .base import VARCHAR2
 from .base import VECTOR
+from .base import VectorIndexConfig
+from .base import VectorIndexType
+from .vector import VectorDistanceType
+from .vector import VectorStorageFormat
 
 # Alias oracledb also as oracledb_async
 oracledb_async = type(
@@ -66,4 +70,8 @@ __all__ = (
     "ROWID",
     "REAL",
     "VECTOR",
+    "VectorDistanceType",
+    "VectorIndexType",
+    "VectorIndexConfig",
+    "VectorStorageFormat",
 )
index 131a341bb432c453cf2f6113ba24284db74f543e..7b57a90e933b1bb3eda4dfb7e133f18ab9e043ea 100644 (file)
@@ -744,11 +744,140 @@ The ``oracle_compress`` parameter accepts either an integer specifying the
 number of prefix columns to compress, or ``True`` to use the default (all
 columns for non-unique indexes, all but the last column for unique indexes).
 
+VECTOR Datatype
+---------------
+
+Oracle Database 23ai introduced a new VECTOR datatype for artificial intelligence and machine
+learning search operations. The VECTOR datatype is a homogeneous array of 8-bit signed integers,
+8-bit unsigned integers, 32-bit floating-point numbers, or 64-bit floating-point numbers.
+For more information on the VECTOR datatype please visit this `link
+<https://python-oracledb.readthedocs.io/en/latest/user_guide/vector_data_type.html>`_.
+
+CREATE TABLE
+~~~~~~~~~~~~
+
+With the :class:`.VECTOR` datatype, you can specify the dimension for the data and the storage
+format. Valid values for storage format are enum values from :class:`.VectorStorageFormat`
+(:attr:`.VectorStorageFormat.INT8`, :attr:`.VectorStorageFormat.BINARY`, :attr:`.VectorStorageFormat.FLOAT32`,
+:attr:`.VectorStorageFormat.FLOAT64`).
+To create a table that includes a :class:`.VECTOR` column::
+
+    from sqlalchemy.dialects.oracle import VECTOR, VectorStorageFormat
+
+    t = Table("t1", metadata,
+        Column('id', Integer, primary_key=True),
+        Column("embedding", VECTOR(dim=3, storage_format=VectorStorageFormat.FLOAT32),
+        Column(...), ...
+    )
+
+Vectors can also be defined with an arbitrary number of dimensions and formats. This allows
+you to specify vectors of different dimensions with the various storage formats mentioned above.
+
+For Example
+
+* In this case, the storage format is flexible, allowing any vector type data to be inserted,
+  such as INT8 or BINARY etc.
+
+    vector_col:Mapped[array.array] = mapped_column(VECTOR(dim=3))
+
+* The dimension is flexible in this case, meaning that any dimension vector can be used.
+
+    vector_col:Mapped[array.array] = mapped_column(VECTOR(storage_format=VectorStorageType.INT8))
+
+* Both the dimensions and the storage format are flexible.
+
+    vector_col:Mapped[array.array] = mapped_column(VECTOR)
+
+INSERT VECTOR DATA
+~~~~~~~~~~~~~~~~~~
+
+VECTOR data can be inserted using Python list or Python array.array() objects. Python arrays of type
+FLOAT (32-bit), DOUBLE (64-bit), or INT (8-bit signed integer) are used as bind values when
+inserting VECTOR columns::
+
+    from sqlalchemy import insert, select
+    import array
+
+    vector_data_8 = [1, 2, 3]
+    statement  = insert(t1)
+    with engine.connect() as conn:
+        conn.execute(statement,[
+            {"id":1,"embedding":vector_data_8},
+            ])
+
+VECTOR INDEXES
+~~~~~~~~~~~~~~
+
+There are two VECTOR indexes supported in VECTOR search: IVF Flat index and HNSW
+index.
+
+To utilize VECTOR indexing, set the `oracle_vector` parameter to True to use
+the default values provided by Oracle. HNSW is the default indexing method::
+
+    Index(
+            'vector_index',
+            t1.c.embedding,
+            oracle_vector = True,
+        )
+
+If you wish to use custom parameters, you can specify all the parameters using the VectorIndexConfig
+Dataclass in the `oracle_vector` option. To learn more about the parameters that can be passed please
+visit this `link. <https://docs.oracle.com/en/database/oracle/oracle-database/23/sqlrf/create-vector-index.html>`_
+
+    Index(
+            'hnsw_vector_index',
+            t1.c.embedding,
+            oracle_vector = VectorIndexConfig(
+                index_type = VectorIndexType.HNSW,
+                distance = VectorDistanceType.COSINE,
+                accuracy = 90,
+                hnsw_neighbors = 5,
+                hnsw_efconstruction = 20,
+                parallel = 10
+            )
+        )
+
+    Index(
+            'ivf_vector_index',
+            t1.c.embedding,
+            oracle_vector = VectorIndexConfig(
+                index_type = VectorIndexType.IVF,
+                distance = VectorDistanceType.DOT,
+                accuracy = 90,
+                ivf_neighbor_partitions = 5,
+            )
+        )
+
+Similarity Searching
+~~~~~~~~~~~~~~~~~~~~
+
+You can  use the following shorthand VECTOR distance functions:
+
+* ``l2_distance``
+* ``cosine_distance``
+* ``inner_product``
+
+Example Usage::
+
+    from sqlalchemy.orm import Session
+    from sqlalchemy.sql import func
+    import array
+
+    session = Session(bind=engine)
+    query_vector = [2,3,4]
+    result_vector = session.scalars(select(t1).order_by(t1.embedding.l2_distance(query_vector)).limit(3))
+
+    for user in vector:
+        print(user.id,user.embedding)
+
+.. versionadded:: 2.1.0 Added support for VECTOR specific to Oracle Database.
+
 """  # noqa
 
 from __future__ import annotations
 
 from collections import defaultdict
+from dataclasses import fields
 from functools import lru_cache
 from functools import wraps
 import re
@@ -771,7 +900,9 @@ from .types import RAW
 from .types import ROWID  # noqa
 from .types import TIMESTAMP
 from .types import VARCHAR2  # noqa
-from .types import VECTOR
+from .vector import VECTOR
+from .vector import VectorIndexConfig
+from .vector import VectorIndexType
 from ... import Computed
 from ... import exc
 from ... import schema as sa_schema
@@ -1011,13 +1142,13 @@ class OracleTypeCompiler(compiler.GenericTypeCompiler):
 
     def visit_VECTOR(self, type_, **kw):
         if type_.dim is None and type_.storage_format is None:
-            return f"VECTOR"
+            return "VECTOR(*,*)"
         elif type_.storage_format is None:
             return f"VECTOR({type_.dim},*)"
         elif type_.dim is None:
-            return f"VECTOR(*, {type_.storage_format})"
+            return f"VECTOR(*,{type_.storage_format.value})"
         else:
-            return f"VECTOR({type_.dim},{type_.storage_format})"
+            return f"VECTOR({type_.dim},{type_.storage_format.value})"
 
 
 class OracleCompiler(compiler.SQLCompiler):
@@ -1505,6 +1636,48 @@ class OracleCompiler(compiler.SQLCompiler):
 
 
 class OracleDDLCompiler(compiler.DDLCompiler):
+
+    def _build_vector_index_config(
+        self, vector_index_config: VectorIndexConfig
+    ) -> str:
+        parts = []
+        sql_param_name = {
+            "hnsw_neighbors": "neighbors",
+            "hnsw_efconstruction": "efconstruction",
+            "ivf_neighbor_partitions": "neighbor partitions",
+            "ivf_sample_per_partition": "sample_per_partition",
+            "ivf_min_vectors_per_partition": "min_vectors_per_partition",
+        }
+        if vector_index_config.index_type == VectorIndexType.HNSW:
+            parts.append("ORGANIZATION INMEMORY NEIGHBOR GRAPH")
+        elif vector_index_config.index_type == VectorIndexType.IVF:
+            parts.append("ORGANIZATION NEIGHBOR PARTITIONS")
+        if vector_index_config.distance is not None:
+            parts.append(f"DISTANCE {vector_index_config.distance.value}")
+
+        if vector_index_config.accuracy is not None:
+            parts.append(
+                f"WITH TARGET ACCURACY {vector_index_config.accuracy}"
+            )
+
+        parameters_str = [f"type {vector_index_config.index_type.name}"]
+        prefix = vector_index_config.index_type.name.lower() + "_"
+
+        for field in fields(vector_index_config):
+            if field.name.startswith(prefix):
+                key = sql_param_name.get(field.name)
+                value = getattr(vector_index_config, field.name)
+                if value is not None:
+                    parameters_str.append(f"{key} {value}")
+
+        parameters_str = ", ".join(parameters_str)
+        parts.append(f"PARAMETERS ({parameters_str})")
+
+        if vector_index_config.parallel is not None:
+            parts.append(f"PARALLEL {vector_index_config.parallel}")
+
+        return " ".join(parts)
+
     def define_constraint_cascades(self, constraint):
         text = ""
         if constraint.ondelete is not None:
@@ -1559,43 +1732,9 @@ class OracleDDLCompiler(compiler.DDLCompiler):
                 )
         if vector_options:
             if vector_options is True:
-                vector_options = {}
-            parts = []
-            parameters = vector_options.get("parameters", {})
-            using = parameters.get("type", "HNSW").upper()
-            if using == "HNSW":
-                parts.append("ORGANIZATION INMEMORY NEIGHBOR GRAPH")
-            elif using == "IVF":
-                parts.append("ORGANIZATION NEIGHBOR PARTITIONS")
-            vector_distance = vector_options.get("distance")
-            if vector_distance is not None:
-                vector_distance = vector_distance.upper()
-                if vector_distance not in (
-                    "EUCLIDEAN",
-                    "DOT",
-                    "COSINE",
-                    "MANHATTAN",
-                ):
-                    raise ValueError("Unknown vector_distance value")
-                parts.append(f"DISTANCE {vector_distance}")
-            target_accuracy = vector_options.get("accuracy")
-            if target_accuracy is not None:
-                if target_accuracy < 0 or target_accuracy > 100:
-                    raise ValueError(
-                        "Accuracy value should be an integer between 0 and 100"
-                    )
-                parts.append(f"WITH TARGET ACCURACY {target_accuracy}")
-            if parameters:
-                parameters_str = ", ".join(
-                    f"{k} {v}" for k, v in parameters.items()
-                )
-                parts.append(f"PARAMETERS ({parameters_str})")
-            parallel = vector_options.get("parallel")
-            if parallel is not None:
-                if not isinstance(parallel, int):
-                    raise ValueError("Parallel value must be an integer")
-                parts.append(f"PARALLEL {parallel}")
-            text += " " + " ".join(parts)
+                vector_options = VectorIndexConfig()
+
+            text += " " + self._build_vector_index_config(vector_options)
         return text
 
     def post_create_table(self, table):
index 201902f997f04a9db96c474289fccd673c9d30b7..8105608837f752274ae22de030f9032856af9e61 100644 (file)
@@ -591,228 +591,6 @@ SQLAlchemy type (or a subclass of such).
 
 .. versionadded:: 2.0.0 added support for the python-oracledb driver.
 
-VECTOR Datatype
----------------
-
-Oracle Database 23ai introduced a new VECTOR datatype for artificial intelligence and machine
-learning search operations. The VECTOR datatype is a homogeneous array of 8-bit signed integers,
-8-bit unsigned integers, 32-bit floating-point numbers, or 64-bit floating-point numbers.
-For more information on the VECTOR datatype please visit this `link.
-<https://python-oracledb.readthedocs.io/en/latest/user_guide/vector_data_type.html>`_
-
-CREATE TABLE
-^^^^^^^^^^^^
-
-With the VECTOR datatype, you can specify the dimension for the data and the storage
-format. To create a table that includes a VECTOR column::
-
-    from sqlalchemy.dialects.oracle import VECTOR
-
-    t = Table("t1", metadata,
-        Column('id', Integer, primary_key=True),
-        Column("embedding", VECTOR(3, 'float32'),
-        Column(...), ...
-    )
-
-Vectors can also be defined with an arbitrary number of dimensions and formats. This allows
-you to specify vectors of different dimensions with the various storage formats mentioned above.
-
-For Example
-
-* In this case, the storage format is flexible, allowing any vector type data to be inserted,
-  such as int8 or binary etc.
-    
-    vector_col:Mapped[array.array] = mapped_column(VECTOR(3))
-
-* The dimension is flexible in this case, meaning that any dimension vector can be used.
-   
-    vector_col:Mapped[array.array] = mapped_column(VECTOR('int8'))
-
-* Both the dimensions and the storage format are flexible.
-
-    vector_col:Mapped[array.array] = mapped_column(VECTOR)
-
-INSERT VECTOR DATA
-^^^^^^^^^^^^^^^^^^
-
-VECTOR data can be inserted using Python list or Python array.array() objects. Python arrays of type
-float (32-bit), double (64-bit), or int8_t (8-bit signed integer) are used as bind values when
-inserting VECTOR columns::
-
-    from sqlalchemy import insert, select
-    import array
-
-    vector_data_8 = [1, 2, 3]
-    statement  = insert(t1)
-    with engine.connect() as conn:
-        conn.execute(statement,[
-            {"id":1,"embedding":vector_data_8},
-            ])
-    
-VECTOR INDEXES
-^^^^^^^^^^^^^^
-
-There are two VECTOR indexes supported in VECTOR search: IVF Flat index and HNSW
-index.
-
-To utilize VECTOR indexing, set the `oracle_vector` parameter to True to use
-the default values provided by Oracle. HNSW is the default indexing method::
-
-    Index(
-            'vector_index',
-            t1.c.embedding,
-            oracle_vector = True,
-        )
-
-If you wish to use custom parameters, you can specify all the parameters as a dictionary
-in the `oracle_vector` option. To learn more about the parameters that can be passed please
-visit this `link. <https://docs.oracle.com/en/database/oracle/oracle-database/23/sqlrf/create-vector-index.html>`_
-
-Configuring Oracle VECTOR Indexes
-=================================
-
-When using Oracle VECTOR indexes, the configuration parameters are divided into two levels:
-**top-level keys** and **nested keys**. This structure applies to both HNSW and IVF VECTOR indexes.
-
-Top-Level Keys
-==============
-
-These keys are specified directly under the ``oracle_vector`` dictionary.
-
-* ``accuracy``:
-    - Specifies the accuracy of the nearest neighbor search during query execution.
-    - **Valid Range**: Greater than 0 and less than or equal to 100.
-    - **Example**: ``'accuracy': 95``
-
-* ``distance``:
-    - Specifies the metric for calculating distance between VECTORS.
-    - **Valid Values**: ``"EUCLIDEAN"``, ``"COSINE"``, ``"DOT"``, ``"MANHATTAN"``.
-    - **Example**: ``'distance': "COSINE"``
-
-* ``parameters``:
-    - A nested dictionary where method-specific options are defined (e.g., HNSW or IVF-specific settings).
-    - **Example**: ``'parameters': {...}``
-
-Nested keys in parameters
-=========================
-
-These keys are specific to the indexing method and are included under the ``parameters`` dictionary.
-
-HNSW Parameters
-===============
-
-* ``type``:
-    - Specifies the indexing method. For HNSW, this must be ``"HNSW"``.
-    - **Placement**: Nested under ``parameters``.
-    - **Example**: ``'type': 'HNSW'``
-
-* ``neighbors``:
-    - The number of nearest neighbors considered during the search.
-    - **Valid Range**: Greater than 0 and less than or equal to 2048.
-    - **Placement**: Nested under ``parameters``.
-    - **Example**: ``'neighbors': 20``
-
-* ``efconstruction``:
-    - Controls the trade-off between indexing speed and recall quality during index construction.
-    - **Valid Range**: Greater than 0 and less than or equal to 65535.
-    - **Placement**: Nested under ``parameters``.
-    - **Example**: ``'efconstruction': 300``
-
-IVF Parameters
-==============
-
-* ``type``:
-    - Specifies the indexing method. For IVF, this must be ``"IVF"``.
-    - **Placement**: Nested under ``parameters``.
-    - **Example**: ``'type': 'IVF'``
-   
-* ``neighbor partitions``:
-    - The number of partitions used to divide the dataset.
-    - **Valid Range**: Greater than 0 and less than or equal to 10,000,000.
-    - **Placement**: Nested under ``parameters``.
-    - **Example**: ``'neighbor partitions': 10``
-
-* ``sample_per_partition``:
-    - The number of samples used per partition.
-    - **Valid Range**: Between 1 and ``num_vectors / neighbor_partitions``.
-    - **Placement**: Nested under ``parameters``.
-    - **Example**: ``'sample_per_partition': 5``
-
-* ``min_vectors_per_partition``:
-    - The minimum number of vectors per partition.
-    - **Valid Range**: From 0 (no trimming) to the total number of vectors (results in 1 partition).
-    - **Placement**: Nested under ``parameters``.
-    - **Example**: ``'min_vectors_per_partition': 100``
-
-Example Configurations
-======================
-
-For custom configurations, the parameters can be specified as shown in the following examples::
-
-    Index(
-            'hnsw_vector_index',
-            t1.c.embedding,
-            oracle_vector = {
-                'accuracy':95,
-                'distance':"COSINE",
-                'parameters':{
-                    'type':'HNSW',
-                    'neighbors':20',
-                    efconstruction':300
-                }
-            }
-        )
-
-    Index(
-            'ivf_vector_index',
-            t1.c.embedding,
-            oracle_vector = {
-                'accuracy':90,
-                'distance':"DOT",
-                'parameters':{
-                    'type':'IVF','
-                    neighbor partitions':10
-                }
-            }
-        )
-
-Similarity Searching
-^^^^^^^^^^^^^^^^^^^^
-
-You can  use the following shorthand VECTOR distance functions:
-
-* ``l2_distance``
-* ``cosine_distance``
-* ``inner_product``
-
-Example Usage::
-
-    from sqlalchemy.orm import Session
-    from sqlalchemy.sql import func
-    import array
-  
-    session = Session(bind=engine)
-    query_vector = [2,3,4]
-    result_vector = session.scalars(select(t1).order_by(t1.embedding.l2_distance(query_vector)).limit(3))
-    for user in vector:
-        print(user.id,user.embedding)
-
-Exact and Approximate Searching
-^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-
-Similarity searches tend to get data from one or more clusters depending on the value of the query VECTOR and the fetch
-size. Approximate searches using VECTOR indexes can limit the searches to specific clusters, whereas exact searches visit
-VECTORS across all clusters.
-You can use the fetch_type clause to set the searching to be either EXACT, APPROX or APPROXIMATE::
-
-    result_vector = session.scalars(select(t1).order_by(t1.embedding.l1_distance(query_vector)).limit(3)).fetch_type("EXACT")
-    result_vector = session.scalars(select(t1).order_by(t1.embedding.l1_distance(query_vector)).limit(3)).fetch_type("APPROX")
-
-
-.. versionadded:: 2.1.0 Added support for VECTOR specific to Oracle Database.
-:class:`_oracle.VECTOR` datatype.
-
 """  # noqa
 from __future__ import annotations
 
index f128a42352c555b8c418ca0ffeef844f5c9e673e..06aeaace2f5fc7dba4f14cbab4d13822064f7701 100644 (file)
@@ -11,14 +11,11 @@ import datetime as dt
 from typing import Optional
 from typing import Type
 from typing import TYPE_CHECKING
-import sqlalchemy.types as types
-from sqlalchemy.types import UserDefinedType, Float
 
 from ... import exc
 from ...sql import sqltypes
 from ...types import NVARCHAR
 from ...types import VARCHAR
-import array
 
 if TYPE_CHECKING:
     from ...engine.interfaces import Dialect
@@ -317,82 +314,3 @@ class ROWID(sqltypes.TypeEngine):
 class _OracleBoolean(sqltypes.Boolean):
     def get_dbapi_type(self, dbapi):
         return dbapi.NUMBER
-
-
-class VECTOR(types.TypeEngine):
-    """Oracle VECTOR datatype."""
-
-    cache_ok = True
-    __visit_name__ = "VECTOR"
-
-    def __init__(self, dim=None, storage_format=None, *args):
-        """
-        :param dim: The dimension of the VECTOR datatype. This should be an
-        integer value.
-        :param storage_format: The VECTOR storage type format. This
-        may be int8, binary, float32, or float64.
-        """
-        if dim is not None and isinstance(dim, int):
-            self.dim = dim
-            self.storage_format = storage_format
-
-        elif dim is not None and isinstance(dim, str):
-            self.dim = storage_format
-            self.storage_format = dim
-
-        else:
-            self.dim = storage_format
-            self.storage_format = dim
-
-    def _cached_bind_processor(self, dialect):
-        """
-        Convert a list to a array.array before binding it to the database.
-        """
-
-        def process(value):
-            if value is None or isinstance(value, array.array):
-                return value
-
-            # Convert list to a array.array
-            elif isinstance(value, list):
-                format = self._array_typecode(self.storage_format)
-                value = array.array(format, value)
-                return value
-
-            else:
-                raise TypeError("VECTOR accepts list or array.array()")
-
-        return process
-
-    def _cached_result_processor(self, dialect, coltype):
-        """
-        Convert a array.array to list before binding it to the database.
-        """
-
-        def process(value):
-            if isinstance(value, array.array):
-                return list(value)
-
-        return process
-
-    def _array_typecode(self, format):
-        """
-        Map storage format to array typecode.
-        """
-        typecode_map = {
-            "int8": "b",  # Signed int
-            "binary": "B",  # Unsigned int
-            "float32": "f",  # Float
-            "float64": "d",  # Double
-        }
-        return typecode_map.get(format, "d")
-
-    class comparator_factory(types.TypeEngine.Comparator):
-        def l2_distance(self, other):
-            return self.op("<->", return_type=Float)(other)
-
-        def inner_product(self, other):
-            return self.op("<#>", return_type=Float)(other)
-
-        def cosine_distance(self, other):
-            return self.op("<=>", return_type=Float)(other)
diff --git a/lib/sqlalchemy/dialects/oracle/vector.py b/lib/sqlalchemy/dialects/oracle/vector.py
new file mode 100644 (file)
index 0000000..84b64cd
--- /dev/null
@@ -0,0 +1,243 @@
+# dialects/oracle/vector.py
+# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+# mypy: ignore-errors
+
+
+from __future__ import annotations
+
+import array
+from dataclasses import dataclass
+from enum import Enum
+from typing import Optional
+
+import sqlalchemy.types as types
+from sqlalchemy.types import Float
+
+
+class VectorIndexType(Enum):
+    """Enum representing different types of VECTOR index structures.
+
+    Each index type has different characteristics and configuration
+    parameters optimized for specific use cases in vector search.
+
+    """
+
+    HNSW = "HNSW"
+    """
+    The HNSW (Hierarchical Navigable Small World) index type.
+    """
+    IVF = "IVF"
+    """
+    The IVF (Inverted File Index) index type
+    """
+
+
+class VectorDistanceType(Enum):
+    """Enum representing different types of vector distance metrics."""
+
+    EUCLIDEAN = "EUCLIDEAN"
+    """Euclidean distance (L2 norm).
+
+    Measures the straight-line distance between two vectors in space.
+    """
+    DOT = "DOT"
+    """Dot product similarity.
+
+    Measures the algebraic similarity between two vectors.
+    """
+    COSINE = "COSINE"
+    """Cosine similarity.
+
+    Measures the cosine of the angle between two vectors.
+    """
+    MANHATTAN = "MANHATTAN"
+    """Manhattan distance (L1 norm).
+
+    Calculates the sum of absolute differences across dimensions.
+    """
+
+
+class VectorStorageFormat(Enum):
+    """Enum representing the data format used to store vector components.
+
+    Choosing the right format balances precision, memory usage,
+    and performance for vector indexing and search.
+    """
+
+    INT8 = "INT8"
+    """
+    8-bit integer format.
+    """
+    BINARY = "BINARY"
+    """
+    Binary format.
+    """
+    FLOAT32 = "FLOAT32"
+    """
+    32-bit floating-point format.
+    """
+    FLOAT64 = "FLOAT64"
+    """
+    64-bit floating-point format.
+    """
+
+
+@dataclass
+class VectorIndexConfig:
+    """Define the configuration for Oracle VECTOR Index.
+
+    :param index_type: Enum values from :class:`.VectorIndexType
+     Specifies the indexing method. For HNSW, this must be
+     :attr:`.VectorIndexType.HNSW`.
+
+    :param distance: Enum values from :class:`.VectorDistanceType`
+     specifies the metric for calculating distance between VECTORS.
+
+    :param accuracy: interger. Should be in the range 0 to 100
+     Specifies the accuracy of the nearest neighbor search during
+     query execution.
+
+    :param parallel: integer. Specifies degree of parallelism.
+
+    :param hnsw_neighbors: interger. Should be in the range 0 to
+     2048. Specifies the number of nearest neighbors considered
+     during the search The attribute :attr:`VectorIndexConfig.hnsw_neighbors`
+     is HNSW index specific.
+
+    :param hnsw_efconstruction: integer. Should be in the range 0
+     to 65535.  Controls the trade-off between indexing speed and
+     recall quality during index construction.The attribute
+     :attr:`VectorIndexConfig.hnsw_efconstruction` is HNSW index
+     specific.
+
+    :param ivf_neighbour_partitions: integer. Should be in the range
+     0 to 10,000,000. Specifies the number of partitions used to
+     divide the dataset. The attribute
+     :attr:`VectorIndexConfig.ivf_neighbour_partitions` is IVF index
+     specific.
+
+    :param ivf sample_per_partition: integer. Should be between 1
+     and ``num_vectors / neighbor partitions``. Specifies the
+     number of samples used per partition. The attribute
+     :attr:`VectorIndexConfig.ivf sample_per_partition` is IVF index
+     specific.
+
+    :param ivf_min_vectors_per_partition: integer. From 0 (no trimming)
+     to the total number of vectors (results in 1 partition). Specifies
+     the minimum number of vectors per partition.
+
+    """
+
+    index_type: VectorIndexType = VectorIndexType.HNSW
+    distance: Optional[VectorDistanceType] = None
+    accuracy: Optional[int] = None
+    hnsw_neighbors: Optional[int] = None
+    hnsw_efconstruction: Optional[int] = None
+    ivf_neighbor_partitions: Optional[int] = None
+    ivf_sample_per_partition: Optional[int] = None
+    ivf_min_vectors_per_partition: Optional[int] = None
+    parallel: Optional[int] = None
+
+    def __post_init__(self):
+        self.index_type = VectorIndexType(self.index_type)
+        for field in [
+            "hnsw_neighbors",
+            "hnsw_efconstruction",
+            "ivf_neighbor_partitions",
+            "ivf_sample_per_partition",
+            "ivf_min_vectors_per_partition",
+            "parallel",
+            "accuracy",
+        ]:
+            value = getattr(self, field)
+            if value is not None and not isinstance(value, int):
+                raise TypeError(
+                    f"{field} must be an integer if"
+                    f"provided, got {type(value).__name__}"
+                )
+
+
+class VECTOR(types.TypeEngine):
+    """Oracle VECTOR datatype."""
+
+    cache_ok = True
+    __visit_name__ = "VECTOR"
+
+    def __init__(self, dim=None, storage_format=None):
+        """Construct a VECTOR.
+
+        :param dim: integer. The dimension of the VECTOR datatype. This
+         should be an integer value.
+
+        :param storage_format: VectorStorageFormat. The VECTOR storage
+         type format. This may be Enum values form
+         `VectorStorageFormat` INT8, BINARY, FLOAT32, or FLOAT64.
+
+        """
+        if dim is not None and not isinstance(dim, int):
+            raise TypeError("dim must be an interger")
+        if storage_format is not None and not isinstance(
+            storage_format, VectorStorageFormat
+        ):
+            raise TypeError(
+                "storage_format must be an enum of type VectorStorageFormat"
+            )
+        self.dim = dim
+        self.storage_format = storage_format
+
+    def _cached_bind_processor(self, dialect):
+        """
+        Convert a list to a array.array before binding it to the database.
+        """
+
+        def process(value):
+            if value is None or isinstance(value, array.array):
+                return value
+
+            # Convert list to a array.array
+            elif isinstance(value, list):
+                typecode = self._array_typecode(self.storage_format)
+                value = array.array(typecode, value)
+                return value
+
+            else:
+                raise TypeError("VECTOR accepts list or array.array()")
+
+        return process
+
+    def _cached_result_processor(self, dialect, coltype):
+        """
+        Convert a array.array to list before binding it to the database.
+        """
+
+        def process(value):
+            if isinstance(value, array.array):
+                return list(value)
+
+        return process
+
+    def _array_typecode(self, typecode):
+        """
+        Map storage format to array typecode.
+        """
+        typecode_map = {
+            VectorStorageFormat.INT8: "b",  # Signed int
+            VectorStorageFormat.BINARY: "B",  # Unsigned int
+            VectorStorageFormat.FLOAT32: "f",  # Float
+            VectorStorageFormat.FLOAT64: "d",  # Double
+        }
+        return typecode_map.get(typecode, "d")
+
+    class comparator_factory(types.TypeEngine.Comparator):
+        def l2_distance(self, other):
+            return self.op("<->", return_type=Float)(other)
+
+        def inner_product(self, other):
+            return self.op("<#>", return_type=Float)(other)
+
+        def cosine_distance(self, other):
+            return self.op("<=>", return_type=Float)(other)
index cf22f20f976c5d063f758d9c323ee2b60dc74ba4..e03a8e3a6c8c75f25f41eb8dbf3882a026460966 100644 (file)
@@ -1,8 +1,8 @@
+import array
 import datetime
 import decimal
 import os
 import random
-import array
 
 from sqlalchemy import bindparam
 from sqlalchemy import cast
@@ -16,6 +16,7 @@ from sqlalchemy import event
 from sqlalchemy import exc
 from sqlalchemy import FLOAT
 from sqlalchemy import Float
+from sqlalchemy import Index
 from sqlalchemy import Integer
 from sqlalchemy import LargeBinary
 from sqlalchemy import literal
@@ -35,12 +36,14 @@ from sqlalchemy import types as sqltypes
 from sqlalchemy import Unicode
 from sqlalchemy import UnicodeText
 from sqlalchemy import VARCHAR
-from sqlalchemy import Index
-from sqlalchemy import delete
 from sqlalchemy.dialects.oracle import base as oracle
 from sqlalchemy.dialects.oracle import cx_oracle
 from sqlalchemy.dialects.oracle import oracledb
 from sqlalchemy.dialects.oracle import VECTOR
+from sqlalchemy.dialects.oracle import VectorDistanceType
+from sqlalchemy.dialects.oracle import VectorIndexConfig
+from sqlalchemy.dialects.oracle import VectorIndexType
+from sqlalchemy.dialects.oracle import VectorStorageFormat
 from sqlalchemy.sql import column
 from sqlalchemy.sql.sqltypes import NullType
 from sqlalchemy.testing import AssertsCompiledSQL
@@ -955,44 +958,45 @@ class TypesTest(fixtures.TestBase):
         finally:
             exec_sql(connection, "DROP TABLE Z_TEST")
 
+    @testing.fails_if("oracle<=23.4")
     def test_vector_dim(self, metadata, connection):
-        t1 = Table("t1", metadata, Column("c1", VECTOR(3, "float32")))
+        t1 = Table(
+            "t1",
+            metadata,
+            Column(
+                "c1", VECTOR(dim=3, storage_format=VectorStorageFormat.FLOAT32)
+            ),
+        )
 
-        if testing.against("oracle>23.4"):
-            t1.create(connection)
-            eq_(t1.c.c1.type.dim, 3)
-        else:
-            with expect_raises_message(exc.DatabaseError, "ORA-03060"):
-                t1.create(connection)
+        t1.create(connection)
+        eq_(t1.c.c1.type.dim, 3)
 
+    @testing.fails_if("oracle<=23.4")
     def test_vector_insert(self, metadata, connection):
         t1 = Table(
             "t1",
             metadata,
             Column("id", Integer, primary_key=True),
-            Column("c1", VECTOR("int8")),
+            Column("c1", VECTOR(storage_format=VectorStorageFormat.INT8)),
         )
 
-        if testing.against("oracle>23.4"):
-            t1.create(connection)
-            connection.execute(
-                t1.insert(),
-                dict(id=1, c1=[6, 7, 8, 5]),
-            )
-            eq_(
-                connection.execute(t1.select()).first(),
-                (1, [6, 7, 8, 5]),
-            )
-            connection.execute(t1.delete().where(t1.c.id == 1))
-            connection.execute(t1.insert(), dict(id=1, c1=[6, 7]))
-            eq_(
-                connection.execute(t1.select()).first(),
-                (1, [6, 7]),
-            )
-        else:
-            with expect_raises_message(exc.DatabaseError, "ORA-03060"):
-                t1.create(connection)
+        t1.create(connection)
+        connection.execute(
+            t1.insert(),
+            dict(id=1, c1=[6, 7, 8, 5]),
+        )
+        eq_(
+            connection.execute(t1.select()).first(),
+            (1, [6, 7, 8, 5]),
+        )
+        connection.execute(t1.delete().where(t1.c.id == 1))
+        connection.execute(t1.insert(), dict(id=1, c1=[6, 7]))
+        eq_(
+            connection.execute(t1.select()).first(),
+            (1, [6, 7]),
+        )
 
+    @testing.fails_if("oracle<=23.4")
     def test_vector_insert_array(self, metadata, connection):
         t1 = Table(
             "t1",
@@ -1001,30 +1005,27 @@ class TypesTest(fixtures.TestBase):
             Column("c1", VECTOR),
         )
 
-        if testing.against("oracle>23.4"):
-            t1.create(connection)
-            connection.execute(
-                t1.insert(),
-                dict(id=1, c1=array.array("b", [6, 7, 8, 5])),
-            )
-            eq_(
-                connection.execute(t1.select()).first(),
-                (1, [6, 7, 8, 5]),
-            )
+        t1.create(connection)
+        connection.execute(
+            t1.insert(),
+            dict(id=1, c1=array.array("b", [6, 7, 8, 5])),
+        )
+        eq_(
+            connection.execute(t1.select()).first(),
+            (1, [6, 7, 8, 5]),
+        )
 
-            connection.execute(t1.delete().where(t1.c.id == 1))
+        connection.execute(t1.delete().where(t1.c.id == 1))
 
-            connection.execute(
-                t1.insert(), dict(id=1, c1=array.array("b", [6, 7]))
-            )
-            eq_(
-                connection.execute(t1.select()).first(),
-                (1, [6, 7]),
-            )
-        else:
-            with expect_raises_message(exc.DatabaseError, "ORA-03060"):
-                t1.create(connection)
+        connection.execute(
+            t1.insert(), dict(id=1, c1=array.array("b", [6, 7]))
+        )
+        eq_(
+            connection.execute(t1.select()).first(),
+            (1, [6, 7]),
+        )
 
+    @testing.fails_if("oracle<=23.4")
     def test_vector_multiformat_insert(self, metadata, connection):
         t1 = Table(
             "t1",
@@ -1033,119 +1034,117 @@ class TypesTest(fixtures.TestBase):
             Column("c1", VECTOR),
         )
 
-        if testing.against("oracle>23.4"):
-            t1.create(connection)
-            connection.execute(
-                t1.insert(),
-                dict(id=1, c1=[6.12, 7.54, 8.33]),
-            )
-            eq_(
-                connection.execute(t1.select()).first(),
-                (1, [6.12, 7.54, 8.33]),
-            )
-            connection.execute(t1.delete().where(t1.c.id == 1))
-            connection.execute(t1.insert(), dict(id=1, c1=[6, 7]))
-            eq_(
-                connection.execute(t1.select()).first(),
-                (1, [6, 7]),
-            )
-        else:
-            with expect_raises_message(exc.DatabaseError, "ORA-03060"):
-                t1.create(connection)
+        t1.create(connection)
+        connection.execute(
+            t1.insert(),
+            dict(id=1, c1=[6.12, 7.54, 8.33]),
+        )
+        eq_(
+            connection.execute(t1.select()).first(),
+            (1, [6.12, 7.54, 8.33]),
+        )
+        connection.execute(t1.delete().where(t1.c.id == 1))
+        connection.execute(t1.insert(), dict(id=1, c1=[6, 7]))
+        eq_(
+            connection.execute(t1.select()).first(),
+            (1, [6, 7]),
+        )
 
+    @testing.fails_if("oracle<=23.4")
     def test_vector_format(self, metadata, connection):
-        t1 = Table("t1", metadata, Column("c1", VECTOR(3, "float32")))
+        t1 = Table(
+            "t1",
+            metadata,
+            Column(
+                "c1", VECTOR(dim=3, storage_format=VectorStorageFormat.FLOAT32)
+            ),
+        )
 
-        if testing.against("oracle>23.4"):
-            t1.create(connection)
-            eq_(t1.c.c1.type.storage_format, "float32")
-        else:
-            with expect_raises_message(exc.DatabaseError, "ORA-03060"):
-                t1.create(connection)
+        t1.create(connection)
+        eq_(t1.c.c1.type.storage_format, VectorStorageFormat.FLOAT32)
 
+    @testing.fails_if("oracle<=23.4")
     def test_vector_hnsw_index(self, metadata, connection):
         t1 = Table(
             "t1",
             metadata,
             Column("id", Integer),
-            Column("embedding", VECTOR(3, "float32")),
+            Column(
+                "embedding",
+                VECTOR(dim=3, storage_format=VectorStorageFormat.FLOAT32),
+            ),
         )
 
-        if testing.against("oracle>23.4"):
-            t1.create(connection)
+        t1.create(connection)
 
-            hnsw_index = Index(
-                "hnsw_vector_index", t1.c.embedding, oracle_vector=True
-            )
-            hnsw_index.create(connection)
+        hnsw_index = Index(
+            "hnsw_vector_index", t1.c.embedding, oracle_vector=True
+        )
+        hnsw_index.create(connection)
 
-            connection.execute(t1.insert(), dict(id=1, embedding=[6, 7, 8]))
-            eq_(
-                connection.execute(t1.select()).first(),
-                (1, [6.0, 7.0, 8.0]),
-            )
-        else:
-            with expect_raises_message(exc.DatabaseError, "ORA-03060"):
-                t1.create(connection)
+        connection.execute(t1.insert(), dict(id=1, embedding=[6, 7, 8]))
+        eq_(
+            connection.execute(t1.select()).first(),
+            (1, [6.0, 7.0, 8.0]),
+        )
 
+    @testing.fails_if("oracle<=23.4")
     def test_vector_ivf_index(self, metadata, connection):
         t1 = Table(
             "t1",
             metadata,
             Column("id", Integer),
-            Column("embedding", VECTOR(3, "float32")),
-        )
-
-        if testing.against("oracle>23.4"):
-            t1.create(connection)
-            ivf_index = Index(
-                "ivf_vector_index",
-                t1.c.embedding,
-                oracle_vector={
-                    "accuracy": 90,
-                    "distance": "DOT",
-                    "parameters": {"type": "IVF", "neighbor partitions": 10},
-                },
-            )
-            ivf_index.create(connection)
+            Column(
+                "embedding",
+                VECTOR(dim=3, storage_format=VectorStorageFormat.FLOAT32),
+            ),
+        )
 
-            connection.execute(t1.insert(), dict(id=1, embedding=[6, 7, 8]))
-            eq_(
-                connection.execute(t1.select()).first(),
-                (1, [6.0, 7.0, 8.0]),
-            )
-        else:
-            with expect_raises_message(exc.DatabaseError, "ORA-03060"):
-                t1.create(connection)
+        t1.create(connection)
+        ivf_index = Index(
+            "ivf_vector_index",
+            t1.c.embedding,
+            oracle_vector=VectorIndexConfig(
+                index_type=VectorIndexType.IVF,
+                distance=VectorDistanceType.DOT,
+                accuracy=90,
+                ivf_neighbor_partitions=5,
+            ),
+        )
+        ivf_index.create(connection)
 
+        connection.execute(t1.insert(), dict(id=1, embedding=[6, 7, 8]))
+        eq_(
+            connection.execute(t1.select()).first(),
+            (1, [6.0, 7.0, 8.0]),
+        )
+
+    @testing.fails_if("oracle<=23.4")
     def test_vector_l2_distance(self, metadata, connection):
         t1 = Table(
             "t1",
             metadata,
             Column("id", Integer),
-            Column("embedding", VECTOR(3, "int8")),
+            Column(
+                "embedding",
+                VECTOR(dim=3, storage_format=VectorStorageFormat.INT8),
+            ),
         )
 
-        if testing.against("oracle>23.4"):
-            t1.create(connection)
+        t1.create(connection)
 
-            connection.execute(t1.insert(), dict(id=1, embedding=[8, 9, 10]))
-            connection.execute(t1.insert(), dict(id=2, embedding=[1, 2, 3]))
-            connection.execute(
-                t1.insert(),
-                dict(id=3, embedding=[15, 16, 17]),
-            )
+        connection.execute(t1.insert(), dict(id=1, embedding=[8, 9, 10]))
+        connection.execute(t1.insert(), dict(id=2, embedding=[1, 2, 3]))
+        connection.execute(
+            t1.insert(),
+            dict(id=3, embedding=[15, 16, 17]),
+        )
 
-            query_vector = [2, 3, 4]
-            res = connection.execute(
-                t1.select().order_by(
-                    (t1.c.embedding.l2_distance(query_vector))
-                )
-            ).first()
-            eq_(res.embedding, [1, 2, 3])
-        else:
-            with expect_raises_message(exc.DatabaseError, "ORA-03060"):
-                t1.create(connection)
+        query_vector = [2, 3, 4]
+        res = connection.execute(
+            t1.select().order_by((t1.c.embedding.l2_distance(query_vector)))
+        ).first()
+        eq_(res.embedding, [1, 2, 3])
 
 
 class LOBFetchTest(fixtures.TablesTest):