]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Added vector datatype support in Oracle dialect
authorsuraj <suraj.shaw@oracle.com>
Mon, 5 May 2025 15:14:35 +0000 (11:14 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 5 May 2025 15:43:30 +0000 (11:43 -0400)
Added new datatype :class:`_oracle.VECTOR` and accompanying DDL and DQL
support to fully support this type for Oracle Database. This change
includes the base :class:`_oracle.VECTOR` type that adds new type-specific
methods ``l2_distance``, ``cosine_distance``, ``inner_product`` as well as
new parameters ``oracle_vector`` for the :class:`.Index` construct,
allowing vector indexes to be configured, and ``oracle_fetch_approximate``
for the :meth:`.Select.fetch` clause.  Pull request courtesy Suraj Shaw.

Fixes: #12317
Closes: #12321
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/12321
Pull-request-sha: a72a18a45c85ae7fa50a34e97ac642e16b463b54

Change-Id: I6f3af4623ce439d0820c14582cd129df293f0ba8

doc/build/changelog/unreleased_20/12317.rst [new file with mode: 0644]
doc/build/dialects/oracle.rst
lib/sqlalchemy/dialects/oracle/__init__.py
lib/sqlalchemy/dialects/oracle/base.py
lib/sqlalchemy/dialects/oracle/vector.py [new file with mode: 0644]
lib/sqlalchemy/sql/selectable.py
test/dialect/oracle/test_compiler.py
test/dialect/oracle/test_reflection.py
test/dialect/oracle/test_types.py
test/sql/test_compare.py

diff --git a/doc/build/changelog/unreleased_20/12317.rst b/doc/build/changelog/unreleased_20/12317.rst
new file mode 100644 (file)
index 0000000..13f6969
--- /dev/null
@@ -0,0 +1,16 @@
+.. change::
+    :tags: usecase, oracle
+    :tickets: 12317, 12341
+
+    Added new datatype :class:`_oracle.VECTOR` and accompanying DDL and DQL
+    support to fully support this type for Oracle Database. This change
+    includes the base :class:`_oracle.VECTOR` type that adds new type-specific
+    methods ``l2_distance``, ``cosine_distance``, ``inner_product`` as well as
+    new parameters ``oracle_vector`` for the :class:`.Index` construct,
+    allowing vector indexes to be configured, and ``oracle_fetch_approximate``
+    for the :meth:`.Select.fetch` clause.  Pull request courtesy Suraj Shaw.
+
+    .. seealso::
+
+        :ref:`oracle_vector_datatype`
+
index 757cc03ed205dbb092fed4d794ba421ebc3c8d01..b9e9a1d087030f5c229f81b6c997f3cc74f1b7e8 100644 (file)
@@ -31,6 +31,7 @@ originate from :mod:`sqlalchemy.types` or from the local dialect::
         TIMESTAMP,
         VARCHAR,
         VARCHAR2,
+        VECTOR,
     )
 
 Types which are specific to Oracle Database, or have Oracle-specific
@@ -77,6 +78,23 @@ construction arguments, are as follows:
 .. autoclass:: TIMESTAMP
   :members: __init__
 
+.. autoclass:: VECTOR
+  :members: __init__
+
+.. autoclass:: VectorIndexType
+  :members:
+
+.. autoclass:: VectorIndexConfig
+  :members:
+  :undoc-members:
+
+.. autoclass:: VectorStorageFormat
+  :members:
+
+.. autoclass:: VectorDistanceType
+  :members:
+
+
 .. _oracledb:
 
 python-oracledb
index 7ceb743d616ec31b95e0012fe1ffb53dda801e1a..2265de033c93236c166f3e91c3dcb00770d5d7d9 100644 (file)
@@ -32,6 +32,11 @@ from .base import ROWID
 from .base import TIMESTAMP
 from .base import VARCHAR
 from .base import VARCHAR2
+from .base import VECTOR
+from .base import VectorIndexConfig
+from .base import VectorIndexType
+from .vector import VectorDistanceType
+from .vector import VectorStorageFormat
 
 # Alias oracledb also as oracledb_async
 oracledb_async = type(
@@ -64,4 +69,9 @@ __all__ = (
     "NVARCHAR2",
     "ROWID",
     "REAL",
+    "VECTOR",
+    "VectorDistanceType",
+    "VectorIndexType",
+    "VectorIndexConfig",
+    "VectorStorageFormat",
 )
index c32dff2ea10691be49541aeb78216ef1b9d52645..f24f4f54b0db0ae4cfcadb9eee3e902761f04f24 100644 (file)
@@ -730,11 +730,177 @@ The ``oracle_compress`` parameter accepts either an integer specifying the
 number of prefix columns to compress, or ``True`` to use the default (all
 columns for non-unique indexes, all but the last column for unique indexes).
 
+.. _oracle_vector_datatype:
+
+VECTOR Datatype
+---------------
+
+Oracle Database 23ai introduced a new VECTOR datatype for artificial intelligence
+and machine learning search operations. The VECTOR datatype is a homogeneous array
+of 8-bit signed integers, 8-bit unsigned integers (binary), 32-bit floating-point numbers,
+or 64-bit floating-point numbers.
+
+.. seealso::
+
+    `Using VECTOR Data
+    <https://python-oracledb.readthedocs.io/en/latest/user_guide/vector_data_type.html>`_ - in the documentation
+    for the :ref:`oracledb` driver.
+
+.. versionadded:: 2.0.41
+
+CREATE TABLE support for VECTOR
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+With the :class:`.VECTOR` datatype, you can specify the dimension for the data
+and the storage format. Valid values for storage format are enum values from
+:class:`.VectorStorageFormat`. To create a table that includes a
+:class:`.VECTOR` column::
+
+    from sqlalchemy.dialects.oracle import VECTOR, VectorStorageFormat
+
+    t = Table(
+        "t1",
+        metadata,
+        Column("id", Integer, primary_key=True),
+        Column(
+            "embedding",
+            VECTOR(dim=3, storage_format=VectorStorageFormat.FLOAT32),
+        ),
+        Column(...),
+        ...,
+    )
+
+Vectors can also be defined with an arbitrary number of dimensions and formats.
+This allows you to specify vectors of different dimensions with the various
+storage formats mentioned above.
+
+**Examples**
+
+* In this case, the storage format is flexible, allowing any vector type data to be inserted,
+  such as INT8 or BINARY etc::
+
+    vector_col: Mapped[array.array] = mapped_column(VECTOR(dim=3))
+
+* The dimension is flexible in this case, meaning that any dimension vector can be used::
+
+    vector_col: Mapped[array.array] = mapped_column(
+        VECTOR(storage_format=VectorStorageType.INT8)
+    )
+
+* Both the dimensions and the storage format are flexible::
+
+    vector_col: Mapped[array.array] = mapped_column(VECTOR)
+
+Python Datatypes for VECTOR
+~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+VECTOR data can be inserted using Python list or Python ``array.array()`` objects.
+Python arrays of type FLOAT (32-bit), DOUBLE (64-bit), or INT (8-bit signed integer)
+are used as bind values when inserting VECTOR columns::
+
+    from sqlalchemy import insert, select
+
+    with engine.begin() as conn:
+        conn.execute(
+            insert(t1),
+            {"id": 1, "embedding": [1, 2, 3]},
+        )
+
+VECTOR Indexes
+~~~~~~~~~~~~~~
+
+The VECTOR feature supports an Oracle-specific parameter ``oracle_vector``
+on the :class:`.Index` construct, which allows the construction of VECTOR
+indexes.
+
+To utilize VECTOR indexing, set the ``oracle_vector`` parameter to True to use
+the default values provided by Oracle. HNSW is the default indexing method::
+
+    from sqlalchemy import Index
+
+    Index(
+        "vector_index",
+        t1.c.embedding,
+        oracle_vector=True,
+    )
+
+The full range of parameters for vector indexes are available by using the
+:class:`.VectorIndexConfig` dataclass in place of a boolean; this dataclass
+allows full configuration of the index::
+
+    Index(
+        "hnsw_vector_index",
+        t1.c.embedding,
+        oracle_vector=VectorIndexConfig(
+            index_type=VectorIndexType.HNSW,
+            distance=VectorDistanceType.COSINE,
+            accuracy=90,
+            hnsw_neighbors=5,
+            hnsw_efconstruction=20,
+            parallel=10,
+        ),
+    )
+
+    Index(
+        "ivf_vector_index",
+        t1.c.embedding,
+        oracle_vector=VectorIndexConfig(
+            index_type=VectorIndexType.IVF,
+            distance=VectorDistanceType.DOT,
+            accuracy=90,
+            ivf_neighbor_partitions=5,
+        ),
+    )
+
+For complete explanation of these parameters, see the Oracle documentation linked
+below.
+
+.. seealso::
+
+    `CREATE VECTOR INDEX <https://www.oracle.com/pls/topic/lookup?ctx=dblatest&id=GUID-B396C369-54BB-4098-A0DD-7C54B3A0D66F>`_ - in the Oracle documentation
+
+
+
+Similarity Searching
+~~~~~~~~~~~~~~~~~~~~
+
+When using the :class:`_oracle.VECTOR` datatype with a :class:`.Column` or similar
+ORM mapped construct, additional comparison functions are available, including:
+
+* ``l2_distance``
+* ``cosine_distance``
+* ``inner_product``
+
+Example Usage::
+
+    result_vector = connection.scalars(
+        select(t1).order_by(t1.embedding.l2_distance([2, 3, 4])).limit(3)
+    )
+
+    for user in vector:
+        print(user.id, user.embedding)
+
+FETCH APPROXIMATE support
+~~~~~~~~~~~~~~~~~~~~~~~~~
+
+Approximate vector search can only be performed when all syntax and semantic
+rules are satisfied, the corresponding vector index is available, and the
+query optimizer determines to perform it. If any of these conditions are
+unmet, then an approximate search is not performed. In this case the query
+returns exact results.
+
+To enable approximate searching during similarity searches on VECTORS, the
+``oracle_fetch_approximate`` parameter may be used with the :meth:`.Select.fetch`
+clause to add ``FETCH APPROX`` to the SELECT statement::
+
+    select(users_table).fetch(5, oracle_fetch_approximate=True)
+
 """  # noqa
 
 from __future__ import annotations
 
 from collections import defaultdict
+from dataclasses import fields
 from functools import lru_cache
 from functools import wraps
 import re
@@ -757,6 +923,9 @@ from .types import RAW
 from .types import ROWID  # noqa
 from .types import TIMESTAMP
 from .types import VARCHAR2  # noqa
+from .vector import VECTOR
+from .vector import VectorIndexConfig
+from .vector import VectorIndexType
 from ... import Computed
 from ... import exc
 from ... import schema as sa_schema
@@ -775,6 +944,7 @@ from ...sql import func
 from ...sql import null
 from ...sql import or_
 from ...sql import select
+from ...sql import selectable as sa_selectable
 from ...sql import sqltypes
 from ...sql import util as sql_util
 from ...sql import visitors
@@ -836,6 +1006,7 @@ ischema_names = {
     "BINARY_DOUBLE": BINARY_DOUBLE,
     "BINARY_FLOAT": BINARY_FLOAT,
     "ROWID": ROWID,
+    "VECTOR": VECTOR,
 }
 
 
@@ -993,6 +1164,16 @@ class OracleTypeCompiler(compiler.GenericTypeCompiler):
     def visit_ROWID(self, type_, **kw):
         return "ROWID"
 
+    def visit_VECTOR(self, type_, **kw):
+        if type_.dim is None and type_.storage_format is None:
+            return "VECTOR(*,*)"
+        elif type_.storage_format is None:
+            return f"VECTOR({type_.dim},*)"
+        elif type_.dim is None:
+            return f"VECTOR(*,{type_.storage_format.value})"
+        else:
+            return f"VECTOR({type_.dim},{type_.storage_format.value})"
+
 
 class OracleCompiler(compiler.SQLCompiler):
     """Oracle compiler modifies the lexical structure of Select
@@ -1234,6 +1415,29 @@ class OracleCompiler(compiler.SQLCompiler):
         else:
             return select._fetch_clause
 
+    def fetch_clause(
+        self,
+        select,
+        fetch_clause=None,
+        require_offset=False,
+        use_literal_execute_for_simple_int=False,
+        **kw,
+    ):
+        text = super().fetch_clause(
+            select,
+            fetch_clause=fetch_clause,
+            require_offset=require_offset,
+            use_literal_execute_for_simple_int=(
+                use_literal_execute_for_simple_int
+            ),
+            **kw,
+        )
+
+        if select.dialect_options["oracle"]["fetch_approximate"]:
+            text = re.sub("FETCH FIRST", "FETCH APPROX FIRST", text)
+
+        return text
+
     def translate_select_structure(self, select_stmt, **kwargs):
         select = select_stmt
 
@@ -1482,6 +1686,48 @@ class OracleCompiler(compiler.SQLCompiler):
 
 
 class OracleDDLCompiler(compiler.DDLCompiler):
+
+    def _build_vector_index_config(
+        self, vector_index_config: VectorIndexConfig
+    ) -> str:
+        parts = []
+        sql_param_name = {
+            "hnsw_neighbors": "neighbors",
+            "hnsw_efconstruction": "efconstruction",
+            "ivf_neighbor_partitions": "neighbor partitions",
+            "ivf_sample_per_partition": "sample_per_partition",
+            "ivf_min_vectors_per_partition": "min_vectors_per_partition",
+        }
+        if vector_index_config.index_type == VectorIndexType.HNSW:
+            parts.append("ORGANIZATION INMEMORY NEIGHBOR GRAPH")
+        elif vector_index_config.index_type == VectorIndexType.IVF:
+            parts.append("ORGANIZATION NEIGHBOR PARTITIONS")
+        if vector_index_config.distance is not None:
+            parts.append(f"DISTANCE {vector_index_config.distance.value}")
+
+        if vector_index_config.accuracy is not None:
+            parts.append(
+                f"WITH TARGET ACCURACY {vector_index_config.accuracy}"
+            )
+
+        parameters_str = [f"type {vector_index_config.index_type.name}"]
+        prefix = vector_index_config.index_type.name.lower() + "_"
+
+        for field in fields(vector_index_config):
+            if field.name.startswith(prefix):
+                key = sql_param_name.get(field.name)
+                value = getattr(vector_index_config, field.name)
+                if value is not None:
+                    parameters_str.append(f"{key} {value}")
+
+        parameters_str = ", ".join(parameters_str)
+        parts.append(f"PARAMETERS ({parameters_str})")
+
+        if vector_index_config.parallel is not None:
+            parts.append(f"PARALLEL {vector_index_config.parallel}")
+
+        return " ".join(parts)
+
     def define_constraint_cascades(self, constraint):
         text = ""
         if constraint.ondelete is not None:
@@ -1514,6 +1760,9 @@ class OracleDDLCompiler(compiler.DDLCompiler):
             text += "UNIQUE "
         if index.dialect_options["oracle"]["bitmap"]:
             text += "BITMAP "
+        vector_options = index.dialect_options["oracle"]["vector"]
+        if vector_options:
+            text += "VECTOR "
         text += "INDEX %s ON %s (%s)" % (
             self._prepared_index_name(index, include_schema=True),
             preparer.format_table(index.table, use_schema=True),
@@ -1531,6 +1780,11 @@ class OracleDDLCompiler(compiler.DDLCompiler):
                 text += " COMPRESS %d" % (
                     index.dialect_options["oracle"]["compress"]
                 )
+        if vector_options:
+            if vector_options is True:
+                vector_options = VectorIndexConfig()
+
+            text += " " + self._build_vector_index_config(vector_options)
         return text
 
     def post_create_table(self, table):
@@ -1682,9 +1936,18 @@ class OracleDialect(default.DefaultDialect):
                 "tablespace": None,
             },
         ),
-        (sa_schema.Index, {"bitmap": False, "compress": False}),
+        (
+            sa_schema.Index,
+            {
+                "bitmap": False,
+                "compress": False,
+                "vector": False,
+            },
+        ),
         (sa_schema.Sequence, {"order": None}),
         (sa_schema.Identity, {"order": None, "on_null": None}),
+        (sa_selectable.Select, {"fetch_approximate": False}),
+        (sa_selectable.CompoundSelect, {"fetch_approximate": False}),
     ]
 
     @util.deprecated_params(
diff --git a/lib/sqlalchemy/dialects/oracle/vector.py b/lib/sqlalchemy/dialects/oracle/vector.py
new file mode 100644 (file)
index 0000000..dae89d3
--- /dev/null
@@ -0,0 +1,266 @@
+# dialects/oracle/vector.py
+# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+# mypy: ignore-errors
+
+
+from __future__ import annotations
+
+import array
+from dataclasses import dataclass
+from enum import Enum
+from typing import Optional
+
+import sqlalchemy.types as types
+from sqlalchemy.types import Float
+
+
+class VectorIndexType(Enum):
+    """Enum representing different types of VECTOR index structures.
+
+    See :ref:`oracle_vector_datatype` for background.
+
+    .. versionadded:: 2.0.41
+
+    """
+
+    HNSW = "HNSW"
+    """
+    The HNSW (Hierarchical Navigable Small World) index type.
+    """
+    IVF = "IVF"
+    """
+    The IVF (Inverted File Index) index type
+    """
+
+
+class VectorDistanceType(Enum):
+    """Enum representing different types of vector distance metrics.
+
+    See :ref:`oracle_vector_datatype` for background.
+
+    .. versionadded:: 2.0.41
+
+    """
+
+    EUCLIDEAN = "EUCLIDEAN"
+    """Euclidean distance (L2 norm).
+
+    Measures the straight-line distance between two vectors in space.
+    """
+    DOT = "DOT"
+    """Dot product similarity.
+
+    Measures the algebraic similarity between two vectors.
+    """
+    COSINE = "COSINE"
+    """Cosine similarity.
+
+    Measures the cosine of the angle between two vectors.
+    """
+    MANHATTAN = "MANHATTAN"
+    """Manhattan distance (L1 norm).
+
+    Calculates the sum of absolute differences across dimensions.
+    """
+
+
+class VectorStorageFormat(Enum):
+    """Enum representing the data format used to store vector components.
+
+    See :ref:`oracle_vector_datatype` for background.
+
+    .. versionadded:: 2.0.41
+
+    """
+
+    INT8 = "INT8"
+    """
+    8-bit integer format.
+    """
+    BINARY = "BINARY"
+    """
+    Binary format.
+    """
+    FLOAT32 = "FLOAT32"
+    """
+    32-bit floating-point format.
+    """
+    FLOAT64 = "FLOAT64"
+    """
+    64-bit floating-point format.
+    """
+
+
+@dataclass
+class VectorIndexConfig:
+    """Define the configuration for Oracle VECTOR Index.
+
+    See :ref:`oracle_vector_datatype` for background.
+
+    .. versionadded:: 2.0.41
+
+    :param index_type: Enum value from :class:`.VectorIndexType`
+     Specifies the indexing method. For HNSW, this must be
+     :attr:`.VectorIndexType.HNSW`.
+
+    :param distance: Enum value from :class:`.VectorDistanceType`
+     specifies the metric for calculating distance between VECTORS.
+
+    :param accuracy: interger. Should be in the range 0 to 100
+     Specifies the accuracy of the nearest neighbor search during
+     query execution.
+
+    :param parallel: integer. Specifies degree of parallelism.
+
+    :param hnsw_neighbors: interger. Should be in the range 0 to
+     2048. Specifies the number of nearest neighbors considered
+     during the search. The attribute :attr:`.VectorIndexConfig.hnsw_neighbors`
+     is HNSW index specific.
+
+    :param hnsw_efconstruction: integer. Should be in the range 0
+     to 65535. Controls the trade-off between indexing speed and
+     recall quality during index construction. The attribute
+     :attr:`.VectorIndexConfig.hnsw_efconstruction` is HNSW index
+     specific.
+
+    :param ivf_neighbor_partitions: integer. Should be in the range
+     0 to 10,000,000. Specifies the number of partitions used to
+     divide the dataset. The attribute
+     :attr:`.VectorIndexConfig.ivf_neighbor_partitions` is IVF index
+     specific.
+
+    :param ivf_sample_per_partition: integer. Should be between 1
+     and ``num_vectors / neighbor partitions``. Specifies the
+     number of samples used per partition. The attribute
+     :attr:`.VectorIndexConfig.ivf_sample_per_partition` is IVF index
+     specific.
+
+    :param ivf_min_vectors_per_partition: integer. From 0 (no trimming)
+     to the total number of vectors (results in 1 partition). Specifies
+     the minimum number of vectors per partition. The attribute
+     :attr:`.VectorIndexConfig.ivf_min_vectors_per_partition`
+     is IVF index specific.
+
+    """
+
+    index_type: VectorIndexType = VectorIndexType.HNSW
+    distance: Optional[VectorDistanceType] = None
+    accuracy: Optional[int] = None
+    hnsw_neighbors: Optional[int] = None
+    hnsw_efconstruction: Optional[int] = None
+    ivf_neighbor_partitions: Optional[int] = None
+    ivf_sample_per_partition: Optional[int] = None
+    ivf_min_vectors_per_partition: Optional[int] = None
+    parallel: Optional[int] = None
+
+    def __post_init__(self):
+        self.index_type = VectorIndexType(self.index_type)
+        for field in [
+            "hnsw_neighbors",
+            "hnsw_efconstruction",
+            "ivf_neighbor_partitions",
+            "ivf_sample_per_partition",
+            "ivf_min_vectors_per_partition",
+            "parallel",
+            "accuracy",
+        ]:
+            value = getattr(self, field)
+            if value is not None and not isinstance(value, int):
+                raise TypeError(
+                    f"{field} must be an integer if"
+                    f"provided, got {type(value).__name__}"
+                )
+
+
+class VECTOR(types.TypeEngine):
+    """Oracle VECTOR datatype.
+
+    For complete background on using this type, see
+    :ref:`oracle_vector_datatype`.
+
+    .. versionadded:: 2.0.41
+
+    """
+
+    cache_ok = True
+    __visit_name__ = "VECTOR"
+
+    _typecode_map = {
+        VectorStorageFormat.INT8: "b",  # Signed int
+        VectorStorageFormat.BINARY: "B",  # Unsigned int
+        VectorStorageFormat.FLOAT32: "f",  # Float
+        VectorStorageFormat.FLOAT64: "d",  # Double
+    }
+
+    def __init__(self, dim=None, storage_format=None):
+        """Construct a VECTOR.
+
+        :param dim: integer. The dimension of the VECTOR datatype. This
+         should be an integer value.
+
+        :param storage_format: VectorStorageFormat. The VECTOR storage
+         type format. This may be Enum values form
+         :class:`.VectorStorageFormat` INT8, BINARY, FLOAT32, or FLOAT64.
+
+        """
+        if dim is not None and not isinstance(dim, int):
+            raise TypeError("dim must be an interger")
+        if storage_format is not None and not isinstance(
+            storage_format, VectorStorageFormat
+        ):
+            raise TypeError(
+                "storage_format must be an enum of type VectorStorageFormat"
+            )
+        self.dim = dim
+        self.storage_format = storage_format
+
+    def _cached_bind_processor(self, dialect):
+        """
+        Convert a list to a array.array before binding it to the database.
+        """
+
+        def process(value):
+            if value is None or isinstance(value, array.array):
+                return value
+
+            # Convert list to a array.array
+            elif isinstance(value, list):
+                typecode = self._array_typecode(self.storage_format)
+                value = array.array(typecode, value)
+                return value
+
+            else:
+                raise TypeError("VECTOR accepts list or array.array()")
+
+        return process
+
+    def _cached_result_processor(self, dialect, coltype):
+        """
+        Convert a array.array to list before binding it to the database.
+        """
+
+        def process(value):
+            if isinstance(value, array.array):
+                return list(value)
+
+        return process
+
+    def _array_typecode(self, typecode):
+        """
+        Map storage format to array typecode.
+        """
+        return self._typecode_map.get(typecode, "d")
+
+    class comparator_factory(types.TypeEngine.Comparator):
+        def l2_distance(self, other):
+            return self.op("<->", return_type=Float)(other)
+
+        def inner_product(self, other):
+            return self.op("<#>", return_type=Float)(other)
+
+        def cosine_distance(self, other):
+            return self.op("<=>", return_type=Float)(other)
index c945c355c794b2196d5d073ef630f8ddab669472..462d96b27acdd2dcaffefd1a01a122094f571888 100644 (file)
@@ -73,6 +73,7 @@ from .base import ColumnCollection
 from .base import ColumnSet
 from .base import CompileState
 from .base import DedupeColumnCollection
+from .base import DialectKWArgs
 from .base import Executable
 from .base import Generative
 from .base import HasCompileState
@@ -3890,7 +3891,7 @@ class SelectStatementGrouping(GroupedElement, SelectBase, Generic[_SB]):
         raise NotImplementedError
 
 
-class GenerativeSelect(SelectBase, Generative):
+class GenerativeSelect(DialectKWArgs, SelectBase, Generative):
     """Base class for SELECT statements where additional elements can be
     added.
 
@@ -4171,8 +4172,9 @@ class GenerativeSelect(SelectBase, Generative):
         count: _LimitOffsetType,
         with_ties: bool = False,
         percent: bool = False,
+        **dialect_kw: Any,
     ) -> Self:
-        """Return a new selectable with the given FETCH FIRST criterion
+        r"""Return a new selectable with the given FETCH FIRST criterion
         applied.
 
         This is a numeric value which usually renders as ``FETCH {FIRST | NEXT}
@@ -4202,6 +4204,11 @@ class GenerativeSelect(SelectBase, Generative):
         :param percent: When ``True``, ``count`` represents the percentage
          of the total number of selected rows to return. Defaults to ``False``
 
+        :param \**dialect_kw: Additional dialect-specific keyword arguments
+         may be accepted by dialects.
+
+         .. versionadded:: 2.0.41
+
         .. seealso::
 
            :meth:`_sql.GenerativeSelect.limit`
@@ -4209,7 +4216,7 @@ class GenerativeSelect(SelectBase, Generative):
            :meth:`_sql.GenerativeSelect.offset`
 
         """
-
+        self._validate_dialect_kwargs(dialect_kw)
         self._limit_clause = None
         if count is None:
             self._fetch_clause = self._fetch_clause_options = None
@@ -4455,6 +4462,7 @@ class CompoundSelect(
         ]
         + SupportsCloneAnnotations._clone_annotations_traverse_internals
         + HasCTE._has_ctes_traverse_internals
+        + DialectKWArgs._dialect_kwargs_traverse_internals
     )
 
     selects: List[SelectBase]
@@ -5342,6 +5350,7 @@ class Select(
         + HasHints._has_hints_traverse_internals
         + SupportsCloneAnnotations._clone_annotations_traverse_internals
         + Executable._executable_traverse_internals
+        + DialectKWArgs._dialect_kwargs_traverse_internals
     )
 
     _cache_key_traversal: _CacheKeyTraversalType = _traverse_internals + [
@@ -5363,7 +5372,9 @@ class Select(
         stmt.__dict__.update(kw)
         return stmt
 
-    def __init__(self, *entities: _ColumnsClauseArgument[Any]):
+    def __init__(
+        self, *entities: _ColumnsClauseArgument[Any], **dialect_kw: Any
+    ):
         r"""Construct a new :class:`_expression.Select`.
 
         The public constructor for :class:`_expression.Select` is the
@@ -5376,7 +5387,6 @@ class Select(
             )
             for ent in entities
         ]
-
         GenerativeSelect.__init__(self)
 
     def _apply_syntax_extension_to_self(
index c7f4a0c492b50623628b12111b1dc963281be525..625547efb1b34ba01bd1b755f874aa1d3b8cb4fd 100644 (file)
@@ -312,6 +312,17 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
             checkparams={"param_1": 20, "param_2": 10},
         )
 
+    @testing.only_on("oracle>=23.4")
+    def test_fetch_type(self):
+        t = table("sometable", column("col1"), column("col2"))
+        s = select(t).fetch(2, oracle_fetch_approximate=True)
+        self.assert_compile(
+            s,
+            "SELECT sometable.col1, sometable.col2 FROM sometable "
+            "FETCH APPROX FIRST __[POSTCOMPILE_param_1] ROWS ONLY",
+            checkparams={"param_1": 2},
+        )
+
     def test_limit_two(self):
         t = table("sometable", column("col1"), column("col2"))
         s = select(t).limit(10).offset(20).subquery()
index f93957526940a75e83edfa8299b47eaac9cfcfa2..93f89cf5d56c8a606270751c4dab816806148dc3 100644 (file)
@@ -21,6 +21,11 @@ from sqlalchemy import text
 from sqlalchemy import Unicode
 from sqlalchemy import UniqueConstraint
 from sqlalchemy.dialects import oracle
+from sqlalchemy.dialects.oracle import VECTOR
+from sqlalchemy.dialects.oracle import VectorDistanceType
+from sqlalchemy.dialects.oracle import VectorIndexConfig
+from sqlalchemy.dialects.oracle import VectorIndexType
+from sqlalchemy.dialects.oracle import VectorStorageFormat
 from sqlalchemy.dialects.oracle.base import BINARY_DOUBLE
 from sqlalchemy.dialects.oracle.base import BINARY_FLOAT
 from sqlalchemy.dialects.oracle.base import DOUBLE_PRECISION
@@ -698,6 +703,25 @@ class TableReflectionTest(fixtures.TestBase):
         tbl = Table("test_tablespace", m2, autoload_with=connection)
         assert tbl.dialect_options["oracle"]["tablespace"] == "TEMP"
 
+    @testing.only_on("oracle>=23.4")
+    def test_reflection_w_vector_column(self, connection, metadata):
+        tb1 = Table(
+            "test_vector",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            Column("name", String(30)),
+            Column(
+                "embedding",
+                VECTOR(dim=3, storage_format=VectorStorageFormat.FLOAT32),
+            ),
+        )
+        metadata.create_all(connection)
+
+        m2 = MetaData()
+
+        tb1 = Table("test_vector", m2, autoload_with=connection)
+        assert tb1.columns.keys() == ["id", "name", "embedding"]
+
 
 class ViewReflectionTest(fixtures.TestBase):
     __only_on__ = "oracle"
@@ -1180,6 +1204,42 @@ class RoundTripIndexTest(fixtures.TestBase):
         eq_(len(reflectedtable.constraints), 1)
         eq_(len(reflectedtable.indexes), 5)
 
+    @testing.only_on("oracle>=23.4")
+    def test_vector_index(self, metadata, connection):
+        tb1 = Table(
+            "test_vector",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            Column("name", String(30)),
+            Column(
+                "embedding",
+                VECTOR(dim=3, storage_format=VectorStorageFormat.FLOAT32),
+            ),
+        )
+        tb1.create(connection)
+
+        ivf_index = Index(
+            "ivf_vector_index",
+            tb1.c.embedding,
+            oracle_vector=VectorIndexConfig(
+                index_type=VectorIndexType.IVF,
+                distance=VectorDistanceType.DOT,
+                accuracy=90,
+                ivf_neighbor_partitions=5,
+            ),
+        )
+        ivf_index.create(connection)
+
+        expected = [
+            {
+                "name": "ivf_vector_index",
+                "column_names": ["embedding"],
+                "dialect_options": {},
+                "unique": False,
+            },
+        ]
+        eq_(inspect(connection).get_indexes("test_vector"), expected)
+
 
 class DBLinkReflectionTest(fixtures.TestBase):
     __requires__ = ("oracle_test_dblink",)
index b5ce61222e8f5ab6c5d300032145447b09a1ccd4..dc060f27e03ca77da4c8022727ca4d3b611b58e2 100644 (file)
@@ -1,3 +1,4 @@
+import array
 import datetime
 import decimal
 import os
@@ -15,6 +16,7 @@ from sqlalchemy import event
 from sqlalchemy import exc
 from sqlalchemy import FLOAT
 from sqlalchemy import Float
+from sqlalchemy import Index
 from sqlalchemy import Integer
 from sqlalchemy import LargeBinary
 from sqlalchemy import literal
@@ -37,6 +39,11 @@ from sqlalchemy import VARCHAR
 from sqlalchemy.dialects.oracle import base as oracle
 from sqlalchemy.dialects.oracle import cx_oracle
 from sqlalchemy.dialects.oracle import oracledb
+from sqlalchemy.dialects.oracle import VECTOR
+from sqlalchemy.dialects.oracle import VectorDistanceType
+from sqlalchemy.dialects.oracle import VectorIndexConfig
+from sqlalchemy.dialects.oracle import VectorIndexType
+from sqlalchemy.dialects.oracle import VectorStorageFormat
 from sqlalchemy.sql import column
 from sqlalchemy.sql.sqltypes import NullType
 from sqlalchemy.testing import AssertsCompiledSQL
@@ -951,6 +958,194 @@ class TypesTest(fixtures.TestBase):
         finally:
             exec_sql(connection, "DROP TABLE Z_TEST")
 
+    @testing.only_on("oracle>=23.4")
+    def test_vector_dim(self, metadata, connection):
+        t1 = Table(
+            "t1",
+            metadata,
+            Column(
+                "c1", VECTOR(dim=3, storage_format=VectorStorageFormat.FLOAT32)
+            ),
+        )
+
+        t1.create(connection)
+        eq_(t1.c.c1.type.dim, 3)
+
+    @testing.only_on("oracle>=23.4")
+    def test_vector_insert(self, metadata, connection):
+        t1 = Table(
+            "t1",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            Column("c1", VECTOR(storage_format=VectorStorageFormat.INT8)),
+        )
+
+        t1.create(connection)
+        connection.execute(
+            t1.insert(),
+            dict(id=1, c1=[6, 7, 8, 5]),
+        )
+        eq_(
+            connection.execute(t1.select()).first(),
+            (1, [6, 7, 8, 5]),
+        )
+        connection.execute(t1.delete().where(t1.c.id == 1))
+        connection.execute(t1.insert(), dict(id=1, c1=[6, 7]))
+        eq_(
+            connection.execute(t1.select()).first(),
+            (1, [6, 7]),
+        )
+
+    @testing.only_on("oracle>=23.4")
+    def test_vector_insert_array(self, metadata, connection):
+        t1 = Table(
+            "t1",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            Column("c1", VECTOR),
+        )
+
+        t1.create(connection)
+        connection.execute(
+            t1.insert(),
+            dict(id=1, c1=array.array("b", [6, 7, 8, 5])),
+        )
+        eq_(
+            connection.execute(t1.select()).first(),
+            (1, [6, 7, 8, 5]),
+        )
+
+        connection.execute(t1.delete().where(t1.c.id == 1))
+
+        connection.execute(
+            t1.insert(), dict(id=1, c1=array.array("b", [6, 7]))
+        )
+        eq_(
+            connection.execute(t1.select()).first(),
+            (1, [6, 7]),
+        )
+
+    @testing.only_on("oracle>=23.4")
+    def test_vector_multiformat_insert(self, metadata, connection):
+        t1 = Table(
+            "t1",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            Column("c1", VECTOR),
+        )
+
+        t1.create(connection)
+        connection.execute(
+            t1.insert(),
+            dict(id=1, c1=[6.12, 7.54, 8.33]),
+        )
+        eq_(
+            connection.execute(t1.select()).first(),
+            (1, [6.12, 7.54, 8.33]),
+        )
+        connection.execute(t1.delete().where(t1.c.id == 1))
+        connection.execute(t1.insert(), dict(id=1, c1=[6, 7]))
+        eq_(
+            connection.execute(t1.select()).first(),
+            (1, [6, 7]),
+        )
+
+    @testing.only_on("oracle>=23.4")
+    def test_vector_format(self, metadata, connection):
+        t1 = Table(
+            "t1",
+            metadata,
+            Column(
+                "c1", VECTOR(dim=3, storage_format=VectorStorageFormat.FLOAT32)
+            ),
+        )
+
+        t1.create(connection)
+        eq_(t1.c.c1.type.storage_format, VectorStorageFormat.FLOAT32)
+
+    @testing.only_on("oracle>=23.4")
+    def test_vector_hnsw_index(self, metadata, connection):
+        t1 = Table(
+            "t1",
+            metadata,
+            Column("id", Integer),
+            Column(
+                "embedding",
+                VECTOR(dim=3, storage_format=VectorStorageFormat.FLOAT32),
+            ),
+        )
+
+        t1.create(connection)
+
+        hnsw_index = Index(
+            "hnsw_vector_index", t1.c.embedding, oracle_vector=True
+        )
+        hnsw_index.create(connection)
+
+        connection.execute(t1.insert(), dict(id=1, embedding=[6, 7, 8]))
+        eq_(
+            connection.execute(t1.select()).first(),
+            (1, [6.0, 7.0, 8.0]),
+        )
+
+    @testing.only_on("oracle>=23.4")
+    def test_vector_ivf_index(self, metadata, connection):
+        t1 = Table(
+            "t1",
+            metadata,
+            Column("id", Integer),
+            Column(
+                "embedding",
+                VECTOR(dim=3, storage_format=VectorStorageFormat.FLOAT32),
+            ),
+        )
+
+        t1.create(connection)
+        ivf_index = Index(
+            "ivf_vector_index",
+            t1.c.embedding,
+            oracle_vector=VectorIndexConfig(
+                index_type=VectorIndexType.IVF,
+                distance=VectorDistanceType.DOT,
+                accuracy=90,
+                ivf_neighbor_partitions=5,
+            ),
+        )
+        ivf_index.create(connection)
+
+        connection.execute(t1.insert(), dict(id=1, embedding=[6, 7, 8]))
+        eq_(
+            connection.execute(t1.select()).first(),
+            (1, [6.0, 7.0, 8.0]),
+        )
+
+    @testing.only_on("oracle>=23.4")
+    def test_vector_l2_distance(self, metadata, connection):
+        t1 = Table(
+            "t1",
+            metadata,
+            Column("id", Integer),
+            Column(
+                "embedding",
+                VECTOR(dim=3, storage_format=VectorStorageFormat.INT8),
+            ),
+        )
+
+        t1.create(connection)
+
+        connection.execute(t1.insert(), dict(id=1, embedding=[8, 9, 10]))
+        connection.execute(t1.insert(), dict(id=2, embedding=[1, 2, 3]))
+        connection.execute(
+            t1.insert(),
+            dict(id=3, embedding=[15, 16, 17]),
+        )
+
+        query_vector = [2, 3, 4]
+        res = connection.execute(
+            t1.select().order_by((t1.c.embedding.l2_distance(query_vector)))
+        ).first()
+        eq_(res.embedding, [1, 2, 3])
+
 
 class LOBFetchTest(fixtures.TablesTest):
     __only_on__ = "oracle"
index 733dcd0aebd6401d2c6b994b439cae6ad6cb36ac..9c9bde1dacfbd50a022add4f74736bab5052ce26 100644 (file)
@@ -43,6 +43,7 @@ from sqlalchemy.sql import True_
 from sqlalchemy.sql import type_coerce
 from sqlalchemy.sql import visitors
 from sqlalchemy.sql.annotation import Annotated
+from sqlalchemy.sql.base import DialectKWArgs
 from sqlalchemy.sql.base import HasCacheKey
 from sqlalchemy.sql.base import SingletonConstant
 from sqlalchemy.sql.base import SyntaxExtension
@@ -549,6 +550,7 @@ class CoreFixtures:
             select(table_a.c.a).fetch(2, percent=True),
             select(table_a.c.a).fetch(2, with_ties=True),
             select(table_a.c.a).fetch(2, with_ties=True, percent=True),
+            select(table_a.c.a).fetch(2, oracle_fetch_approximate=True),
             select(table_a.c.a).fetch(2).offset(3),
             select(table_a.c.a).fetch(2).offset(5),
             select(table_a.c.a).limit(2).offset(5),
@@ -1682,6 +1684,7 @@ class HasCacheKeySubclass(fixtures.TestBase):
                 NoInit,
                 SingletonConstant,
                 SyntaxExtension,
+                DialectKWArgs,
             ]
         )
     )