TIMESTAMP,
VARCHAR,
VARCHAR2,
+ VECTOR,
)
.. versionadded:: 1.2.19 Added :class:`_types.NCHAR` to the list of datatypes
.. autoclass:: TIMESTAMP
:members: __init__
+.. autoclass:: VECTOR
+ :members: __init__
+
.. _oracledb:
python-oracledb
from .base import VARCHAR
from .base import VARCHAR2
from .base import VECTOR
+from .base import VectorIndexConfig
+from .base import VectorIndexType
+from .vector import VectorDistanceType
+from .vector import VectorStorageFormat
# Alias oracledb also as oracledb_async
oracledb_async = type(
"ROWID",
"REAL",
"VECTOR",
+ "VectorDistanceType",
+ "VectorIndexType",
+ "VectorIndexConfig",
+ "VectorStorageFormat",
)
number of prefix columns to compress, or ``True`` to use the default (all
columns for non-unique indexes, all but the last column for unique indexes).
+VECTOR Datatype
+---------------
+
+Oracle Database 23ai introduced a new VECTOR datatype for artificial intelligence and machine
+learning search operations. The VECTOR datatype is a homogeneous array of 8-bit signed integers,
+8-bit unsigned integers, 32-bit floating-point numbers, or 64-bit floating-point numbers.
+For more information on the VECTOR datatype please visit this `link
+<https://python-oracledb.readthedocs.io/en/latest/user_guide/vector_data_type.html>`_.
+
+CREATE TABLE
+~~~~~~~~~~~~
+
+With the :class:`.VECTOR` datatype, you can specify the dimension for the data and the storage
+format. Valid values for storage format are enum values from :class:`.VectorStorageFormat`
+(:attr:`.VectorStorageFormat.INT8`, :attr:`.VectorStorageFormat.BINARY`, :attr:`.VectorStorageFormat.FLOAT32`,
+:attr:`.VectorStorageFormat.FLOAT64`).
+To create a table that includes a :class:`.VECTOR` column::
+
+ from sqlalchemy.dialects.oracle import VECTOR, VectorStorageFormat
+
+ t = Table("t1", metadata,
+ Column('id', Integer, primary_key=True),
+ Column("embedding", VECTOR(dim=3, storage_format=VectorStorageFormat.FLOAT32),
+ Column(...), ...
+ )
+
+Vectors can also be defined with an arbitrary number of dimensions and formats. This allows
+you to specify vectors of different dimensions with the various storage formats mentioned above.
+
+For Example
+
+* In this case, the storage format is flexible, allowing any vector type data to be inserted,
+ such as INT8 or BINARY etc.
+
+ vector_col:Mapped[array.array] = mapped_column(VECTOR(dim=3))
+
+* The dimension is flexible in this case, meaning that any dimension vector can be used.
+
+ vector_col:Mapped[array.array] = mapped_column(VECTOR(storage_format=VectorStorageType.INT8))
+
+* Both the dimensions and the storage format are flexible.
+
+ vector_col:Mapped[array.array] = mapped_column(VECTOR)
+
+INSERT VECTOR DATA
+~~~~~~~~~~~~~~~~~~
+
+VECTOR data can be inserted using Python list or Python array.array() objects. Python arrays of type
+FLOAT (32-bit), DOUBLE (64-bit), or INT (8-bit signed integer) are used as bind values when
+inserting VECTOR columns::
+
+ from sqlalchemy import insert, select
+ import array
+
+ vector_data_8 = [1, 2, 3]
+ statement = insert(t1)
+ with engine.connect() as conn:
+ conn.execute(statement,[
+ {"id":1,"embedding":vector_data_8},
+ ])
+
+VECTOR INDEXES
+~~~~~~~~~~~~~~
+
+There are two VECTOR indexes supported in VECTOR search: IVF Flat index and HNSW
+index.
+
+To utilize VECTOR indexing, set the `oracle_vector` parameter to True to use
+the default values provided by Oracle. HNSW is the default indexing method::
+
+ Index(
+ 'vector_index',
+ t1.c.embedding,
+ oracle_vector = True,
+ )
+
+If you wish to use custom parameters, you can specify all the parameters using the VectorIndexConfig
+Dataclass in the `oracle_vector` option. To learn more about the parameters that can be passed please
+visit this `link. <https://docs.oracle.com/en/database/oracle/oracle-database/23/sqlrf/create-vector-index.html>`_
+
+ Index(
+ 'hnsw_vector_index',
+ t1.c.embedding,
+ oracle_vector = VectorIndexConfig(
+ index_type = VectorIndexType.HNSW,
+ distance = VectorDistanceType.COSINE,
+ accuracy = 90,
+ hnsw_neighbors = 5,
+ hnsw_efconstruction = 20,
+ parallel = 10
+ )
+ )
+
+ Index(
+ 'ivf_vector_index',
+ t1.c.embedding,
+ oracle_vector = VectorIndexConfig(
+ index_type = VectorIndexType.IVF,
+ distance = VectorDistanceType.DOT,
+ accuracy = 90,
+ ivf_neighbor_partitions = 5,
+ )
+ )
+
+Similarity Searching
+~~~~~~~~~~~~~~~~~~~~
+
+You can use the following shorthand VECTOR distance functions:
+
+* ``l2_distance``
+* ``cosine_distance``
+* ``inner_product``
+
+Example Usage::
+
+ from sqlalchemy.orm import Session
+ from sqlalchemy.sql import func
+ import array
+
+ session = Session(bind=engine)
+ query_vector = [2,3,4]
+ result_vector = session.scalars(select(t1).order_by(t1.embedding.l2_distance(query_vector)).limit(3))
+
+ for user in vector:
+ print(user.id,user.embedding)
+
+.. versionadded:: 2.1.0 Added support for VECTOR specific to Oracle Database.
+
""" # noqa
from __future__ import annotations
from collections import defaultdict
+from dataclasses import fields
from functools import lru_cache
from functools import wraps
import re
from .types import ROWID # noqa
from .types import TIMESTAMP
from .types import VARCHAR2 # noqa
-from .types import VECTOR
+from .vector import VECTOR
+from .vector import VectorIndexConfig
+from .vector import VectorIndexType
from ... import Computed
from ... import exc
from ... import schema as sa_schema
def visit_VECTOR(self, type_, **kw):
if type_.dim is None and type_.storage_format is None:
- return f"VECTOR"
+ return "VECTOR(*,*)"
elif type_.storage_format is None:
return f"VECTOR({type_.dim},*)"
elif type_.dim is None:
- return f"VECTOR(*, {type_.storage_format})"
+ return f"VECTOR(*,{type_.storage_format.value})"
else:
- return f"VECTOR({type_.dim},{type_.storage_format})"
+ return f"VECTOR({type_.dim},{type_.storage_format.value})"
class OracleCompiler(compiler.SQLCompiler):
class OracleDDLCompiler(compiler.DDLCompiler):
+
+ def _build_vector_index_config(
+ self, vector_index_config: VectorIndexConfig
+ ) -> str:
+ parts = []
+ sql_param_name = {
+ "hnsw_neighbors": "neighbors",
+ "hnsw_efconstruction": "efconstruction",
+ "ivf_neighbor_partitions": "neighbor partitions",
+ "ivf_sample_per_partition": "sample_per_partition",
+ "ivf_min_vectors_per_partition": "min_vectors_per_partition",
+ }
+ if vector_index_config.index_type == VectorIndexType.HNSW:
+ parts.append("ORGANIZATION INMEMORY NEIGHBOR GRAPH")
+ elif vector_index_config.index_type == VectorIndexType.IVF:
+ parts.append("ORGANIZATION NEIGHBOR PARTITIONS")
+ if vector_index_config.distance is not None:
+ parts.append(f"DISTANCE {vector_index_config.distance.value}")
+
+ if vector_index_config.accuracy is not None:
+ parts.append(
+ f"WITH TARGET ACCURACY {vector_index_config.accuracy}"
+ )
+
+ parameters_str = [f"type {vector_index_config.index_type.name}"]
+ prefix = vector_index_config.index_type.name.lower() + "_"
+
+ for field in fields(vector_index_config):
+ if field.name.startswith(prefix):
+ key = sql_param_name.get(field.name)
+ value = getattr(vector_index_config, field.name)
+ if value is not None:
+ parameters_str.append(f"{key} {value}")
+
+ parameters_str = ", ".join(parameters_str)
+ parts.append(f"PARAMETERS ({parameters_str})")
+
+ if vector_index_config.parallel is not None:
+ parts.append(f"PARALLEL {vector_index_config.parallel}")
+
+ return " ".join(parts)
+
def define_constraint_cascades(self, constraint):
text = ""
if constraint.ondelete is not None:
)
if vector_options:
if vector_options is True:
- vector_options = {}
- parts = []
- parameters = vector_options.get("parameters", {})
- using = parameters.get("type", "HNSW").upper()
- if using == "HNSW":
- parts.append("ORGANIZATION INMEMORY NEIGHBOR GRAPH")
- elif using == "IVF":
- parts.append("ORGANIZATION NEIGHBOR PARTITIONS")
- vector_distance = vector_options.get("distance")
- if vector_distance is not None:
- vector_distance = vector_distance.upper()
- if vector_distance not in (
- "EUCLIDEAN",
- "DOT",
- "COSINE",
- "MANHATTAN",
- ):
- raise ValueError("Unknown vector_distance value")
- parts.append(f"DISTANCE {vector_distance}")
- target_accuracy = vector_options.get("accuracy")
- if target_accuracy is not None:
- if target_accuracy < 0 or target_accuracy > 100:
- raise ValueError(
- "Accuracy value should be an integer between 0 and 100"
- )
- parts.append(f"WITH TARGET ACCURACY {target_accuracy}")
- if parameters:
- parameters_str = ", ".join(
- f"{k} {v}" for k, v in parameters.items()
- )
- parts.append(f"PARAMETERS ({parameters_str})")
- parallel = vector_options.get("parallel")
- if parallel is not None:
- if not isinstance(parallel, int):
- raise ValueError("Parallel value must be an integer")
- parts.append(f"PARALLEL {parallel}")
- text += " " + " ".join(parts)
+ vector_options = VectorIndexConfig()
+
+ text += " " + self._build_vector_index_config(vector_options)
return text
def post_create_table(self, table):
.. versionadded:: 2.0.0 added support for the python-oracledb driver.
-VECTOR Datatype
----------------
-
-Oracle Database 23ai introduced a new VECTOR datatype for artificial intelligence and machine
-learning search operations. The VECTOR datatype is a homogeneous array of 8-bit signed integers,
-8-bit unsigned integers, 32-bit floating-point numbers, or 64-bit floating-point numbers.
-For more information on the VECTOR datatype please visit this `link.
-<https://python-oracledb.readthedocs.io/en/latest/user_guide/vector_data_type.html>`_
-
-CREATE TABLE
-^^^^^^^^^^^^
-
-With the VECTOR datatype, you can specify the dimension for the data and the storage
-format. To create a table that includes a VECTOR column::
-
- from sqlalchemy.dialects.oracle import VECTOR
-
- t = Table("t1", metadata,
- Column('id', Integer, primary_key=True),
- Column("embedding", VECTOR(3, 'float32'),
- Column(...), ...
- )
-
-Vectors can also be defined with an arbitrary number of dimensions and formats. This allows
-you to specify vectors of different dimensions with the various storage formats mentioned above.
-
-For Example
-
-* In this case, the storage format is flexible, allowing any vector type data to be inserted,
- such as int8 or binary etc.
-
- vector_col:Mapped[array.array] = mapped_column(VECTOR(3))
-
-* The dimension is flexible in this case, meaning that any dimension vector can be used.
-
- vector_col:Mapped[array.array] = mapped_column(VECTOR('int8'))
-
-* Both the dimensions and the storage format are flexible.
-
- vector_col:Mapped[array.array] = mapped_column(VECTOR)
-
-INSERT VECTOR DATA
-^^^^^^^^^^^^^^^^^^
-
-VECTOR data can be inserted using Python list or Python array.array() objects. Python arrays of type
-float (32-bit), double (64-bit), or int8_t (8-bit signed integer) are used as bind values when
-inserting VECTOR columns::
-
- from sqlalchemy import insert, select
- import array
-
- vector_data_8 = [1, 2, 3]
- statement = insert(t1)
- with engine.connect() as conn:
- conn.execute(statement,[
- {"id":1,"embedding":vector_data_8},
- ])
-
-VECTOR INDEXES
-^^^^^^^^^^^^^^
-
-There are two VECTOR indexes supported in VECTOR search: IVF Flat index and HNSW
-index.
-
-To utilize VECTOR indexing, set the `oracle_vector` parameter to True to use
-the default values provided by Oracle. HNSW is the default indexing method::
-
- Index(
- 'vector_index',
- t1.c.embedding,
- oracle_vector = True,
- )
-
-If you wish to use custom parameters, you can specify all the parameters as a dictionary
-in the `oracle_vector` option. To learn more about the parameters that can be passed please
-visit this `link. <https://docs.oracle.com/en/database/oracle/oracle-database/23/sqlrf/create-vector-index.html>`_
-
-Configuring Oracle VECTOR Indexes
-=================================
-
-When using Oracle VECTOR indexes, the configuration parameters are divided into two levels:
-**top-level keys** and **nested keys**. This structure applies to both HNSW and IVF VECTOR indexes.
-
-Top-Level Keys
-==============
-
-These keys are specified directly under the ``oracle_vector`` dictionary.
-
-* ``accuracy``:
- - Specifies the accuracy of the nearest neighbor search during query execution.
- - **Valid Range**: Greater than 0 and less than or equal to 100.
- - **Example**: ``'accuracy': 95``
-
-* ``distance``:
- - Specifies the metric for calculating distance between VECTORS.
- - **Valid Values**: ``"EUCLIDEAN"``, ``"COSINE"``, ``"DOT"``, ``"MANHATTAN"``.
- - **Example**: ``'distance': "COSINE"``
-
-* ``parameters``:
- - A nested dictionary where method-specific options are defined (e.g., HNSW or IVF-specific settings).
- - **Example**: ``'parameters': {...}``
-
-Nested keys in parameters
-=========================
-
-These keys are specific to the indexing method and are included under the ``parameters`` dictionary.
-
-HNSW Parameters
-===============
-
-* ``type``:
- - Specifies the indexing method. For HNSW, this must be ``"HNSW"``.
- - **Placement**: Nested under ``parameters``.
- - **Example**: ``'type': 'HNSW'``
-
-* ``neighbors``:
- - The number of nearest neighbors considered during the search.
- - **Valid Range**: Greater than 0 and less than or equal to 2048.
- - **Placement**: Nested under ``parameters``.
- - **Example**: ``'neighbors': 20``
-
-* ``efconstruction``:
- - Controls the trade-off between indexing speed and recall quality during index construction.
- - **Valid Range**: Greater than 0 and less than or equal to 65535.
- - **Placement**: Nested under ``parameters``.
- - **Example**: ``'efconstruction': 300``
-
-IVF Parameters
-==============
-
-* ``type``:
- - Specifies the indexing method. For IVF, this must be ``"IVF"``.
- - **Placement**: Nested under ``parameters``.
- - **Example**: ``'type': 'IVF'``
-
-* ``neighbor partitions``:
- - The number of partitions used to divide the dataset.
- - **Valid Range**: Greater than 0 and less than or equal to 10,000,000.
- - **Placement**: Nested under ``parameters``.
- - **Example**: ``'neighbor partitions': 10``
-
-* ``sample_per_partition``:
- - The number of samples used per partition.
- - **Valid Range**: Between 1 and ``num_vectors / neighbor_partitions``.
- - **Placement**: Nested under ``parameters``.
- - **Example**: ``'sample_per_partition': 5``
-
-* ``min_vectors_per_partition``:
- - The minimum number of vectors per partition.
- - **Valid Range**: From 0 (no trimming) to the total number of vectors (results in 1 partition).
- - **Placement**: Nested under ``parameters``.
- - **Example**: ``'min_vectors_per_partition': 100``
-
-Example Configurations
-======================
-
-For custom configurations, the parameters can be specified as shown in the following examples::
-
- Index(
- 'hnsw_vector_index',
- t1.c.embedding,
- oracle_vector = {
- 'accuracy':95,
- 'distance':"COSINE",
- 'parameters':{
- 'type':'HNSW',
- 'neighbors':20',
- efconstruction':300
- }
- }
- )
-
- Index(
- 'ivf_vector_index',
- t1.c.embedding,
- oracle_vector = {
- 'accuracy':90,
- 'distance':"DOT",
- 'parameters':{
- 'type':'IVF','
- neighbor partitions':10
- }
- }
- )
-
-Similarity Searching
-^^^^^^^^^^^^^^^^^^^^
-
-You can use the following shorthand VECTOR distance functions:
-
-* ``l2_distance``
-* ``cosine_distance``
-* ``inner_product``
-
-Example Usage::
-
- from sqlalchemy.orm import Session
- from sqlalchemy.sql import func
- import array
-
- session = Session(bind=engine)
- query_vector = [2,3,4]
- result_vector = session.scalars(select(t1).order_by(t1.embedding.l2_distance(query_vector)).limit(3))
-
- for user in vector:
- print(user.id,user.embedding)
-
-Exact and Approximate Searching
-^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-
-Similarity searches tend to get data from one or more clusters depending on the value of the query VECTOR and the fetch
-size. Approximate searches using VECTOR indexes can limit the searches to specific clusters, whereas exact searches visit
-VECTORS across all clusters.
-You can use the fetch_type clause to set the searching to be either EXACT, APPROX or APPROXIMATE::
-
- result_vector = session.scalars(select(t1).order_by(t1.embedding.l1_distance(query_vector)).limit(3)).fetch_type("EXACT")
- result_vector = session.scalars(select(t1).order_by(t1.embedding.l1_distance(query_vector)).limit(3)).fetch_type("APPROX")
-
-
-.. versionadded:: 2.1.0 Added support for VECTOR specific to Oracle Database.
-:class:`_oracle.VECTOR` datatype.
-
""" # noqa
from __future__ import annotations
from typing import Optional
from typing import Type
from typing import TYPE_CHECKING
-import sqlalchemy.types as types
-from sqlalchemy.types import UserDefinedType, Float
from ... import exc
from ...sql import sqltypes
from ...types import NVARCHAR
from ...types import VARCHAR
-import array
if TYPE_CHECKING:
from ...engine.interfaces import Dialect
class _OracleBoolean(sqltypes.Boolean):
def get_dbapi_type(self, dbapi):
return dbapi.NUMBER
-
-
-class VECTOR(types.TypeEngine):
- """Oracle VECTOR datatype."""
-
- cache_ok = True
- __visit_name__ = "VECTOR"
-
- def __init__(self, dim=None, storage_format=None, *args):
- """
- :param dim: The dimension of the VECTOR datatype. This should be an
- integer value.
- :param storage_format: The VECTOR storage type format. This
- may be int8, binary, float32, or float64.
- """
- if dim is not None and isinstance(dim, int):
- self.dim = dim
- self.storage_format = storage_format
-
- elif dim is not None and isinstance(dim, str):
- self.dim = storage_format
- self.storage_format = dim
-
- else:
- self.dim = storage_format
- self.storage_format = dim
-
- def _cached_bind_processor(self, dialect):
- """
- Convert a list to a array.array before binding it to the database.
- """
-
- def process(value):
- if value is None or isinstance(value, array.array):
- return value
-
- # Convert list to a array.array
- elif isinstance(value, list):
- format = self._array_typecode(self.storage_format)
- value = array.array(format, value)
- return value
-
- else:
- raise TypeError("VECTOR accepts list or array.array()")
-
- return process
-
- def _cached_result_processor(self, dialect, coltype):
- """
- Convert a array.array to list before binding it to the database.
- """
-
- def process(value):
- if isinstance(value, array.array):
- return list(value)
-
- return process
-
- def _array_typecode(self, format):
- """
- Map storage format to array typecode.
- """
- typecode_map = {
- "int8": "b", # Signed int
- "binary": "B", # Unsigned int
- "float32": "f", # Float
- "float64": "d", # Double
- }
- return typecode_map.get(format, "d")
-
- class comparator_factory(types.TypeEngine.Comparator):
- def l2_distance(self, other):
- return self.op("<->", return_type=Float)(other)
-
- def inner_product(self, other):
- return self.op("<#>", return_type=Float)(other)
-
- def cosine_distance(self, other):
- return self.op("<=>", return_type=Float)(other)
--- /dev/null
+# dialects/oracle/vector.py
+# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+# mypy: ignore-errors
+
+
+from __future__ import annotations
+
+import array
+from dataclasses import dataclass
+from enum import Enum
+from typing import Optional
+
+import sqlalchemy.types as types
+from sqlalchemy.types import Float
+
+
+class VectorIndexType(Enum):
+ """Enum representing different types of VECTOR index structures.
+
+ Each index type has different characteristics and configuration
+ parameters optimized for specific use cases in vector search.
+
+ """
+
+ HNSW = "HNSW"
+ """
+ The HNSW (Hierarchical Navigable Small World) index type.
+ """
+ IVF = "IVF"
+ """
+ The IVF (Inverted File Index) index type
+ """
+
+
+class VectorDistanceType(Enum):
+ """Enum representing different types of vector distance metrics."""
+
+ EUCLIDEAN = "EUCLIDEAN"
+ """Euclidean distance (L2 norm).
+
+ Measures the straight-line distance between two vectors in space.
+ """
+ DOT = "DOT"
+ """Dot product similarity.
+
+ Measures the algebraic similarity between two vectors.
+ """
+ COSINE = "COSINE"
+ """Cosine similarity.
+
+ Measures the cosine of the angle between two vectors.
+ """
+ MANHATTAN = "MANHATTAN"
+ """Manhattan distance (L1 norm).
+
+ Calculates the sum of absolute differences across dimensions.
+ """
+
+
+class VectorStorageFormat(Enum):
+ """Enum representing the data format used to store vector components.
+
+ Choosing the right format balances precision, memory usage,
+ and performance for vector indexing and search.
+ """
+
+ INT8 = "INT8"
+ """
+ 8-bit integer format.
+ """
+ BINARY = "BINARY"
+ """
+ Binary format.
+ """
+ FLOAT32 = "FLOAT32"
+ """
+ 32-bit floating-point format.
+ """
+ FLOAT64 = "FLOAT64"
+ """
+ 64-bit floating-point format.
+ """
+
+
+@dataclass
+class VectorIndexConfig:
+ """Define the configuration for Oracle VECTOR Index.
+
+ :param index_type: Enum values from :class:`.VectorIndexType
+ Specifies the indexing method. For HNSW, this must be
+ :attr:`.VectorIndexType.HNSW`.
+
+ :param distance: Enum values from :class:`.VectorDistanceType`
+ specifies the metric for calculating distance between VECTORS.
+
+ :param accuracy: interger. Should be in the range 0 to 100
+ Specifies the accuracy of the nearest neighbor search during
+ query execution.
+
+ :param parallel: integer. Specifies degree of parallelism.
+
+ :param hnsw_neighbors: interger. Should be in the range 0 to
+ 2048. Specifies the number of nearest neighbors considered
+ during the search The attribute :attr:`VectorIndexConfig.hnsw_neighbors`
+ is HNSW index specific.
+
+ :param hnsw_efconstruction: integer. Should be in the range 0
+ to 65535. Controls the trade-off between indexing speed and
+ recall quality during index construction.The attribute
+ :attr:`VectorIndexConfig.hnsw_efconstruction` is HNSW index
+ specific.
+
+ :param ivf_neighbour_partitions: integer. Should be in the range
+ 0 to 10,000,000. Specifies the number of partitions used to
+ divide the dataset. The attribute
+ :attr:`VectorIndexConfig.ivf_neighbour_partitions` is IVF index
+ specific.
+
+ :param ivf sample_per_partition: integer. Should be between 1
+ and ``num_vectors / neighbor partitions``. Specifies the
+ number of samples used per partition. The attribute
+ :attr:`VectorIndexConfig.ivf sample_per_partition` is IVF index
+ specific.
+
+ :param ivf_min_vectors_per_partition: integer. From 0 (no trimming)
+ to the total number of vectors (results in 1 partition). Specifies
+ the minimum number of vectors per partition.
+
+ """
+
+ index_type: VectorIndexType = VectorIndexType.HNSW
+ distance: Optional[VectorDistanceType] = None
+ accuracy: Optional[int] = None
+ hnsw_neighbors: Optional[int] = None
+ hnsw_efconstruction: Optional[int] = None
+ ivf_neighbor_partitions: Optional[int] = None
+ ivf_sample_per_partition: Optional[int] = None
+ ivf_min_vectors_per_partition: Optional[int] = None
+ parallel: Optional[int] = None
+
+ def __post_init__(self):
+ self.index_type = VectorIndexType(self.index_type)
+ for field in [
+ "hnsw_neighbors",
+ "hnsw_efconstruction",
+ "ivf_neighbor_partitions",
+ "ivf_sample_per_partition",
+ "ivf_min_vectors_per_partition",
+ "parallel",
+ "accuracy",
+ ]:
+ value = getattr(self, field)
+ if value is not None and not isinstance(value, int):
+ raise TypeError(
+ f"{field} must be an integer if"
+ f"provided, got {type(value).__name__}"
+ )
+
+
+class VECTOR(types.TypeEngine):
+ """Oracle VECTOR datatype."""
+
+ cache_ok = True
+ __visit_name__ = "VECTOR"
+
+ def __init__(self, dim=None, storage_format=None):
+ """Construct a VECTOR.
+
+ :param dim: integer. The dimension of the VECTOR datatype. This
+ should be an integer value.
+
+ :param storage_format: VectorStorageFormat. The VECTOR storage
+ type format. This may be Enum values form
+ `VectorStorageFormat` INT8, BINARY, FLOAT32, or FLOAT64.
+
+ """
+ if dim is not None and not isinstance(dim, int):
+ raise TypeError("dim must be an interger")
+ if storage_format is not None and not isinstance(
+ storage_format, VectorStorageFormat
+ ):
+ raise TypeError(
+ "storage_format must be an enum of type VectorStorageFormat"
+ )
+ self.dim = dim
+ self.storage_format = storage_format
+
+ def _cached_bind_processor(self, dialect):
+ """
+ Convert a list to a array.array before binding it to the database.
+ """
+
+ def process(value):
+ if value is None or isinstance(value, array.array):
+ return value
+
+ # Convert list to a array.array
+ elif isinstance(value, list):
+ typecode = self._array_typecode(self.storage_format)
+ value = array.array(typecode, value)
+ return value
+
+ else:
+ raise TypeError("VECTOR accepts list or array.array()")
+
+ return process
+
+ def _cached_result_processor(self, dialect, coltype):
+ """
+ Convert a array.array to list before binding it to the database.
+ """
+
+ def process(value):
+ if isinstance(value, array.array):
+ return list(value)
+
+ return process
+
+ def _array_typecode(self, typecode):
+ """
+ Map storage format to array typecode.
+ """
+ typecode_map = {
+ VectorStorageFormat.INT8: "b", # Signed int
+ VectorStorageFormat.BINARY: "B", # Unsigned int
+ VectorStorageFormat.FLOAT32: "f", # Float
+ VectorStorageFormat.FLOAT64: "d", # Double
+ }
+ return typecode_map.get(typecode, "d")
+
+ class comparator_factory(types.TypeEngine.Comparator):
+ def l2_distance(self, other):
+ return self.op("<->", return_type=Float)(other)
+
+ def inner_product(self, other):
+ return self.op("<#>", return_type=Float)(other)
+
+ def cosine_distance(self, other):
+ return self.op("<=>", return_type=Float)(other)
+import array
import datetime
import decimal
import os
import random
-import array
from sqlalchemy import bindparam
from sqlalchemy import cast
from sqlalchemy import exc
from sqlalchemy import FLOAT
from sqlalchemy import Float
+from sqlalchemy import Index
from sqlalchemy import Integer
from sqlalchemy import LargeBinary
from sqlalchemy import literal
from sqlalchemy import Unicode
from sqlalchemy import UnicodeText
from sqlalchemy import VARCHAR
-from sqlalchemy import Index
-from sqlalchemy import delete
from sqlalchemy.dialects.oracle import base as oracle
from sqlalchemy.dialects.oracle import cx_oracle
from sqlalchemy.dialects.oracle import oracledb
from sqlalchemy.dialects.oracle import VECTOR
+from sqlalchemy.dialects.oracle import VectorDistanceType
+from sqlalchemy.dialects.oracle import VectorIndexConfig
+from sqlalchemy.dialects.oracle import VectorIndexType
+from sqlalchemy.dialects.oracle import VectorStorageFormat
from sqlalchemy.sql import column
from sqlalchemy.sql.sqltypes import NullType
from sqlalchemy.testing import AssertsCompiledSQL
finally:
exec_sql(connection, "DROP TABLE Z_TEST")
+ @testing.fails_if("oracle<=23.4")
def test_vector_dim(self, metadata, connection):
- t1 = Table("t1", metadata, Column("c1", VECTOR(3, "float32")))
+ t1 = Table(
+ "t1",
+ metadata,
+ Column(
+ "c1", VECTOR(dim=3, storage_format=VectorStorageFormat.FLOAT32)
+ ),
+ )
- if testing.against("oracle>23.4"):
- t1.create(connection)
- eq_(t1.c.c1.type.dim, 3)
- else:
- with expect_raises_message(exc.DatabaseError, "ORA-03060"):
- t1.create(connection)
+ t1.create(connection)
+ eq_(t1.c.c1.type.dim, 3)
+ @testing.fails_if("oracle<=23.4")
def test_vector_insert(self, metadata, connection):
t1 = Table(
"t1",
metadata,
Column("id", Integer, primary_key=True),
- Column("c1", VECTOR("int8")),
+ Column("c1", VECTOR(storage_format=VectorStorageFormat.INT8)),
)
- if testing.against("oracle>23.4"):
- t1.create(connection)
- connection.execute(
- t1.insert(),
- dict(id=1, c1=[6, 7, 8, 5]),
- )
- eq_(
- connection.execute(t1.select()).first(),
- (1, [6, 7, 8, 5]),
- )
- connection.execute(t1.delete().where(t1.c.id == 1))
- connection.execute(t1.insert(), dict(id=1, c1=[6, 7]))
- eq_(
- connection.execute(t1.select()).first(),
- (1, [6, 7]),
- )
- else:
- with expect_raises_message(exc.DatabaseError, "ORA-03060"):
- t1.create(connection)
+ t1.create(connection)
+ connection.execute(
+ t1.insert(),
+ dict(id=1, c1=[6, 7, 8, 5]),
+ )
+ eq_(
+ connection.execute(t1.select()).first(),
+ (1, [6, 7, 8, 5]),
+ )
+ connection.execute(t1.delete().where(t1.c.id == 1))
+ connection.execute(t1.insert(), dict(id=1, c1=[6, 7]))
+ eq_(
+ connection.execute(t1.select()).first(),
+ (1, [6, 7]),
+ )
+ @testing.fails_if("oracle<=23.4")
def test_vector_insert_array(self, metadata, connection):
t1 = Table(
"t1",
Column("c1", VECTOR),
)
- if testing.against("oracle>23.4"):
- t1.create(connection)
- connection.execute(
- t1.insert(),
- dict(id=1, c1=array.array("b", [6, 7, 8, 5])),
- )
- eq_(
- connection.execute(t1.select()).first(),
- (1, [6, 7, 8, 5]),
- )
+ t1.create(connection)
+ connection.execute(
+ t1.insert(),
+ dict(id=1, c1=array.array("b", [6, 7, 8, 5])),
+ )
+ eq_(
+ connection.execute(t1.select()).first(),
+ (1, [6, 7, 8, 5]),
+ )
- connection.execute(t1.delete().where(t1.c.id == 1))
+ connection.execute(t1.delete().where(t1.c.id == 1))
- connection.execute(
- t1.insert(), dict(id=1, c1=array.array("b", [6, 7]))
- )
- eq_(
- connection.execute(t1.select()).first(),
- (1, [6, 7]),
- )
- else:
- with expect_raises_message(exc.DatabaseError, "ORA-03060"):
- t1.create(connection)
+ connection.execute(
+ t1.insert(), dict(id=1, c1=array.array("b", [6, 7]))
+ )
+ eq_(
+ connection.execute(t1.select()).first(),
+ (1, [6, 7]),
+ )
+ @testing.fails_if("oracle<=23.4")
def test_vector_multiformat_insert(self, metadata, connection):
t1 = Table(
"t1",
Column("c1", VECTOR),
)
- if testing.against("oracle>23.4"):
- t1.create(connection)
- connection.execute(
- t1.insert(),
- dict(id=1, c1=[6.12, 7.54, 8.33]),
- )
- eq_(
- connection.execute(t1.select()).first(),
- (1, [6.12, 7.54, 8.33]),
- )
- connection.execute(t1.delete().where(t1.c.id == 1))
- connection.execute(t1.insert(), dict(id=1, c1=[6, 7]))
- eq_(
- connection.execute(t1.select()).first(),
- (1, [6, 7]),
- )
- else:
- with expect_raises_message(exc.DatabaseError, "ORA-03060"):
- t1.create(connection)
+ t1.create(connection)
+ connection.execute(
+ t1.insert(),
+ dict(id=1, c1=[6.12, 7.54, 8.33]),
+ )
+ eq_(
+ connection.execute(t1.select()).first(),
+ (1, [6.12, 7.54, 8.33]),
+ )
+ connection.execute(t1.delete().where(t1.c.id == 1))
+ connection.execute(t1.insert(), dict(id=1, c1=[6, 7]))
+ eq_(
+ connection.execute(t1.select()).first(),
+ (1, [6, 7]),
+ )
+ @testing.fails_if("oracle<=23.4")
def test_vector_format(self, metadata, connection):
- t1 = Table("t1", metadata, Column("c1", VECTOR(3, "float32")))
+ t1 = Table(
+ "t1",
+ metadata,
+ Column(
+ "c1", VECTOR(dim=3, storage_format=VectorStorageFormat.FLOAT32)
+ ),
+ )
- if testing.against("oracle>23.4"):
- t1.create(connection)
- eq_(t1.c.c1.type.storage_format, "float32")
- else:
- with expect_raises_message(exc.DatabaseError, "ORA-03060"):
- t1.create(connection)
+ t1.create(connection)
+ eq_(t1.c.c1.type.storage_format, VectorStorageFormat.FLOAT32)
+ @testing.fails_if("oracle<=23.4")
def test_vector_hnsw_index(self, metadata, connection):
t1 = Table(
"t1",
metadata,
Column("id", Integer),
- Column("embedding", VECTOR(3, "float32")),
+ Column(
+ "embedding",
+ VECTOR(dim=3, storage_format=VectorStorageFormat.FLOAT32),
+ ),
)
- if testing.against("oracle>23.4"):
- t1.create(connection)
+ t1.create(connection)
- hnsw_index = Index(
- "hnsw_vector_index", t1.c.embedding, oracle_vector=True
- )
- hnsw_index.create(connection)
+ hnsw_index = Index(
+ "hnsw_vector_index", t1.c.embedding, oracle_vector=True
+ )
+ hnsw_index.create(connection)
- connection.execute(t1.insert(), dict(id=1, embedding=[6, 7, 8]))
- eq_(
- connection.execute(t1.select()).first(),
- (1, [6.0, 7.0, 8.0]),
- )
- else:
- with expect_raises_message(exc.DatabaseError, "ORA-03060"):
- t1.create(connection)
+ connection.execute(t1.insert(), dict(id=1, embedding=[6, 7, 8]))
+ eq_(
+ connection.execute(t1.select()).first(),
+ (1, [6.0, 7.0, 8.0]),
+ )
+ @testing.fails_if("oracle<=23.4")
def test_vector_ivf_index(self, metadata, connection):
t1 = Table(
"t1",
metadata,
Column("id", Integer),
- Column("embedding", VECTOR(3, "float32")),
- )
-
- if testing.against("oracle>23.4"):
- t1.create(connection)
- ivf_index = Index(
- "ivf_vector_index",
- t1.c.embedding,
- oracle_vector={
- "accuracy": 90,
- "distance": "DOT",
- "parameters": {"type": "IVF", "neighbor partitions": 10},
- },
- )
- ivf_index.create(connection)
+ Column(
+ "embedding",
+ VECTOR(dim=3, storage_format=VectorStorageFormat.FLOAT32),
+ ),
+ )
- connection.execute(t1.insert(), dict(id=1, embedding=[6, 7, 8]))
- eq_(
- connection.execute(t1.select()).first(),
- (1, [6.0, 7.0, 8.0]),
- )
- else:
- with expect_raises_message(exc.DatabaseError, "ORA-03060"):
- t1.create(connection)
+ t1.create(connection)
+ ivf_index = Index(
+ "ivf_vector_index",
+ t1.c.embedding,
+ oracle_vector=VectorIndexConfig(
+ index_type=VectorIndexType.IVF,
+ distance=VectorDistanceType.DOT,
+ accuracy=90,
+ ivf_neighbor_partitions=5,
+ ),
+ )
+ ivf_index.create(connection)
+ connection.execute(t1.insert(), dict(id=1, embedding=[6, 7, 8]))
+ eq_(
+ connection.execute(t1.select()).first(),
+ (1, [6.0, 7.0, 8.0]),
+ )
+
+ @testing.fails_if("oracle<=23.4")
def test_vector_l2_distance(self, metadata, connection):
t1 = Table(
"t1",
metadata,
Column("id", Integer),
- Column("embedding", VECTOR(3, "int8")),
+ Column(
+ "embedding",
+ VECTOR(dim=3, storage_format=VectorStorageFormat.INT8),
+ ),
)
- if testing.against("oracle>23.4"):
- t1.create(connection)
+ t1.create(connection)
- connection.execute(t1.insert(), dict(id=1, embedding=[8, 9, 10]))
- connection.execute(t1.insert(), dict(id=2, embedding=[1, 2, 3]))
- connection.execute(
- t1.insert(),
- dict(id=3, embedding=[15, 16, 17]),
- )
+ connection.execute(t1.insert(), dict(id=1, embedding=[8, 9, 10]))
+ connection.execute(t1.insert(), dict(id=2, embedding=[1, 2, 3]))
+ connection.execute(
+ t1.insert(),
+ dict(id=3, embedding=[15, 16, 17]),
+ )
- query_vector = [2, 3, 4]
- res = connection.execute(
- t1.select().order_by(
- (t1.c.embedding.l2_distance(query_vector))
- )
- ).first()
- eq_(res.embedding, [1, 2, 3])
- else:
- with expect_raises_message(exc.DatabaseError, "ORA-03060"):
- t1.create(connection)
+ query_vector = [2, 3, 4]
+ res = connection.execute(
+ t1.select().order_by((t1.c.embedding.l2_distance(query_vector)))
+ ).first()
+ eq_(res.embedding, [1, 2, 3])
class LOBFetchTest(fixtures.TablesTest):