From a0b3fa8fc993d2f70a089a649edb1cf18b12a84a Mon Sep 17 00:00:00 2001 From: suraj Date: Thu, 10 Apr 2025 19:54:07 +0530 Subject: [PATCH] Fixes: sqlalchemy#12317 Added vector datatype support in Oracle dialect included vector_Storage Fixes: sqlalchemy#12317 Added vector datatype support in Oracle dialect --- doc/build/dialects/oracle.rst | 4 + lib/sqlalchemy/dialects/oracle/__init__.py | 8 + lib/sqlalchemy/dialects/oracle/base.py | 221 +++++++++++++---- lib/sqlalchemy/dialects/oracle/oracledb.py | 222 ----------------- lib/sqlalchemy/dialects/oracle/types.py | 82 ------- lib/sqlalchemy/dialects/oracle/vector.py | 243 +++++++++++++++++++ test/dialect/oracle/test_types.py | 267 ++++++++++----------- 7 files changed, 568 insertions(+), 479 deletions(-) create mode 100644 lib/sqlalchemy/dialects/oracle/vector.py diff --git a/doc/build/dialects/oracle.rst b/doc/build/dialects/oracle.rst index b3d44858ce..c054a92211 100644 --- a/doc/build/dialects/oracle.rst +++ b/doc/build/dialects/oracle.rst @@ -31,6 +31,7 @@ originate from :mod:`sqlalchemy.types` or from the local dialect:: TIMESTAMP, VARCHAR, VARCHAR2, + VECTOR, ) .. versionadded:: 1.2.19 Added :class:`_types.NCHAR` to the list of datatypes @@ -80,6 +81,9 @@ construction arguments, are as follows: .. autoclass:: TIMESTAMP :members: __init__ +.. autoclass:: VECTOR + :members: __init__ + .. _oracledb: python-oracledb diff --git a/lib/sqlalchemy/dialects/oracle/__init__.py b/lib/sqlalchemy/dialects/oracle/__init__.py index c05d8bdf87..2265de033c 100644 --- a/lib/sqlalchemy/dialects/oracle/__init__.py +++ b/lib/sqlalchemy/dialects/oracle/__init__.py @@ -33,6 +33,10 @@ from .base import TIMESTAMP from .base import VARCHAR from .base import VARCHAR2 from .base import VECTOR +from .base import VectorIndexConfig +from .base import VectorIndexType +from .vector import VectorDistanceType +from .vector import VectorStorageFormat # Alias oracledb also as oracledb_async oracledb_async = type( @@ -66,4 +70,8 @@ __all__ = ( "ROWID", "REAL", "VECTOR", + "VectorDistanceType", + "VectorIndexType", + "VectorIndexConfig", + "VectorStorageFormat", ) diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py index 131a341bb4..7b57a90e93 100644 --- a/lib/sqlalchemy/dialects/oracle/base.py +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -744,11 +744,140 @@ The ``oracle_compress`` parameter accepts either an integer specifying the number of prefix columns to compress, or ``True`` to use the default (all columns for non-unique indexes, all but the last column for unique indexes). +VECTOR Datatype +--------------- + +Oracle Database 23ai introduced a new VECTOR datatype for artificial intelligence and machine +learning search operations. The VECTOR datatype is a homogeneous array of 8-bit signed integers, +8-bit unsigned integers, 32-bit floating-point numbers, or 64-bit floating-point numbers. +For more information on the VECTOR datatype please visit this `link +`_. + +CREATE TABLE +~~~~~~~~~~~~ + +With the :class:`.VECTOR` datatype, you can specify the dimension for the data and the storage +format. Valid values for storage format are enum values from :class:`.VectorStorageFormat` +(:attr:`.VectorStorageFormat.INT8`, :attr:`.VectorStorageFormat.BINARY`, :attr:`.VectorStorageFormat.FLOAT32`, +:attr:`.VectorStorageFormat.FLOAT64`). +To create a table that includes a :class:`.VECTOR` column:: + + from sqlalchemy.dialects.oracle import VECTOR, VectorStorageFormat + + t = Table("t1", metadata, + Column('id', Integer, primary_key=True), + Column("embedding", VECTOR(dim=3, storage_format=VectorStorageFormat.FLOAT32), + Column(...), ... + ) + +Vectors can also be defined with an arbitrary number of dimensions and formats. This allows +you to specify vectors of different dimensions with the various storage formats mentioned above. + +For Example + +* In this case, the storage format is flexible, allowing any vector type data to be inserted, + such as INT8 or BINARY etc. + + vector_col:Mapped[array.array] = mapped_column(VECTOR(dim=3)) + +* The dimension is flexible in this case, meaning that any dimension vector can be used. + + vector_col:Mapped[array.array] = mapped_column(VECTOR(storage_format=VectorStorageType.INT8)) + +* Both the dimensions and the storage format are flexible. + + vector_col:Mapped[array.array] = mapped_column(VECTOR) + +INSERT VECTOR DATA +~~~~~~~~~~~~~~~~~~ + +VECTOR data can be inserted using Python list or Python array.array() objects. Python arrays of type +FLOAT (32-bit), DOUBLE (64-bit), or INT (8-bit signed integer) are used as bind values when +inserting VECTOR columns:: + + from sqlalchemy import insert, select + import array + + vector_data_8 = [1, 2, 3] + statement = insert(t1) + with engine.connect() as conn: + conn.execute(statement,[ + {"id":1,"embedding":vector_data_8}, + ]) + +VECTOR INDEXES +~~~~~~~~~~~~~~ + +There are two VECTOR indexes supported in VECTOR search: IVF Flat index and HNSW +index. + +To utilize VECTOR indexing, set the `oracle_vector` parameter to True to use +the default values provided by Oracle. HNSW is the default indexing method:: + + Index( + 'vector_index', + t1.c.embedding, + oracle_vector = True, + ) + +If you wish to use custom parameters, you can specify all the parameters using the VectorIndexConfig +Dataclass in the `oracle_vector` option. To learn more about the parameters that can be passed please +visit this `link. `_ + + Index( + 'hnsw_vector_index', + t1.c.embedding, + oracle_vector = VectorIndexConfig( + index_type = VectorIndexType.HNSW, + distance = VectorDistanceType.COSINE, + accuracy = 90, + hnsw_neighbors = 5, + hnsw_efconstruction = 20, + parallel = 10 + ) + ) + + Index( + 'ivf_vector_index', + t1.c.embedding, + oracle_vector = VectorIndexConfig( + index_type = VectorIndexType.IVF, + distance = VectorDistanceType.DOT, + accuracy = 90, + ivf_neighbor_partitions = 5, + ) + ) + +Similarity Searching +~~~~~~~~~~~~~~~~~~~~ + +You can use the following shorthand VECTOR distance functions: + +* ``l2_distance`` +* ``cosine_distance`` +* ``inner_product`` + +Example Usage:: + + from sqlalchemy.orm import Session + from sqlalchemy.sql import func + import array + + session = Session(bind=engine) + query_vector = [2,3,4] + result_vector = session.scalars(select(t1).order_by(t1.embedding.l2_distance(query_vector)).limit(3)) + + for user in vector: + print(user.id,user.embedding) + +.. versionadded:: 2.1.0 Added support for VECTOR specific to Oracle Database. + """ # noqa from __future__ import annotations from collections import defaultdict +from dataclasses import fields from functools import lru_cache from functools import wraps import re @@ -771,7 +900,9 @@ from .types import RAW from .types import ROWID # noqa from .types import TIMESTAMP from .types import VARCHAR2 # noqa -from .types import VECTOR +from .vector import VECTOR +from .vector import VectorIndexConfig +from .vector import VectorIndexType from ... import Computed from ... import exc from ... import schema as sa_schema @@ -1011,13 +1142,13 @@ class OracleTypeCompiler(compiler.GenericTypeCompiler): def visit_VECTOR(self, type_, **kw): if type_.dim is None and type_.storage_format is None: - return f"VECTOR" + return "VECTOR(*,*)" elif type_.storage_format is None: return f"VECTOR({type_.dim},*)" elif type_.dim is None: - return f"VECTOR(*, {type_.storage_format})" + return f"VECTOR(*,{type_.storage_format.value})" else: - return f"VECTOR({type_.dim},{type_.storage_format})" + return f"VECTOR({type_.dim},{type_.storage_format.value})" class OracleCompiler(compiler.SQLCompiler): @@ -1505,6 +1636,48 @@ class OracleCompiler(compiler.SQLCompiler): class OracleDDLCompiler(compiler.DDLCompiler): + + def _build_vector_index_config( + self, vector_index_config: VectorIndexConfig + ) -> str: + parts = [] + sql_param_name = { + "hnsw_neighbors": "neighbors", + "hnsw_efconstruction": "efconstruction", + "ivf_neighbor_partitions": "neighbor partitions", + "ivf_sample_per_partition": "sample_per_partition", + "ivf_min_vectors_per_partition": "min_vectors_per_partition", + } + if vector_index_config.index_type == VectorIndexType.HNSW: + parts.append("ORGANIZATION INMEMORY NEIGHBOR GRAPH") + elif vector_index_config.index_type == VectorIndexType.IVF: + parts.append("ORGANIZATION NEIGHBOR PARTITIONS") + if vector_index_config.distance is not None: + parts.append(f"DISTANCE {vector_index_config.distance.value}") + + if vector_index_config.accuracy is not None: + parts.append( + f"WITH TARGET ACCURACY {vector_index_config.accuracy}" + ) + + parameters_str = [f"type {vector_index_config.index_type.name}"] + prefix = vector_index_config.index_type.name.lower() + "_" + + for field in fields(vector_index_config): + if field.name.startswith(prefix): + key = sql_param_name.get(field.name) + value = getattr(vector_index_config, field.name) + if value is not None: + parameters_str.append(f"{key} {value}") + + parameters_str = ", ".join(parameters_str) + parts.append(f"PARAMETERS ({parameters_str})") + + if vector_index_config.parallel is not None: + parts.append(f"PARALLEL {vector_index_config.parallel}") + + return " ".join(parts) + def define_constraint_cascades(self, constraint): text = "" if constraint.ondelete is not None: @@ -1559,43 +1732,9 @@ class OracleDDLCompiler(compiler.DDLCompiler): ) if vector_options: if vector_options is True: - vector_options = {} - parts = [] - parameters = vector_options.get("parameters", {}) - using = parameters.get("type", "HNSW").upper() - if using == "HNSW": - parts.append("ORGANIZATION INMEMORY NEIGHBOR GRAPH") - elif using == "IVF": - parts.append("ORGANIZATION NEIGHBOR PARTITIONS") - vector_distance = vector_options.get("distance") - if vector_distance is not None: - vector_distance = vector_distance.upper() - if vector_distance not in ( - "EUCLIDEAN", - "DOT", - "COSINE", - "MANHATTAN", - ): - raise ValueError("Unknown vector_distance value") - parts.append(f"DISTANCE {vector_distance}") - target_accuracy = vector_options.get("accuracy") - if target_accuracy is not None: - if target_accuracy < 0 or target_accuracy > 100: - raise ValueError( - "Accuracy value should be an integer between 0 and 100" - ) - parts.append(f"WITH TARGET ACCURACY {target_accuracy}") - if parameters: - parameters_str = ", ".join( - f"{k} {v}" for k, v in parameters.items() - ) - parts.append(f"PARAMETERS ({parameters_str})") - parallel = vector_options.get("parallel") - if parallel is not None: - if not isinstance(parallel, int): - raise ValueError("Parallel value must be an integer") - parts.append(f"PARALLEL {parallel}") - text += " " + " ".join(parts) + vector_options = VectorIndexConfig() + + text += " " + self._build_vector_index_config(vector_options) return text def post_create_table(self, table): diff --git a/lib/sqlalchemy/dialects/oracle/oracledb.py b/lib/sqlalchemy/dialects/oracle/oracledb.py index 201902f997..8105608837 100644 --- a/lib/sqlalchemy/dialects/oracle/oracledb.py +++ b/lib/sqlalchemy/dialects/oracle/oracledb.py @@ -591,228 +591,6 @@ SQLAlchemy type (or a subclass of such). .. versionadded:: 2.0.0 added support for the python-oracledb driver. -VECTOR Datatype ---------------- - -Oracle Database 23ai introduced a new VECTOR datatype for artificial intelligence and machine -learning search operations. The VECTOR datatype is a homogeneous array of 8-bit signed integers, -8-bit unsigned integers, 32-bit floating-point numbers, or 64-bit floating-point numbers. -For more information on the VECTOR datatype please visit this `link. -`_ - -CREATE TABLE -^^^^^^^^^^^^ - -With the VECTOR datatype, you can specify the dimension for the data and the storage -format. To create a table that includes a VECTOR column:: - - from sqlalchemy.dialects.oracle import VECTOR - - t = Table("t1", metadata, - Column('id', Integer, primary_key=True), - Column("embedding", VECTOR(3, 'float32'), - Column(...), ... - ) - -Vectors can also be defined with an arbitrary number of dimensions and formats. This allows -you to specify vectors of different dimensions with the various storage formats mentioned above. - -For Example - -* In this case, the storage format is flexible, allowing any vector type data to be inserted, - such as int8 or binary etc. - - vector_col:Mapped[array.array] = mapped_column(VECTOR(3)) - -* The dimension is flexible in this case, meaning that any dimension vector can be used. - - vector_col:Mapped[array.array] = mapped_column(VECTOR('int8')) - -* Both the dimensions and the storage format are flexible. - - vector_col:Mapped[array.array] = mapped_column(VECTOR) - -INSERT VECTOR DATA -^^^^^^^^^^^^^^^^^^ - -VECTOR data can be inserted using Python list or Python array.array() objects. Python arrays of type -float (32-bit), double (64-bit), or int8_t (8-bit signed integer) are used as bind values when -inserting VECTOR columns:: - - from sqlalchemy import insert, select - import array - - vector_data_8 = [1, 2, 3] - statement = insert(t1) - with engine.connect() as conn: - conn.execute(statement,[ - {"id":1,"embedding":vector_data_8}, - ]) - -VECTOR INDEXES -^^^^^^^^^^^^^^ - -There are two VECTOR indexes supported in VECTOR search: IVF Flat index and HNSW -index. - -To utilize VECTOR indexing, set the `oracle_vector` parameter to True to use -the default values provided by Oracle. HNSW is the default indexing method:: - - Index( - 'vector_index', - t1.c.embedding, - oracle_vector = True, - ) - -If you wish to use custom parameters, you can specify all the parameters as a dictionary -in the `oracle_vector` option. To learn more about the parameters that can be passed please -visit this `link. `_ - -Configuring Oracle VECTOR Indexes -================================= - -When using Oracle VECTOR indexes, the configuration parameters are divided into two levels: -**top-level keys** and **nested keys**. This structure applies to both HNSW and IVF VECTOR indexes. - -Top-Level Keys -============== - -These keys are specified directly under the ``oracle_vector`` dictionary. - -* ``accuracy``: - - Specifies the accuracy of the nearest neighbor search during query execution. - - **Valid Range**: Greater than 0 and less than or equal to 100. - - **Example**: ``'accuracy': 95`` - -* ``distance``: - - Specifies the metric for calculating distance between VECTORS. - - **Valid Values**: ``"EUCLIDEAN"``, ``"COSINE"``, ``"DOT"``, ``"MANHATTAN"``. - - **Example**: ``'distance': "COSINE"`` - -* ``parameters``: - - A nested dictionary where method-specific options are defined (e.g., HNSW or IVF-specific settings). - - **Example**: ``'parameters': {...}`` - -Nested keys in parameters -========================= - -These keys are specific to the indexing method and are included under the ``parameters`` dictionary. - -HNSW Parameters -=============== - -* ``type``: - - Specifies the indexing method. For HNSW, this must be ``"HNSW"``. - - **Placement**: Nested under ``parameters``. - - **Example**: ``'type': 'HNSW'`` - -* ``neighbors``: - - The number of nearest neighbors considered during the search. - - **Valid Range**: Greater than 0 and less than or equal to 2048. - - **Placement**: Nested under ``parameters``. - - **Example**: ``'neighbors': 20`` - -* ``efconstruction``: - - Controls the trade-off between indexing speed and recall quality during index construction. - - **Valid Range**: Greater than 0 and less than or equal to 65535. - - **Placement**: Nested under ``parameters``. - - **Example**: ``'efconstruction': 300`` - -IVF Parameters -============== - -* ``type``: - - Specifies the indexing method. For IVF, this must be ``"IVF"``. - - **Placement**: Nested under ``parameters``. - - **Example**: ``'type': 'IVF'`` - -* ``neighbor partitions``: - - The number of partitions used to divide the dataset. - - **Valid Range**: Greater than 0 and less than or equal to 10,000,000. - - **Placement**: Nested under ``parameters``. - - **Example**: ``'neighbor partitions': 10`` - -* ``sample_per_partition``: - - The number of samples used per partition. - - **Valid Range**: Between 1 and ``num_vectors / neighbor_partitions``. - - **Placement**: Nested under ``parameters``. - - **Example**: ``'sample_per_partition': 5`` - -* ``min_vectors_per_partition``: - - The minimum number of vectors per partition. - - **Valid Range**: From 0 (no trimming) to the total number of vectors (results in 1 partition). - - **Placement**: Nested under ``parameters``. - - **Example**: ``'min_vectors_per_partition': 100`` - -Example Configurations -====================== - -For custom configurations, the parameters can be specified as shown in the following examples:: - - Index( - 'hnsw_vector_index', - t1.c.embedding, - oracle_vector = { - 'accuracy':95, - 'distance':"COSINE", - 'parameters':{ - 'type':'HNSW', - 'neighbors':20', - efconstruction':300 - } - } - ) - - Index( - 'ivf_vector_index', - t1.c.embedding, - oracle_vector = { - 'accuracy':90, - 'distance':"DOT", - 'parameters':{ - 'type':'IVF',' - neighbor partitions':10 - } - } - ) - -Similarity Searching -^^^^^^^^^^^^^^^^^^^^ - -You can use the following shorthand VECTOR distance functions: - -* ``l2_distance`` -* ``cosine_distance`` -* ``inner_product`` - -Example Usage:: - - from sqlalchemy.orm import Session - from sqlalchemy.sql import func - import array - - session = Session(bind=engine) - query_vector = [2,3,4] - result_vector = session.scalars(select(t1).order_by(t1.embedding.l2_distance(query_vector)).limit(3)) - - for user in vector: - print(user.id,user.embedding) - -Exact and Approximate Searching -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Similarity searches tend to get data from one or more clusters depending on the value of the query VECTOR and the fetch -size. Approximate searches using VECTOR indexes can limit the searches to specific clusters, whereas exact searches visit -VECTORS across all clusters. -You can use the fetch_type clause to set the searching to be either EXACT, APPROX or APPROXIMATE:: - - result_vector = session.scalars(select(t1).order_by(t1.embedding.l1_distance(query_vector)).limit(3)).fetch_type("EXACT") - result_vector = session.scalars(select(t1).order_by(t1.embedding.l1_distance(query_vector)).limit(3)).fetch_type("APPROX") - - -.. versionadded:: 2.1.0 Added support for VECTOR specific to Oracle Database. -:class:`_oracle.VECTOR` datatype. - """ # noqa from __future__ import annotations diff --git a/lib/sqlalchemy/dialects/oracle/types.py b/lib/sqlalchemy/dialects/oracle/types.py index f128a42352..06aeaace2f 100644 --- a/lib/sqlalchemy/dialects/oracle/types.py +++ b/lib/sqlalchemy/dialects/oracle/types.py @@ -11,14 +11,11 @@ import datetime as dt from typing import Optional from typing import Type from typing import TYPE_CHECKING -import sqlalchemy.types as types -from sqlalchemy.types import UserDefinedType, Float from ... import exc from ...sql import sqltypes from ...types import NVARCHAR from ...types import VARCHAR -import array if TYPE_CHECKING: from ...engine.interfaces import Dialect @@ -317,82 +314,3 @@ class ROWID(sqltypes.TypeEngine): class _OracleBoolean(sqltypes.Boolean): def get_dbapi_type(self, dbapi): return dbapi.NUMBER - - -class VECTOR(types.TypeEngine): - """Oracle VECTOR datatype.""" - - cache_ok = True - __visit_name__ = "VECTOR" - - def __init__(self, dim=None, storage_format=None, *args): - """ - :param dim: The dimension of the VECTOR datatype. This should be an - integer value. - :param storage_format: The VECTOR storage type format. This - may be int8, binary, float32, or float64. - """ - if dim is not None and isinstance(dim, int): - self.dim = dim - self.storage_format = storage_format - - elif dim is not None and isinstance(dim, str): - self.dim = storage_format - self.storage_format = dim - - else: - self.dim = storage_format - self.storage_format = dim - - def _cached_bind_processor(self, dialect): - """ - Convert a list to a array.array before binding it to the database. - """ - - def process(value): - if value is None or isinstance(value, array.array): - return value - - # Convert list to a array.array - elif isinstance(value, list): - format = self._array_typecode(self.storage_format) - value = array.array(format, value) - return value - - else: - raise TypeError("VECTOR accepts list or array.array()") - - return process - - def _cached_result_processor(self, dialect, coltype): - """ - Convert a array.array to list before binding it to the database. - """ - - def process(value): - if isinstance(value, array.array): - return list(value) - - return process - - def _array_typecode(self, format): - """ - Map storage format to array typecode. - """ - typecode_map = { - "int8": "b", # Signed int - "binary": "B", # Unsigned int - "float32": "f", # Float - "float64": "d", # Double - } - return typecode_map.get(format, "d") - - class comparator_factory(types.TypeEngine.Comparator): - def l2_distance(self, other): - return self.op("<->", return_type=Float)(other) - - def inner_product(self, other): - return self.op("<#>", return_type=Float)(other) - - def cosine_distance(self, other): - return self.op("<=>", return_type=Float)(other) diff --git a/lib/sqlalchemy/dialects/oracle/vector.py b/lib/sqlalchemy/dialects/oracle/vector.py new file mode 100644 index 0000000000..84b64cd44f --- /dev/null +++ b/lib/sqlalchemy/dialects/oracle/vector.py @@ -0,0 +1,243 @@ +# dialects/oracle/vector.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +from __future__ import annotations + +import array +from dataclasses import dataclass +from enum import Enum +from typing import Optional + +import sqlalchemy.types as types +from sqlalchemy.types import Float + + +class VectorIndexType(Enum): + """Enum representing different types of VECTOR index structures. + + Each index type has different characteristics and configuration + parameters optimized for specific use cases in vector search. + + """ + + HNSW = "HNSW" + """ + The HNSW (Hierarchical Navigable Small World) index type. + """ + IVF = "IVF" + """ + The IVF (Inverted File Index) index type + """ + + +class VectorDistanceType(Enum): + """Enum representing different types of vector distance metrics.""" + + EUCLIDEAN = "EUCLIDEAN" + """Euclidean distance (L2 norm). + + Measures the straight-line distance between two vectors in space. + """ + DOT = "DOT" + """Dot product similarity. + + Measures the algebraic similarity between two vectors. + """ + COSINE = "COSINE" + """Cosine similarity. + + Measures the cosine of the angle between two vectors. + """ + MANHATTAN = "MANHATTAN" + """Manhattan distance (L1 norm). + + Calculates the sum of absolute differences across dimensions. + """ + + +class VectorStorageFormat(Enum): + """Enum representing the data format used to store vector components. + + Choosing the right format balances precision, memory usage, + and performance for vector indexing and search. + """ + + INT8 = "INT8" + """ + 8-bit integer format. + """ + BINARY = "BINARY" + """ + Binary format. + """ + FLOAT32 = "FLOAT32" + """ + 32-bit floating-point format. + """ + FLOAT64 = "FLOAT64" + """ + 64-bit floating-point format. + """ + + +@dataclass +class VectorIndexConfig: + """Define the configuration for Oracle VECTOR Index. + + :param index_type: Enum values from :class:`.VectorIndexType + Specifies the indexing method. For HNSW, this must be + :attr:`.VectorIndexType.HNSW`. + + :param distance: Enum values from :class:`.VectorDistanceType` + specifies the metric for calculating distance between VECTORS. + + :param accuracy: interger. Should be in the range 0 to 100 + Specifies the accuracy of the nearest neighbor search during + query execution. + + :param parallel: integer. Specifies degree of parallelism. + + :param hnsw_neighbors: interger. Should be in the range 0 to + 2048. Specifies the number of nearest neighbors considered + during the search The attribute :attr:`VectorIndexConfig.hnsw_neighbors` + is HNSW index specific. + + :param hnsw_efconstruction: integer. Should be in the range 0 + to 65535. Controls the trade-off between indexing speed and + recall quality during index construction.The attribute + :attr:`VectorIndexConfig.hnsw_efconstruction` is HNSW index + specific. + + :param ivf_neighbour_partitions: integer. Should be in the range + 0 to 10,000,000. Specifies the number of partitions used to + divide the dataset. The attribute + :attr:`VectorIndexConfig.ivf_neighbour_partitions` is IVF index + specific. + + :param ivf sample_per_partition: integer. Should be between 1 + and ``num_vectors / neighbor partitions``. Specifies the + number of samples used per partition. The attribute + :attr:`VectorIndexConfig.ivf sample_per_partition` is IVF index + specific. + + :param ivf_min_vectors_per_partition: integer. From 0 (no trimming) + to the total number of vectors (results in 1 partition). Specifies + the minimum number of vectors per partition. + + """ + + index_type: VectorIndexType = VectorIndexType.HNSW + distance: Optional[VectorDistanceType] = None + accuracy: Optional[int] = None + hnsw_neighbors: Optional[int] = None + hnsw_efconstruction: Optional[int] = None + ivf_neighbor_partitions: Optional[int] = None + ivf_sample_per_partition: Optional[int] = None + ivf_min_vectors_per_partition: Optional[int] = None + parallel: Optional[int] = None + + def __post_init__(self): + self.index_type = VectorIndexType(self.index_type) + for field in [ + "hnsw_neighbors", + "hnsw_efconstruction", + "ivf_neighbor_partitions", + "ivf_sample_per_partition", + "ivf_min_vectors_per_partition", + "parallel", + "accuracy", + ]: + value = getattr(self, field) + if value is not None and not isinstance(value, int): + raise TypeError( + f"{field} must be an integer if" + f"provided, got {type(value).__name__}" + ) + + +class VECTOR(types.TypeEngine): + """Oracle VECTOR datatype.""" + + cache_ok = True + __visit_name__ = "VECTOR" + + def __init__(self, dim=None, storage_format=None): + """Construct a VECTOR. + + :param dim: integer. The dimension of the VECTOR datatype. This + should be an integer value. + + :param storage_format: VectorStorageFormat. The VECTOR storage + type format. This may be Enum values form + `VectorStorageFormat` INT8, BINARY, FLOAT32, or FLOAT64. + + """ + if dim is not None and not isinstance(dim, int): + raise TypeError("dim must be an interger") + if storage_format is not None and not isinstance( + storage_format, VectorStorageFormat + ): + raise TypeError( + "storage_format must be an enum of type VectorStorageFormat" + ) + self.dim = dim + self.storage_format = storage_format + + def _cached_bind_processor(self, dialect): + """ + Convert a list to a array.array before binding it to the database. + """ + + def process(value): + if value is None or isinstance(value, array.array): + return value + + # Convert list to a array.array + elif isinstance(value, list): + typecode = self._array_typecode(self.storage_format) + value = array.array(typecode, value) + return value + + else: + raise TypeError("VECTOR accepts list or array.array()") + + return process + + def _cached_result_processor(self, dialect, coltype): + """ + Convert a array.array to list before binding it to the database. + """ + + def process(value): + if isinstance(value, array.array): + return list(value) + + return process + + def _array_typecode(self, typecode): + """ + Map storage format to array typecode. + """ + typecode_map = { + VectorStorageFormat.INT8: "b", # Signed int + VectorStorageFormat.BINARY: "B", # Unsigned int + VectorStorageFormat.FLOAT32: "f", # Float + VectorStorageFormat.FLOAT64: "d", # Double + } + return typecode_map.get(typecode, "d") + + class comparator_factory(types.TypeEngine.Comparator): + def l2_distance(self, other): + return self.op("<->", return_type=Float)(other) + + def inner_product(self, other): + return self.op("<#>", return_type=Float)(other) + + def cosine_distance(self, other): + return self.op("<=>", return_type=Float)(other) diff --git a/test/dialect/oracle/test_types.py b/test/dialect/oracle/test_types.py index cf22f20f97..e03a8e3a6c 100644 --- a/test/dialect/oracle/test_types.py +++ b/test/dialect/oracle/test_types.py @@ -1,8 +1,8 @@ +import array import datetime import decimal import os import random -import array from sqlalchemy import bindparam from sqlalchemy import cast @@ -16,6 +16,7 @@ from sqlalchemy import event from sqlalchemy import exc from sqlalchemy import FLOAT from sqlalchemy import Float +from sqlalchemy import Index from sqlalchemy import Integer from sqlalchemy import LargeBinary from sqlalchemy import literal @@ -35,12 +36,14 @@ from sqlalchemy import types as sqltypes from sqlalchemy import Unicode from sqlalchemy import UnicodeText from sqlalchemy import VARCHAR -from sqlalchemy import Index -from sqlalchemy import delete from sqlalchemy.dialects.oracle import base as oracle from sqlalchemy.dialects.oracle import cx_oracle from sqlalchemy.dialects.oracle import oracledb from sqlalchemy.dialects.oracle import VECTOR +from sqlalchemy.dialects.oracle import VectorDistanceType +from sqlalchemy.dialects.oracle import VectorIndexConfig +from sqlalchemy.dialects.oracle import VectorIndexType +from sqlalchemy.dialects.oracle import VectorStorageFormat from sqlalchemy.sql import column from sqlalchemy.sql.sqltypes import NullType from sqlalchemy.testing import AssertsCompiledSQL @@ -955,44 +958,45 @@ class TypesTest(fixtures.TestBase): finally: exec_sql(connection, "DROP TABLE Z_TEST") + @testing.fails_if("oracle<=23.4") def test_vector_dim(self, metadata, connection): - t1 = Table("t1", metadata, Column("c1", VECTOR(3, "float32"))) + t1 = Table( + "t1", + metadata, + Column( + "c1", VECTOR(dim=3, storage_format=VectorStorageFormat.FLOAT32) + ), + ) - if testing.against("oracle>23.4"): - t1.create(connection) - eq_(t1.c.c1.type.dim, 3) - else: - with expect_raises_message(exc.DatabaseError, "ORA-03060"): - t1.create(connection) + t1.create(connection) + eq_(t1.c.c1.type.dim, 3) + @testing.fails_if("oracle<=23.4") def test_vector_insert(self, metadata, connection): t1 = Table( "t1", metadata, Column("id", Integer, primary_key=True), - Column("c1", VECTOR("int8")), + Column("c1", VECTOR(storage_format=VectorStorageFormat.INT8)), ) - if testing.against("oracle>23.4"): - t1.create(connection) - connection.execute( - t1.insert(), - dict(id=1, c1=[6, 7, 8, 5]), - ) - eq_( - connection.execute(t1.select()).first(), - (1, [6, 7, 8, 5]), - ) - connection.execute(t1.delete().where(t1.c.id == 1)) - connection.execute(t1.insert(), dict(id=1, c1=[6, 7])) - eq_( - connection.execute(t1.select()).first(), - (1, [6, 7]), - ) - else: - with expect_raises_message(exc.DatabaseError, "ORA-03060"): - t1.create(connection) + t1.create(connection) + connection.execute( + t1.insert(), + dict(id=1, c1=[6, 7, 8, 5]), + ) + eq_( + connection.execute(t1.select()).first(), + (1, [6, 7, 8, 5]), + ) + connection.execute(t1.delete().where(t1.c.id == 1)) + connection.execute(t1.insert(), dict(id=1, c1=[6, 7])) + eq_( + connection.execute(t1.select()).first(), + (1, [6, 7]), + ) + @testing.fails_if("oracle<=23.4") def test_vector_insert_array(self, metadata, connection): t1 = Table( "t1", @@ -1001,30 +1005,27 @@ class TypesTest(fixtures.TestBase): Column("c1", VECTOR), ) - if testing.against("oracle>23.4"): - t1.create(connection) - connection.execute( - t1.insert(), - dict(id=1, c1=array.array("b", [6, 7, 8, 5])), - ) - eq_( - connection.execute(t1.select()).first(), - (1, [6, 7, 8, 5]), - ) + t1.create(connection) + connection.execute( + t1.insert(), + dict(id=1, c1=array.array("b", [6, 7, 8, 5])), + ) + eq_( + connection.execute(t1.select()).first(), + (1, [6, 7, 8, 5]), + ) - connection.execute(t1.delete().where(t1.c.id == 1)) + connection.execute(t1.delete().where(t1.c.id == 1)) - connection.execute( - t1.insert(), dict(id=1, c1=array.array("b", [6, 7])) - ) - eq_( - connection.execute(t1.select()).first(), - (1, [6, 7]), - ) - else: - with expect_raises_message(exc.DatabaseError, "ORA-03060"): - t1.create(connection) + connection.execute( + t1.insert(), dict(id=1, c1=array.array("b", [6, 7])) + ) + eq_( + connection.execute(t1.select()).first(), + (1, [6, 7]), + ) + @testing.fails_if("oracle<=23.4") def test_vector_multiformat_insert(self, metadata, connection): t1 = Table( "t1", @@ -1033,119 +1034,117 @@ class TypesTest(fixtures.TestBase): Column("c1", VECTOR), ) - if testing.against("oracle>23.4"): - t1.create(connection) - connection.execute( - t1.insert(), - dict(id=1, c1=[6.12, 7.54, 8.33]), - ) - eq_( - connection.execute(t1.select()).first(), - (1, [6.12, 7.54, 8.33]), - ) - connection.execute(t1.delete().where(t1.c.id == 1)) - connection.execute(t1.insert(), dict(id=1, c1=[6, 7])) - eq_( - connection.execute(t1.select()).first(), - (1, [6, 7]), - ) - else: - with expect_raises_message(exc.DatabaseError, "ORA-03060"): - t1.create(connection) + t1.create(connection) + connection.execute( + t1.insert(), + dict(id=1, c1=[6.12, 7.54, 8.33]), + ) + eq_( + connection.execute(t1.select()).first(), + (1, [6.12, 7.54, 8.33]), + ) + connection.execute(t1.delete().where(t1.c.id == 1)) + connection.execute(t1.insert(), dict(id=1, c1=[6, 7])) + eq_( + connection.execute(t1.select()).first(), + (1, [6, 7]), + ) + @testing.fails_if("oracle<=23.4") def test_vector_format(self, metadata, connection): - t1 = Table("t1", metadata, Column("c1", VECTOR(3, "float32"))) + t1 = Table( + "t1", + metadata, + Column( + "c1", VECTOR(dim=3, storage_format=VectorStorageFormat.FLOAT32) + ), + ) - if testing.against("oracle>23.4"): - t1.create(connection) - eq_(t1.c.c1.type.storage_format, "float32") - else: - with expect_raises_message(exc.DatabaseError, "ORA-03060"): - t1.create(connection) + t1.create(connection) + eq_(t1.c.c1.type.storage_format, VectorStorageFormat.FLOAT32) + @testing.fails_if("oracle<=23.4") def test_vector_hnsw_index(self, metadata, connection): t1 = Table( "t1", metadata, Column("id", Integer), - Column("embedding", VECTOR(3, "float32")), + Column( + "embedding", + VECTOR(dim=3, storage_format=VectorStorageFormat.FLOAT32), + ), ) - if testing.against("oracle>23.4"): - t1.create(connection) + t1.create(connection) - hnsw_index = Index( - "hnsw_vector_index", t1.c.embedding, oracle_vector=True - ) - hnsw_index.create(connection) + hnsw_index = Index( + "hnsw_vector_index", t1.c.embedding, oracle_vector=True + ) + hnsw_index.create(connection) - connection.execute(t1.insert(), dict(id=1, embedding=[6, 7, 8])) - eq_( - connection.execute(t1.select()).first(), - (1, [6.0, 7.0, 8.0]), - ) - else: - with expect_raises_message(exc.DatabaseError, "ORA-03060"): - t1.create(connection) + connection.execute(t1.insert(), dict(id=1, embedding=[6, 7, 8])) + eq_( + connection.execute(t1.select()).first(), + (1, [6.0, 7.0, 8.0]), + ) + @testing.fails_if("oracle<=23.4") def test_vector_ivf_index(self, metadata, connection): t1 = Table( "t1", metadata, Column("id", Integer), - Column("embedding", VECTOR(3, "float32")), - ) - - if testing.against("oracle>23.4"): - t1.create(connection) - ivf_index = Index( - "ivf_vector_index", - t1.c.embedding, - oracle_vector={ - "accuracy": 90, - "distance": "DOT", - "parameters": {"type": "IVF", "neighbor partitions": 10}, - }, - ) - ivf_index.create(connection) + Column( + "embedding", + VECTOR(dim=3, storage_format=VectorStorageFormat.FLOAT32), + ), + ) - connection.execute(t1.insert(), dict(id=1, embedding=[6, 7, 8])) - eq_( - connection.execute(t1.select()).first(), - (1, [6.0, 7.0, 8.0]), - ) - else: - with expect_raises_message(exc.DatabaseError, "ORA-03060"): - t1.create(connection) + t1.create(connection) + ivf_index = Index( + "ivf_vector_index", + t1.c.embedding, + oracle_vector=VectorIndexConfig( + index_type=VectorIndexType.IVF, + distance=VectorDistanceType.DOT, + accuracy=90, + ivf_neighbor_partitions=5, + ), + ) + ivf_index.create(connection) + connection.execute(t1.insert(), dict(id=1, embedding=[6, 7, 8])) + eq_( + connection.execute(t1.select()).first(), + (1, [6.0, 7.0, 8.0]), + ) + + @testing.fails_if("oracle<=23.4") def test_vector_l2_distance(self, metadata, connection): t1 = Table( "t1", metadata, Column("id", Integer), - Column("embedding", VECTOR(3, "int8")), + Column( + "embedding", + VECTOR(dim=3, storage_format=VectorStorageFormat.INT8), + ), ) - if testing.against("oracle>23.4"): - t1.create(connection) + t1.create(connection) - connection.execute(t1.insert(), dict(id=1, embedding=[8, 9, 10])) - connection.execute(t1.insert(), dict(id=2, embedding=[1, 2, 3])) - connection.execute( - t1.insert(), - dict(id=3, embedding=[15, 16, 17]), - ) + connection.execute(t1.insert(), dict(id=1, embedding=[8, 9, 10])) + connection.execute(t1.insert(), dict(id=2, embedding=[1, 2, 3])) + connection.execute( + t1.insert(), + dict(id=3, embedding=[15, 16, 17]), + ) - query_vector = [2, 3, 4] - res = connection.execute( - t1.select().order_by( - (t1.c.embedding.l2_distance(query_vector)) - ) - ).first() - eq_(res.embedding, [1, 2, 3]) - else: - with expect_raises_message(exc.DatabaseError, "ORA-03060"): - t1.create(connection) + query_vector = [2, 3, 4] + res = connection.execute( + t1.select().order_by((t1.c.embedding.l2_distance(query_vector))) + ).first() + eq_(res.embedding, [1, 2, 3]) class LOBFetchTest(fixtures.TablesTest): -- 2.47.3