From: suraj Date: Mon, 5 May 2025 15:14:35 +0000 (-0400) Subject: Added vector datatype support in Oracle dialect X-Git-Tag: rel_2_0_41~9^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=1ca4fe4159084c8ceda360f418c73118c4d15002;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Added vector datatype support in Oracle dialect Added new datatype :class:`_oracle.VECTOR` and accompanying DDL and DQL support to fully support this type for Oracle Database. This change includes the base :class:`_oracle.VECTOR` type that adds new type-specific methods ``l2_distance``, ``cosine_distance``, ``inner_product`` as well as new parameters ``oracle_vector`` for the :class:`.Index` construct, allowing vector indexes to be configured, and ``oracle_fetch_approximate`` for the :meth:`.Select.fetch` clause. Pull request courtesy Suraj Shaw. Fixes: #12317 Closes: #12321 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/12321 Pull-request-sha: a72a18a45c85ae7fa50a34e97ac642e16b463b54 Change-Id: I6f3af4623ce439d0820c14582cd129df293f0ba8 (cherry picked from commit 1b780ce3d3f7e33e5cc9e49eafa316a514cdc324) --- diff --git a/doc/build/changelog/unreleased_20/12317.rst b/doc/build/changelog/unreleased_20/12317.rst new file mode 100644 index 0000000000..13f69693e6 --- /dev/null +++ b/doc/build/changelog/unreleased_20/12317.rst @@ -0,0 +1,16 @@ +.. change:: + :tags: usecase, oracle + :tickets: 12317, 12341 + + Added new datatype :class:`_oracle.VECTOR` and accompanying DDL and DQL + support to fully support this type for Oracle Database. This change + includes the base :class:`_oracle.VECTOR` type that adds new type-specific + methods ``l2_distance``, ``cosine_distance``, ``inner_product`` as well as + new parameters ``oracle_vector`` for the :class:`.Index` construct, + allowing vector indexes to be configured, and ``oracle_fetch_approximate`` + for the :meth:`.Select.fetch` clause. Pull request courtesy Suraj Shaw. + + .. seealso:: + + :ref:`oracle_vector_datatype` + diff --git a/doc/build/dialects/oracle.rst b/doc/build/dialects/oracle.rst index b3d44858ce..882f926604 100644 --- a/doc/build/dialects/oracle.rst +++ b/doc/build/dialects/oracle.rst @@ -31,6 +31,7 @@ originate from :mod:`sqlalchemy.types` or from the local dialect:: TIMESTAMP, VARCHAR, VARCHAR2, + VECTOR, ) .. versionadded:: 1.2.19 Added :class:`_types.NCHAR` to the list of datatypes @@ -80,6 +81,23 @@ construction arguments, are as follows: .. autoclass:: TIMESTAMP :members: __init__ +.. autoclass:: VECTOR + :members: __init__ + +.. autoclass:: VectorIndexType + :members: + +.. autoclass:: VectorIndexConfig + :members: + :undoc-members: + +.. autoclass:: VectorStorageFormat + :members: + +.. autoclass:: VectorDistanceType + :members: + + .. _oracledb: python-oracledb diff --git a/lib/sqlalchemy/dialects/oracle/__init__.py b/lib/sqlalchemy/dialects/oracle/__init__.py index 7ceb743d61..2265de033c 100644 --- a/lib/sqlalchemy/dialects/oracle/__init__.py +++ b/lib/sqlalchemy/dialects/oracle/__init__.py @@ -32,6 +32,11 @@ from .base import ROWID from .base import TIMESTAMP from .base import VARCHAR from .base import VARCHAR2 +from .base import VECTOR +from .base import VectorIndexConfig +from .base import VectorIndexType +from .vector import VectorDistanceType +from .vector import VectorStorageFormat # Alias oracledb also as oracledb_async oracledb_async = type( @@ -64,4 +69,9 @@ __all__ = ( "NVARCHAR2", "ROWID", "REAL", + "VECTOR", + "VectorDistanceType", + "VectorIndexType", + "VectorIndexConfig", + "VectorStorageFormat", ) diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py index 02aa4d5366..1d882def8d 100644 --- a/lib/sqlalchemy/dialects/oracle/base.py +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -722,11 +722,177 @@ The ``oracle_compress`` parameter accepts either an integer specifying the number of prefix columns to compress, or ``True`` to use the default (all columns for non-unique indexes, all but the last column for unique indexes). +.. _oracle_vector_datatype: + +VECTOR Datatype +--------------- + +Oracle Database 23ai introduced a new VECTOR datatype for artificial intelligence +and machine learning search operations. The VECTOR datatype is a homogeneous array +of 8-bit signed integers, 8-bit unsigned integers (binary), 32-bit floating-point numbers, +or 64-bit floating-point numbers. + +.. seealso:: + + `Using VECTOR Data + `_ - in the documentation + for the :ref:`oracledb` driver. + +.. versionadded:: 2.0.41 + +CREATE TABLE support for VECTOR +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +With the :class:`.VECTOR` datatype, you can specify the dimension for the data +and the storage format. Valid values for storage format are enum values from +:class:`.VectorStorageFormat`. To create a table that includes a +:class:`.VECTOR` column:: + + from sqlalchemy.dialects.oracle import VECTOR, VectorStorageFormat + + t = Table( + "t1", + metadata, + Column("id", Integer, primary_key=True), + Column( + "embedding", + VECTOR(dim=3, storage_format=VectorStorageFormat.FLOAT32), + ), + Column(...), + ..., + ) + +Vectors can also be defined with an arbitrary number of dimensions and formats. +This allows you to specify vectors of different dimensions with the various +storage formats mentioned above. + +**Examples** + +* In this case, the storage format is flexible, allowing any vector type data to be inserted, + such as INT8 or BINARY etc:: + + vector_col: Mapped[array.array] = mapped_column(VECTOR(dim=3)) + +* The dimension is flexible in this case, meaning that any dimension vector can be used:: + + vector_col: Mapped[array.array] = mapped_column( + VECTOR(storage_format=VectorStorageType.INT8) + ) + +* Both the dimensions and the storage format are flexible:: + + vector_col: Mapped[array.array] = mapped_column(VECTOR) + +Python Datatypes for VECTOR +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +VECTOR data can be inserted using Python list or Python ``array.array()`` objects. +Python arrays of type FLOAT (32-bit), DOUBLE (64-bit), or INT (8-bit signed integer) +are used as bind values when inserting VECTOR columns:: + + from sqlalchemy import insert, select + + with engine.begin() as conn: + conn.execute( + insert(t1), + {"id": 1, "embedding": [1, 2, 3]}, + ) + +VECTOR Indexes +~~~~~~~~~~~~~~ + +The VECTOR feature supports an Oracle-specific parameter ``oracle_vector`` +on the :class:`.Index` construct, which allows the construction of VECTOR +indexes. + +To utilize VECTOR indexing, set the ``oracle_vector`` parameter to True to use +the default values provided by Oracle. HNSW is the default indexing method:: + + from sqlalchemy import Index + + Index( + "vector_index", + t1.c.embedding, + oracle_vector=True, + ) + +The full range of parameters for vector indexes are available by using the +:class:`.VectorIndexConfig` dataclass in place of a boolean; this dataclass +allows full configuration of the index:: + + Index( + "hnsw_vector_index", + t1.c.embedding, + oracle_vector=VectorIndexConfig( + index_type=VectorIndexType.HNSW, + distance=VectorDistanceType.COSINE, + accuracy=90, + hnsw_neighbors=5, + hnsw_efconstruction=20, + parallel=10, + ), + ) + + Index( + "ivf_vector_index", + t1.c.embedding, + oracle_vector=VectorIndexConfig( + index_type=VectorIndexType.IVF, + distance=VectorDistanceType.DOT, + accuracy=90, + ivf_neighbor_partitions=5, + ), + ) + +For complete explanation of these parameters, see the Oracle documentation linked +below. + +.. seealso:: + + `CREATE VECTOR INDEX `_ - in the Oracle documentation + + + +Similarity Searching +~~~~~~~~~~~~~~~~~~~~ + +When using the :class:`_oracle.VECTOR` datatype with a :class:`.Column` or similar +ORM mapped construct, additional comparison functions are available, including: + +* ``l2_distance`` +* ``cosine_distance`` +* ``inner_product`` + +Example Usage:: + + result_vector = connection.scalars( + select(t1).order_by(t1.embedding.l2_distance([2, 3, 4])).limit(3) + ) + + for user in vector: + print(user.id, user.embedding) + +FETCH APPROXIMATE support +~~~~~~~~~~~~~~~~~~~~~~~~~ + +Approximate vector search can only be performed when all syntax and semantic +rules are satisfied, the corresponding vector index is available, and the +query optimizer determines to perform it. If any of these conditions are +unmet, then an approximate search is not performed. In this case the query +returns exact results. + +To enable approximate searching during similarity searches on VECTORS, the +``oracle_fetch_approximate`` parameter may be used with the :meth:`.Select.fetch` +clause to add ``FETCH APPROX`` to the SELECT statement:: + + select(users_table).fetch(5, oracle_fetch_approximate=True) + """ # noqa from __future__ import annotations from collections import defaultdict +from dataclasses import fields from functools import lru_cache from functools import wraps import re @@ -749,6 +915,9 @@ from .types import RAW from .types import ROWID # noqa from .types import TIMESTAMP from .types import VARCHAR2 # noqa +from .vector import VECTOR +from .vector import VectorIndexConfig +from .vector import VectorIndexType from ... import Computed from ... import exc from ... import schema as sa_schema @@ -767,6 +936,7 @@ from ...sql import func from ...sql import null from ...sql import or_ from ...sql import select +from ...sql import selectable as sa_selectable from ...sql import sqltypes from ...sql import util as sql_util from ...sql import visitors @@ -828,6 +998,7 @@ ischema_names = { "BINARY_DOUBLE": BINARY_DOUBLE, "BINARY_FLOAT": BINARY_FLOAT, "ROWID": ROWID, + "VECTOR": VECTOR, } @@ -985,6 +1156,16 @@ class OracleTypeCompiler(compiler.GenericTypeCompiler): def visit_ROWID(self, type_, **kw): return "ROWID" + def visit_VECTOR(self, type_, **kw): + if type_.dim is None and type_.storage_format is None: + return "VECTOR(*,*)" + elif type_.storage_format is None: + return f"VECTOR({type_.dim},*)" + elif type_.dim is None: + return f"VECTOR(*,{type_.storage_format.value})" + else: + return f"VECTOR({type_.dim},{type_.storage_format.value})" + class OracleCompiler(compiler.SQLCompiler): """Oracle compiler modifies the lexical structure of Select @@ -1223,6 +1404,29 @@ class OracleCompiler(compiler.SQLCompiler): else: return select._fetch_clause + def fetch_clause( + self, + select, + fetch_clause=None, + require_offset=False, + use_literal_execute_for_simple_int=False, + **kw, + ): + text = super().fetch_clause( + select, + fetch_clause=fetch_clause, + require_offset=require_offset, + use_literal_execute_for_simple_int=( + use_literal_execute_for_simple_int + ), + **kw, + ) + + if select.dialect_options["oracle"]["fetch_approximate"]: + text = re.sub("FETCH FIRST", "FETCH APPROX FIRST", text) + + return text + def translate_select_structure(self, select_stmt, **kwargs): select = select_stmt @@ -1471,6 +1675,48 @@ class OracleCompiler(compiler.SQLCompiler): class OracleDDLCompiler(compiler.DDLCompiler): + + def _build_vector_index_config( + self, vector_index_config: VectorIndexConfig + ) -> str: + parts = [] + sql_param_name = { + "hnsw_neighbors": "neighbors", + "hnsw_efconstruction": "efconstruction", + "ivf_neighbor_partitions": "neighbor partitions", + "ivf_sample_per_partition": "sample_per_partition", + "ivf_min_vectors_per_partition": "min_vectors_per_partition", + } + if vector_index_config.index_type == VectorIndexType.HNSW: + parts.append("ORGANIZATION INMEMORY NEIGHBOR GRAPH") + elif vector_index_config.index_type == VectorIndexType.IVF: + parts.append("ORGANIZATION NEIGHBOR PARTITIONS") + if vector_index_config.distance is not None: + parts.append(f"DISTANCE {vector_index_config.distance.value}") + + if vector_index_config.accuracy is not None: + parts.append( + f"WITH TARGET ACCURACY {vector_index_config.accuracy}" + ) + + parameters_str = [f"type {vector_index_config.index_type.name}"] + prefix = vector_index_config.index_type.name.lower() + "_" + + for field in fields(vector_index_config): + if field.name.startswith(prefix): + key = sql_param_name.get(field.name) + value = getattr(vector_index_config, field.name) + if value is not None: + parameters_str.append(f"{key} {value}") + + parameters_str = ", ".join(parameters_str) + parts.append(f"PARAMETERS ({parameters_str})") + + if vector_index_config.parallel is not None: + parts.append(f"PARALLEL {vector_index_config.parallel}") + + return " ".join(parts) + def define_constraint_cascades(self, constraint): text = "" if constraint.ondelete is not None: @@ -1503,6 +1749,9 @@ class OracleDDLCompiler(compiler.DDLCompiler): text += "UNIQUE " if index.dialect_options["oracle"]["bitmap"]: text += "BITMAP " + vector_options = index.dialect_options["oracle"]["vector"] + if vector_options: + text += "VECTOR " text += "INDEX %s ON %s (%s)" % ( self._prepared_index_name(index, include_schema=True), preparer.format_table(index.table, use_schema=True), @@ -1520,6 +1769,11 @@ class OracleDDLCompiler(compiler.DDLCompiler): text += " COMPRESS %d" % ( index.dialect_options["oracle"]["compress"] ) + if vector_options: + if vector_options is True: + vector_options = VectorIndexConfig() + + text += " " + self._build_vector_index_config(vector_options) return text def post_create_table(self, table): @@ -1670,7 +1924,16 @@ class OracleDialect(default.DefaultDialect): "tablespace": None, }, ), - (sa_schema.Index, {"bitmap": False, "compress": False}), + ( + sa_schema.Index, + { + "bitmap": False, + "compress": False, + "vector": False, + }, + ), + (sa_selectable.Select, {"fetch_approximate": False}), + (sa_selectable.CompoundSelect, {"fetch_approximate": False}), ] @util.deprecated_params( diff --git a/lib/sqlalchemy/dialects/oracle/vector.py b/lib/sqlalchemy/dialects/oracle/vector.py new file mode 100644 index 0000000000..dae89d3418 --- /dev/null +++ b/lib/sqlalchemy/dialects/oracle/vector.py @@ -0,0 +1,266 @@ +# dialects/oracle/vector.py +# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +from __future__ import annotations + +import array +from dataclasses import dataclass +from enum import Enum +from typing import Optional + +import sqlalchemy.types as types +from sqlalchemy.types import Float + + +class VectorIndexType(Enum): + """Enum representing different types of VECTOR index structures. + + See :ref:`oracle_vector_datatype` for background. + + .. versionadded:: 2.0.41 + + """ + + HNSW = "HNSW" + """ + The HNSW (Hierarchical Navigable Small World) index type. + """ + IVF = "IVF" + """ + The IVF (Inverted File Index) index type + """ + + +class VectorDistanceType(Enum): + """Enum representing different types of vector distance metrics. + + See :ref:`oracle_vector_datatype` for background. + + .. versionadded:: 2.0.41 + + """ + + EUCLIDEAN = "EUCLIDEAN" + """Euclidean distance (L2 norm). + + Measures the straight-line distance between two vectors in space. + """ + DOT = "DOT" + """Dot product similarity. + + Measures the algebraic similarity between two vectors. + """ + COSINE = "COSINE" + """Cosine similarity. + + Measures the cosine of the angle between two vectors. + """ + MANHATTAN = "MANHATTAN" + """Manhattan distance (L1 norm). + + Calculates the sum of absolute differences across dimensions. + """ + + +class VectorStorageFormat(Enum): + """Enum representing the data format used to store vector components. + + See :ref:`oracle_vector_datatype` for background. + + .. versionadded:: 2.0.41 + + """ + + INT8 = "INT8" + """ + 8-bit integer format. + """ + BINARY = "BINARY" + """ + Binary format. + """ + FLOAT32 = "FLOAT32" + """ + 32-bit floating-point format. + """ + FLOAT64 = "FLOAT64" + """ + 64-bit floating-point format. + """ + + +@dataclass +class VectorIndexConfig: + """Define the configuration for Oracle VECTOR Index. + + See :ref:`oracle_vector_datatype` for background. + + .. versionadded:: 2.0.41 + + :param index_type: Enum value from :class:`.VectorIndexType` + Specifies the indexing method. For HNSW, this must be + :attr:`.VectorIndexType.HNSW`. + + :param distance: Enum value from :class:`.VectorDistanceType` + specifies the metric for calculating distance between VECTORS. + + :param accuracy: interger. Should be in the range 0 to 100 + Specifies the accuracy of the nearest neighbor search during + query execution. + + :param parallel: integer. Specifies degree of parallelism. + + :param hnsw_neighbors: interger. Should be in the range 0 to + 2048. Specifies the number of nearest neighbors considered + during the search. The attribute :attr:`.VectorIndexConfig.hnsw_neighbors` + is HNSW index specific. + + :param hnsw_efconstruction: integer. Should be in the range 0 + to 65535. Controls the trade-off between indexing speed and + recall quality during index construction. The attribute + :attr:`.VectorIndexConfig.hnsw_efconstruction` is HNSW index + specific. + + :param ivf_neighbor_partitions: integer. Should be in the range + 0 to 10,000,000. Specifies the number of partitions used to + divide the dataset. The attribute + :attr:`.VectorIndexConfig.ivf_neighbor_partitions` is IVF index + specific. + + :param ivf_sample_per_partition: integer. Should be between 1 + and ``num_vectors / neighbor partitions``. Specifies the + number of samples used per partition. The attribute + :attr:`.VectorIndexConfig.ivf_sample_per_partition` is IVF index + specific. + + :param ivf_min_vectors_per_partition: integer. From 0 (no trimming) + to the total number of vectors (results in 1 partition). Specifies + the minimum number of vectors per partition. The attribute + :attr:`.VectorIndexConfig.ivf_min_vectors_per_partition` + is IVF index specific. + + """ + + index_type: VectorIndexType = VectorIndexType.HNSW + distance: Optional[VectorDistanceType] = None + accuracy: Optional[int] = None + hnsw_neighbors: Optional[int] = None + hnsw_efconstruction: Optional[int] = None + ivf_neighbor_partitions: Optional[int] = None + ivf_sample_per_partition: Optional[int] = None + ivf_min_vectors_per_partition: Optional[int] = None + parallel: Optional[int] = None + + def __post_init__(self): + self.index_type = VectorIndexType(self.index_type) + for field in [ + "hnsw_neighbors", + "hnsw_efconstruction", + "ivf_neighbor_partitions", + "ivf_sample_per_partition", + "ivf_min_vectors_per_partition", + "parallel", + "accuracy", + ]: + value = getattr(self, field) + if value is not None and not isinstance(value, int): + raise TypeError( + f"{field} must be an integer if" + f"provided, got {type(value).__name__}" + ) + + +class VECTOR(types.TypeEngine): + """Oracle VECTOR datatype. + + For complete background on using this type, see + :ref:`oracle_vector_datatype`. + + .. versionadded:: 2.0.41 + + """ + + cache_ok = True + __visit_name__ = "VECTOR" + + _typecode_map = { + VectorStorageFormat.INT8: "b", # Signed int + VectorStorageFormat.BINARY: "B", # Unsigned int + VectorStorageFormat.FLOAT32: "f", # Float + VectorStorageFormat.FLOAT64: "d", # Double + } + + def __init__(self, dim=None, storage_format=None): + """Construct a VECTOR. + + :param dim: integer. The dimension of the VECTOR datatype. This + should be an integer value. + + :param storage_format: VectorStorageFormat. The VECTOR storage + type format. This may be Enum values form + :class:`.VectorStorageFormat` INT8, BINARY, FLOAT32, or FLOAT64. + + """ + if dim is not None and not isinstance(dim, int): + raise TypeError("dim must be an interger") + if storage_format is not None and not isinstance( + storage_format, VectorStorageFormat + ): + raise TypeError( + "storage_format must be an enum of type VectorStorageFormat" + ) + self.dim = dim + self.storage_format = storage_format + + def _cached_bind_processor(self, dialect): + """ + Convert a list to a array.array before binding it to the database. + """ + + def process(value): + if value is None or isinstance(value, array.array): + return value + + # Convert list to a array.array + elif isinstance(value, list): + typecode = self._array_typecode(self.storage_format) + value = array.array(typecode, value) + return value + + else: + raise TypeError("VECTOR accepts list or array.array()") + + return process + + def _cached_result_processor(self, dialect, coltype): + """ + Convert a array.array to list before binding it to the database. + """ + + def process(value): + if isinstance(value, array.array): + return list(value) + + return process + + def _array_typecode(self, typecode): + """ + Map storage format to array typecode. + """ + return self._typecode_map.get(typecode, "d") + + class comparator_factory(types.TypeEngine.Comparator): + def l2_distance(self, other): + return self.op("<->", return_type=Float)(other) + + def inner_product(self, other): + return self.op("<#>", return_type=Float)(other) + + def cosine_distance(self, other): + return self.op("<=>", return_type=Float)(other) diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index d137ab504e..ef7605a64b 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -72,6 +72,7 @@ from .base import ColumnCollection from .base import ColumnSet from .base import CompileState from .base import DedupeColumnCollection +from .base import DialectKWArgs from .base import Executable from .base import Generative from .base import HasCompileState @@ -3927,7 +3928,7 @@ class SelectStatementGrouping(GroupedElement, SelectBase, Generic[_SB]): raise NotImplementedError -class GenerativeSelect(SelectBase, Generative): +class GenerativeSelect(DialectKWArgs, SelectBase, Generative): """Base class for SELECT statements where additional elements can be added. @@ -4208,8 +4209,9 @@ class GenerativeSelect(SelectBase, Generative): count: _LimitOffsetType, with_ties: bool = False, percent: bool = False, + **dialect_kw: Any, ) -> Self: - """Return a new selectable with the given FETCH FIRST criterion + r"""Return a new selectable with the given FETCH FIRST criterion applied. This is a numeric value which usually renders as ``FETCH {FIRST | NEXT} @@ -4239,6 +4241,11 @@ class GenerativeSelect(SelectBase, Generative): :param percent: When ``True``, ``count`` represents the percentage of the total number of selected rows to return. Defaults to ``False`` + :param \**dialect_kw: Additional dialect-specific keyword arguments + may be accepted by dialects. + + .. versionadded:: 2.0.41 + .. seealso:: :meth:`_sql.GenerativeSelect.limit` @@ -4246,7 +4253,7 @@ class GenerativeSelect(SelectBase, Generative): :meth:`_sql.GenerativeSelect.offset` """ - + self._validate_dialect_kwargs(dialect_kw) self._limit_clause = None if count is None: self._fetch_clause = self._fetch_clause_options = None @@ -4488,6 +4495,7 @@ class CompoundSelect(HasCompileState, GenerativeSelect, TypedReturnsRows[_TP]): ] + SupportsCloneAnnotations._clone_annotations_traverse_internals + HasCTE._has_ctes_traverse_internals + + DialectKWArgs._dialect_kwargs_traverse_internals ) selects: List[SelectBase] @@ -5309,6 +5317,7 @@ class Select( + HasHints._has_hints_traverse_internals + SupportsCloneAnnotations._clone_annotations_traverse_internals + Executable._executable_traverse_internals + + DialectKWArgs._dialect_kwargs_traverse_internals ) _cache_key_traversal: _CacheKeyTraversalType = _traverse_internals + [ @@ -5330,7 +5339,9 @@ class Select( stmt.__dict__.update(kw) return stmt - def __init__(self, *entities: _ColumnsClauseArgument[Any]): + def __init__( + self, *entities: _ColumnsClauseArgument[Any], **dialect_kw: Any + ): r"""Construct a new :class:`_expression.Select`. The public constructor for :class:`_expression.Select` is the @@ -5343,7 +5354,6 @@ class Select( ) for ent in entities ] - GenerativeSelect.__init__(self) def _scalar_type(self) -> TypeEngine[Any]: diff --git a/test/dialect/oracle/test_compiler.py b/test/dialect/oracle/test_compiler.py index 467edbe183..7effcf3aa5 100644 --- a/test/dialect/oracle/test_compiler.py +++ b/test/dialect/oracle/test_compiler.py @@ -310,6 +310,17 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): checkparams={"param_1": 20, "param_2": 10}, ) + @testing.only_on("oracle>=23.4") + def test_fetch_type(self): + t = table("sometable", column("col1"), column("col2")) + s = select(t).fetch(2, oracle_fetch_approximate=True) + self.assert_compile( + s, + "SELECT sometable.col1, sometable.col2 FROM sometable " + "FETCH APPROX FIRST __[POSTCOMPILE_param_1] ROWS ONLY", + checkparams={"param_1": 2}, + ) + def test_limit_two(self): t = table("sometable", column("col1"), column("col2")) s = select(t).limit(10).offset(20).subquery() diff --git a/test/dialect/oracle/test_reflection.py b/test/dialect/oracle/test_reflection.py index a17b53895f..3573588948 100644 --- a/test/dialect/oracle/test_reflection.py +++ b/test/dialect/oracle/test_reflection.py @@ -21,6 +21,11 @@ from sqlalchemy import text from sqlalchemy import Unicode from sqlalchemy import UniqueConstraint from sqlalchemy.dialects import oracle +from sqlalchemy.dialects.oracle import VECTOR +from sqlalchemy.dialects.oracle import VectorDistanceType +from sqlalchemy.dialects.oracle import VectorIndexConfig +from sqlalchemy.dialects.oracle import VectorIndexType +from sqlalchemy.dialects.oracle import VectorStorageFormat from sqlalchemy.dialects.oracle.base import BINARY_DOUBLE from sqlalchemy.dialects.oracle.base import BINARY_FLOAT from sqlalchemy.dialects.oracle.base import DOUBLE_PRECISION @@ -698,6 +703,25 @@ class TableReflectionTest(fixtures.TestBase): tbl = Table("test_tablespace", m2, autoload_with=connection) assert tbl.dialect_options["oracle"]["tablespace"] == "TEMP" + @testing.only_on("oracle>=23.4") + def test_reflection_w_vector_column(self, connection, metadata): + tb1 = Table( + "test_vector", + metadata, + Column("id", Integer, primary_key=True), + Column("name", String(30)), + Column( + "embedding", + VECTOR(dim=3, storage_format=VectorStorageFormat.FLOAT32), + ), + ) + metadata.create_all(connection) + + m2 = MetaData() + + tb1 = Table("test_vector", m2, autoload_with=connection) + assert tb1.columns.keys() == ["id", "name", "embedding"] + class ViewReflectionTest(fixtures.TestBase): __only_on__ = "oracle" @@ -1180,6 +1204,42 @@ class RoundTripIndexTest(fixtures.TestBase): eq_(len(reflectedtable.constraints), 1) eq_(len(reflectedtable.indexes), 5) + @testing.only_on("oracle>=23.4") + def test_vector_index(self, metadata, connection): + tb1 = Table( + "test_vector", + metadata, + Column("id", Integer, primary_key=True), + Column("name", String(30)), + Column( + "embedding", + VECTOR(dim=3, storage_format=VectorStorageFormat.FLOAT32), + ), + ) + tb1.create(connection) + + ivf_index = Index( + "ivf_vector_index", + tb1.c.embedding, + oracle_vector=VectorIndexConfig( + index_type=VectorIndexType.IVF, + distance=VectorDistanceType.DOT, + accuracy=90, + ivf_neighbor_partitions=5, + ), + ) + ivf_index.create(connection) + + expected = [ + { + "name": "ivf_vector_index", + "column_names": ["embedding"], + "dialect_options": {}, + "unique": False, + }, + ] + eq_(inspect(connection).get_indexes("test_vector"), expected) + class DBLinkReflectionTest(fixtures.TestBase): __requires__ = ("oracle_test_dblink",) diff --git a/test/dialect/oracle/test_types.py b/test/dialect/oracle/test_types.py index 7b03de88c5..331103e8f2 100644 --- a/test/dialect/oracle/test_types.py +++ b/test/dialect/oracle/test_types.py @@ -1,3 +1,4 @@ +import array import datetime import decimal import os @@ -15,6 +16,7 @@ from sqlalchemy import event from sqlalchemy import exc from sqlalchemy import FLOAT from sqlalchemy import Float +from sqlalchemy import Index from sqlalchemy import Integer from sqlalchemy import LargeBinary from sqlalchemy import literal @@ -37,6 +39,11 @@ from sqlalchemy import VARCHAR from sqlalchemy.dialects.oracle import base as oracle from sqlalchemy.dialects.oracle import cx_oracle from sqlalchemy.dialects.oracle import oracledb +from sqlalchemy.dialects.oracle import VECTOR +from sqlalchemy.dialects.oracle import VectorDistanceType +from sqlalchemy.dialects.oracle import VectorIndexConfig +from sqlalchemy.dialects.oracle import VectorIndexType +from sqlalchemy.dialects.oracle import VectorStorageFormat from sqlalchemy.sql import column from sqlalchemy.sql.sqltypes import NullType from sqlalchemy.testing import assert_raises_message @@ -952,6 +959,194 @@ class TypesTest(fixtures.TestBase): finally: exec_sql(connection, "DROP TABLE Z_TEST") + @testing.only_on("oracle>=23.4") + def test_vector_dim(self, metadata, connection): + t1 = Table( + "t1", + metadata, + Column( + "c1", VECTOR(dim=3, storage_format=VectorStorageFormat.FLOAT32) + ), + ) + + t1.create(connection) + eq_(t1.c.c1.type.dim, 3) + + @testing.only_on("oracle>=23.4") + def test_vector_insert(self, metadata, connection): + t1 = Table( + "t1", + metadata, + Column("id", Integer, primary_key=True), + Column("c1", VECTOR(storage_format=VectorStorageFormat.INT8)), + ) + + t1.create(connection) + connection.execute( + t1.insert(), + dict(id=1, c1=[6, 7, 8, 5]), + ) + eq_( + connection.execute(t1.select()).first(), + (1, [6, 7, 8, 5]), + ) + connection.execute(t1.delete().where(t1.c.id == 1)) + connection.execute(t1.insert(), dict(id=1, c1=[6, 7])) + eq_( + connection.execute(t1.select()).first(), + (1, [6, 7]), + ) + + @testing.only_on("oracle>=23.4") + def test_vector_insert_array(self, metadata, connection): + t1 = Table( + "t1", + metadata, + Column("id", Integer, primary_key=True), + Column("c1", VECTOR), + ) + + t1.create(connection) + connection.execute( + t1.insert(), + dict(id=1, c1=array.array("b", [6, 7, 8, 5])), + ) + eq_( + connection.execute(t1.select()).first(), + (1, [6, 7, 8, 5]), + ) + + connection.execute(t1.delete().where(t1.c.id == 1)) + + connection.execute( + t1.insert(), dict(id=1, c1=array.array("b", [6, 7])) + ) + eq_( + connection.execute(t1.select()).first(), + (1, [6, 7]), + ) + + @testing.only_on("oracle>=23.4") + def test_vector_multiformat_insert(self, metadata, connection): + t1 = Table( + "t1", + metadata, + Column("id", Integer, primary_key=True), + Column("c1", VECTOR), + ) + + t1.create(connection) + connection.execute( + t1.insert(), + dict(id=1, c1=[6.12, 7.54, 8.33]), + ) + eq_( + connection.execute(t1.select()).first(), + (1, [6.12, 7.54, 8.33]), + ) + connection.execute(t1.delete().where(t1.c.id == 1)) + connection.execute(t1.insert(), dict(id=1, c1=[6, 7])) + eq_( + connection.execute(t1.select()).first(), + (1, [6, 7]), + ) + + @testing.only_on("oracle>=23.4") + def test_vector_format(self, metadata, connection): + t1 = Table( + "t1", + metadata, + Column( + "c1", VECTOR(dim=3, storage_format=VectorStorageFormat.FLOAT32) + ), + ) + + t1.create(connection) + eq_(t1.c.c1.type.storage_format, VectorStorageFormat.FLOAT32) + + @testing.only_on("oracle>=23.4") + def test_vector_hnsw_index(self, metadata, connection): + t1 = Table( + "t1", + metadata, + Column("id", Integer), + Column( + "embedding", + VECTOR(dim=3, storage_format=VectorStorageFormat.FLOAT32), + ), + ) + + t1.create(connection) + + hnsw_index = Index( + "hnsw_vector_index", t1.c.embedding, oracle_vector=True + ) + hnsw_index.create(connection) + + connection.execute(t1.insert(), dict(id=1, embedding=[6, 7, 8])) + eq_( + connection.execute(t1.select()).first(), + (1, [6.0, 7.0, 8.0]), + ) + + @testing.only_on("oracle>=23.4") + def test_vector_ivf_index(self, metadata, connection): + t1 = Table( + "t1", + metadata, + Column("id", Integer), + Column( + "embedding", + VECTOR(dim=3, storage_format=VectorStorageFormat.FLOAT32), + ), + ) + + t1.create(connection) + ivf_index = Index( + "ivf_vector_index", + t1.c.embedding, + oracle_vector=VectorIndexConfig( + index_type=VectorIndexType.IVF, + distance=VectorDistanceType.DOT, + accuracy=90, + ivf_neighbor_partitions=5, + ), + ) + ivf_index.create(connection) + + connection.execute(t1.insert(), dict(id=1, embedding=[6, 7, 8])) + eq_( + connection.execute(t1.select()).first(), + (1, [6.0, 7.0, 8.0]), + ) + + @testing.only_on("oracle>=23.4") + def test_vector_l2_distance(self, metadata, connection): + t1 = Table( + "t1", + metadata, + Column("id", Integer), + Column( + "embedding", + VECTOR(dim=3, storage_format=VectorStorageFormat.INT8), + ), + ) + + t1.create(connection) + + connection.execute(t1.insert(), dict(id=1, embedding=[8, 9, 10])) + connection.execute(t1.insert(), dict(id=2, embedding=[1, 2, 3])) + connection.execute( + t1.insert(), + dict(id=3, embedding=[15, 16, 17]), + ) + + query_vector = [2, 3, 4] + res = connection.execute( + t1.select().order_by((t1.c.embedding.l2_distance(query_vector))) + ).first() + eq_(res.embedding, [1, 2, 3]) + class LOBFetchTest(fixtures.TablesTest): __only_on__ = "oracle" diff --git a/test/sql/test_compare.py b/test/sql/test_compare.py index 0457770442..77743b9c92 100644 --- a/test/sql/test_compare.py +++ b/test/sql/test_compare.py @@ -44,6 +44,7 @@ from sqlalchemy.sql import True_ from sqlalchemy.sql import type_coerce from sqlalchemy.sql import visitors from sqlalchemy.sql.annotation import Annotated +from sqlalchemy.sql.base import DialectKWArgs from sqlalchemy.sql.base import HasCacheKey from sqlalchemy.sql.base import SingletonConstant from sqlalchemy.sql.elements import _label_reference @@ -537,6 +538,7 @@ class CoreFixtures: select(table_a.c.a).fetch(2, percent=True), select(table_a.c.a).fetch(2, with_ties=True), select(table_a.c.a).fetch(2, with_ties=True, percent=True), + select(table_a.c.a).fetch(2, oracle_fetch_approximate=True), select(table_a.c.a).fetch(2).offset(3), select(table_a.c.a).fetch(2).offset(5), select(table_a.c.a).limit(2).offset(5), @@ -1635,7 +1637,12 @@ class HasCacheKeySubclass(fixtures.TestBase): @testing.combinations( *all_hascachekey_subclasses( - ignore_subclasses=[Annotated, NoInit, SingletonConstant] + ignore_subclasses=[ + Annotated, + NoInit, + SingletonConstant, + DialectKWArgs, + ] ) ) def test_init_args_in_traversal(self, cls: type):