From: suraj Date: Thu, 6 Feb 2025 08:02:47 +0000 (+0530) Subject: Fixes: #12317 Added vector datatype support in Oracle dialect X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=c3fed5ddd79abe748d04cb511b98aa33759b5f19;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Fixes: #12317 Added vector datatype support in Oracle dialect Fixes: #12317 Added vector datatype support in Oracle dialect --- diff --git a/lib/sqlalchemy/dialects/oracle/__init__.py b/lib/sqlalchemy/dialects/oracle/__init__.py index 7ceb743d61..c05d8bdf87 100644 --- a/lib/sqlalchemy/dialects/oracle/__init__.py +++ b/lib/sqlalchemy/dialects/oracle/__init__.py @@ -32,6 +32,7 @@ from .base import ROWID from .base import TIMESTAMP from .base import VARCHAR from .base import VARCHAR2 +from .base import VECTOR # Alias oracledb also as oracledb_async oracledb_async = type( @@ -64,4 +65,5 @@ __all__ = ( "NVARCHAR2", "ROWID", "REAL", + "VECTOR", ) diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py index 3d3ff9d517..131a341bb4 100644 --- a/lib/sqlalchemy/dialects/oracle/base.py +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -771,6 +771,7 @@ from .types import RAW from .types import ROWID # noqa from .types import TIMESTAMP from .types import VARCHAR2 # noqa +from .types import VECTOR from ... import Computed from ... import exc from ... import schema as sa_schema @@ -850,6 +851,7 @@ ischema_names = { "BINARY_DOUBLE": BINARY_DOUBLE, "BINARY_FLOAT": BINARY_FLOAT, "ROWID": ROWID, + "VECTOR": VECTOR, } @@ -1007,6 +1009,16 @@ class OracleTypeCompiler(compiler.GenericTypeCompiler): def visit_ROWID(self, type_, **kw): return "ROWID" + def visit_VECTOR(self, type_, **kw): + if type_.dim is None and type_.storage_format is None: + return f"VECTOR" + elif type_.storage_format is None: + return f"VECTOR({type_.dim},*)" + elif type_.dim is None: + return f"VECTOR(*, {type_.storage_format})" + else: + return f"VECTOR({type_.dim},{type_.storage_format})" + class OracleCompiler(compiler.SQLCompiler): """Oracle compiler modifies the lexical structure of Select @@ -1525,6 +1537,9 @@ class OracleDDLCompiler(compiler.DDLCompiler): text += "UNIQUE " if index.dialect_options["oracle"]["bitmap"]: text += "BITMAP " + vector_options = index.dialect_options["oracle"]["vector"] + if vector_options: + text += "VECTOR " text += "INDEX %s ON %s (%s)" % ( self._prepared_index_name(index, include_schema=True), preparer.format_table(index.table, use_schema=True), @@ -1542,6 +1557,45 @@ class OracleDDLCompiler(compiler.DDLCompiler): text += " COMPRESS %d" % ( index.dialect_options["oracle"]["compress"] ) + if vector_options: + if vector_options is True: + vector_options = {} + parts = [] + parameters = vector_options.get("parameters", {}) + using = parameters.get("type", "HNSW").upper() + if using == "HNSW": + parts.append("ORGANIZATION INMEMORY NEIGHBOR GRAPH") + elif using == "IVF": + parts.append("ORGANIZATION NEIGHBOR PARTITIONS") + vector_distance = vector_options.get("distance") + if vector_distance is not None: + vector_distance = vector_distance.upper() + if vector_distance not in ( + "EUCLIDEAN", + "DOT", + "COSINE", + "MANHATTAN", + ): + raise ValueError("Unknown vector_distance value") + parts.append(f"DISTANCE {vector_distance}") + target_accuracy = vector_options.get("accuracy") + if target_accuracy is not None: + if target_accuracy < 0 or target_accuracy > 100: + raise ValueError( + "Accuracy value should be an integer between 0 and 100" + ) + parts.append(f"WITH TARGET ACCURACY {target_accuracy}") + if parameters: + parameters_str = ", ".join( + f"{k} {v}" for k, v in parameters.items() + ) + parts.append(f"PARAMETERS ({parameters_str})") + parallel = vector_options.get("parallel") + if parallel is not None: + if not isinstance(parallel, int): + raise ValueError("Parallel value must be an integer") + parts.append(f"PARALLEL {parallel}") + text += " " + " ".join(parts) return text def post_create_table(self, table): @@ -1693,7 +1747,14 @@ class OracleDialect(default.DefaultDialect): "tablespace": None, }, ), - (sa_schema.Index, {"bitmap": False, "compress": False}), + ( + sa_schema.Index, + { + "bitmap": False, + "compress": False, + "vector": False, + }, + ), (sa_schema.Sequence, {"order": None}), (sa_schema.Identity, {"order": None, "on_null": None}), ] diff --git a/lib/sqlalchemy/dialects/oracle/oracledb.py b/lib/sqlalchemy/dialects/oracle/oracledb.py index 8105608837..201902f997 100644 --- a/lib/sqlalchemy/dialects/oracle/oracledb.py +++ b/lib/sqlalchemy/dialects/oracle/oracledb.py @@ -591,6 +591,228 @@ SQLAlchemy type (or a subclass of such). .. versionadded:: 2.0.0 added support for the python-oracledb driver. +VECTOR Datatype +--------------- + +Oracle Database 23ai introduced a new VECTOR datatype for artificial intelligence and machine +learning search operations. The VECTOR datatype is a homogeneous array of 8-bit signed integers, +8-bit unsigned integers, 32-bit floating-point numbers, or 64-bit floating-point numbers. +For more information on the VECTOR datatype please visit this `link. +`_ + +CREATE TABLE +^^^^^^^^^^^^ + +With the VECTOR datatype, you can specify the dimension for the data and the storage +format. To create a table that includes a VECTOR column:: + + from sqlalchemy.dialects.oracle import VECTOR + + t = Table("t1", metadata, + Column('id', Integer, primary_key=True), + Column("embedding", VECTOR(3, 'float32'), + Column(...), ... + ) + +Vectors can also be defined with an arbitrary number of dimensions and formats. This allows +you to specify vectors of different dimensions with the various storage formats mentioned above. + +For Example + +* In this case, the storage format is flexible, allowing any vector type data to be inserted, + such as int8 or binary etc. + + vector_col:Mapped[array.array] = mapped_column(VECTOR(3)) + +* The dimension is flexible in this case, meaning that any dimension vector can be used. + + vector_col:Mapped[array.array] = mapped_column(VECTOR('int8')) + +* Both the dimensions and the storage format are flexible. + + vector_col:Mapped[array.array] = mapped_column(VECTOR) + +INSERT VECTOR DATA +^^^^^^^^^^^^^^^^^^ + +VECTOR data can be inserted using Python list or Python array.array() objects. Python arrays of type +float (32-bit), double (64-bit), or int8_t (8-bit signed integer) are used as bind values when +inserting VECTOR columns:: + + from sqlalchemy import insert, select + import array + + vector_data_8 = [1, 2, 3] + statement = insert(t1) + with engine.connect() as conn: + conn.execute(statement,[ + {"id":1,"embedding":vector_data_8}, + ]) + +VECTOR INDEXES +^^^^^^^^^^^^^^ + +There are two VECTOR indexes supported in VECTOR search: IVF Flat index and HNSW +index. + +To utilize VECTOR indexing, set the `oracle_vector` parameter to True to use +the default values provided by Oracle. HNSW is the default indexing method:: + + Index( + 'vector_index', + t1.c.embedding, + oracle_vector = True, + ) + +If you wish to use custom parameters, you can specify all the parameters as a dictionary +in the `oracle_vector` option. To learn more about the parameters that can be passed please +visit this `link. `_ + +Configuring Oracle VECTOR Indexes +================================= + +When using Oracle VECTOR indexes, the configuration parameters are divided into two levels: +**top-level keys** and **nested keys**. This structure applies to both HNSW and IVF VECTOR indexes. + +Top-Level Keys +============== + +These keys are specified directly under the ``oracle_vector`` dictionary. + +* ``accuracy``: + - Specifies the accuracy of the nearest neighbor search during query execution. + - **Valid Range**: Greater than 0 and less than or equal to 100. + - **Example**: ``'accuracy': 95`` + +* ``distance``: + - Specifies the metric for calculating distance between VECTORS. + - **Valid Values**: ``"EUCLIDEAN"``, ``"COSINE"``, ``"DOT"``, ``"MANHATTAN"``. + - **Example**: ``'distance': "COSINE"`` + +* ``parameters``: + - A nested dictionary where method-specific options are defined (e.g., HNSW or IVF-specific settings). + - **Example**: ``'parameters': {...}`` + +Nested keys in parameters +========================= + +These keys are specific to the indexing method and are included under the ``parameters`` dictionary. + +HNSW Parameters +=============== + +* ``type``: + - Specifies the indexing method. For HNSW, this must be ``"HNSW"``. + - **Placement**: Nested under ``parameters``. + - **Example**: ``'type': 'HNSW'`` + +* ``neighbors``: + - The number of nearest neighbors considered during the search. + - **Valid Range**: Greater than 0 and less than or equal to 2048. + - **Placement**: Nested under ``parameters``. + - **Example**: ``'neighbors': 20`` + +* ``efconstruction``: + - Controls the trade-off between indexing speed and recall quality during index construction. + - **Valid Range**: Greater than 0 and less than or equal to 65535. + - **Placement**: Nested under ``parameters``. + - **Example**: ``'efconstruction': 300`` + +IVF Parameters +============== + +* ``type``: + - Specifies the indexing method. For IVF, this must be ``"IVF"``. + - **Placement**: Nested under ``parameters``. + - **Example**: ``'type': 'IVF'`` + +* ``neighbor partitions``: + - The number of partitions used to divide the dataset. + - **Valid Range**: Greater than 0 and less than or equal to 10,000,000. + - **Placement**: Nested under ``parameters``. + - **Example**: ``'neighbor partitions': 10`` + +* ``sample_per_partition``: + - The number of samples used per partition. + - **Valid Range**: Between 1 and ``num_vectors / neighbor_partitions``. + - **Placement**: Nested under ``parameters``. + - **Example**: ``'sample_per_partition': 5`` + +* ``min_vectors_per_partition``: + - The minimum number of vectors per partition. + - **Valid Range**: From 0 (no trimming) to the total number of vectors (results in 1 partition). + - **Placement**: Nested under ``parameters``. + - **Example**: ``'min_vectors_per_partition': 100`` + +Example Configurations +====================== + +For custom configurations, the parameters can be specified as shown in the following examples:: + + Index( + 'hnsw_vector_index', + t1.c.embedding, + oracle_vector = { + 'accuracy':95, + 'distance':"COSINE", + 'parameters':{ + 'type':'HNSW', + 'neighbors':20', + efconstruction':300 + } + } + ) + + Index( + 'ivf_vector_index', + t1.c.embedding, + oracle_vector = { + 'accuracy':90, + 'distance':"DOT", + 'parameters':{ + 'type':'IVF',' + neighbor partitions':10 + } + } + ) + +Similarity Searching +^^^^^^^^^^^^^^^^^^^^ + +You can use the following shorthand VECTOR distance functions: + +* ``l2_distance`` +* ``cosine_distance`` +* ``inner_product`` + +Example Usage:: + + from sqlalchemy.orm import Session + from sqlalchemy.sql import func + import array + + session = Session(bind=engine) + query_vector = [2,3,4] + result_vector = session.scalars(select(t1).order_by(t1.embedding.l2_distance(query_vector)).limit(3)) + + for user in vector: + print(user.id,user.embedding) + +Exact and Approximate Searching +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Similarity searches tend to get data from one or more clusters depending on the value of the query VECTOR and the fetch +size. Approximate searches using VECTOR indexes can limit the searches to specific clusters, whereas exact searches visit +VECTORS across all clusters. +You can use the fetch_type clause to set the searching to be either EXACT, APPROX or APPROXIMATE:: + + result_vector = session.scalars(select(t1).order_by(t1.embedding.l1_distance(query_vector)).limit(3)).fetch_type("EXACT") + result_vector = session.scalars(select(t1).order_by(t1.embedding.l1_distance(query_vector)).limit(3)).fetch_type("APPROX") + + +.. versionadded:: 2.1.0 Added support for VECTOR specific to Oracle Database. +:class:`_oracle.VECTOR` datatype. + """ # noqa from __future__ import annotations diff --git a/lib/sqlalchemy/dialects/oracle/types.py b/lib/sqlalchemy/dialects/oracle/types.py index 06aeaace2f..f128a42352 100644 --- a/lib/sqlalchemy/dialects/oracle/types.py +++ b/lib/sqlalchemy/dialects/oracle/types.py @@ -11,11 +11,14 @@ import datetime as dt from typing import Optional from typing import Type from typing import TYPE_CHECKING +import sqlalchemy.types as types +from sqlalchemy.types import UserDefinedType, Float from ... import exc from ...sql import sqltypes from ...types import NVARCHAR from ...types import VARCHAR +import array if TYPE_CHECKING: from ...engine.interfaces import Dialect @@ -314,3 +317,82 @@ class ROWID(sqltypes.TypeEngine): class _OracleBoolean(sqltypes.Boolean): def get_dbapi_type(self, dbapi): return dbapi.NUMBER + + +class VECTOR(types.TypeEngine): + """Oracle VECTOR datatype.""" + + cache_ok = True + __visit_name__ = "VECTOR" + + def __init__(self, dim=None, storage_format=None, *args): + """ + :param dim: The dimension of the VECTOR datatype. This should be an + integer value. + :param storage_format: The VECTOR storage type format. This + may be int8, binary, float32, or float64. + """ + if dim is not None and isinstance(dim, int): + self.dim = dim + self.storage_format = storage_format + + elif dim is not None and isinstance(dim, str): + self.dim = storage_format + self.storage_format = dim + + else: + self.dim = storage_format + self.storage_format = dim + + def _cached_bind_processor(self, dialect): + """ + Convert a list to a array.array before binding it to the database. + """ + + def process(value): + if value is None or isinstance(value, array.array): + return value + + # Convert list to a array.array + elif isinstance(value, list): + format = self._array_typecode(self.storage_format) + value = array.array(format, value) + return value + + else: + raise TypeError("VECTOR accepts list or array.array()") + + return process + + def _cached_result_processor(self, dialect, coltype): + """ + Convert a array.array to list before binding it to the database. + """ + + def process(value): + if isinstance(value, array.array): + return list(value) + + return process + + def _array_typecode(self, format): + """ + Map storage format to array typecode. + """ + typecode_map = { + "int8": "b", # Signed int + "binary": "B", # Unsigned int + "float32": "f", # Float + "float64": "d", # Double + } + return typecode_map.get(format, "d") + + class comparator_factory(types.TypeEngine.Comparator): + def l2_distance(self, other): + return self.op("<->", return_type=Float)(other) + + def inner_product(self, other): + return self.op("<#>", return_type=Float)(other) + + def cosine_distance(self, other): + return self.op("<=>", return_type=Float)(other) diff --git a/test/dialect/oracle/test_types.py b/test/dialect/oracle/test_types.py index b5ce61222e..cf22f20f97 100644 --- a/test/dialect/oracle/test_types.py +++ b/test/dialect/oracle/test_types.py @@ -2,6 +2,7 @@ import datetime import decimal import os import random +import array from sqlalchemy import bindparam from sqlalchemy import cast @@ -34,9 +35,12 @@ from sqlalchemy import types as sqltypes from sqlalchemy import Unicode from sqlalchemy import UnicodeText from sqlalchemy import VARCHAR +from sqlalchemy import Index +from sqlalchemy import delete from sqlalchemy.dialects.oracle import base as oracle from sqlalchemy.dialects.oracle import cx_oracle from sqlalchemy.dialects.oracle import oracledb +from sqlalchemy.dialects.oracle import VECTOR from sqlalchemy.sql import column from sqlalchemy.sql.sqltypes import NullType from sqlalchemy.testing import AssertsCompiledSQL @@ -951,6 +955,198 @@ class TypesTest(fixtures.TestBase): finally: exec_sql(connection, "DROP TABLE Z_TEST") + def test_vector_dim(self, metadata, connection): + t1 = Table("t1", metadata, Column("c1", VECTOR(3, "float32"))) + + if testing.against("oracle>23.4"): + t1.create(connection) + eq_(t1.c.c1.type.dim, 3) + else: + with expect_raises_message(exc.DatabaseError, "ORA-03060"): + t1.create(connection) + + def test_vector_insert(self, metadata, connection): + t1 = Table( + "t1", + metadata, + Column("id", Integer, primary_key=True), + Column("c1", VECTOR("int8")), + ) + + if testing.against("oracle>23.4"): + t1.create(connection) + connection.execute( + t1.insert(), + dict(id=1, c1=[6, 7, 8, 5]), + ) + eq_( + connection.execute(t1.select()).first(), + (1, [6, 7, 8, 5]), + ) + connection.execute(t1.delete().where(t1.c.id == 1)) + connection.execute(t1.insert(), dict(id=1, c1=[6, 7])) + eq_( + connection.execute(t1.select()).first(), + (1, [6, 7]), + ) + else: + with expect_raises_message(exc.DatabaseError, "ORA-03060"): + t1.create(connection) + + def test_vector_insert_array(self, metadata, connection): + t1 = Table( + "t1", + metadata, + Column("id", Integer, primary_key=True), + Column("c1", VECTOR), + ) + + if testing.against("oracle>23.4"): + t1.create(connection) + connection.execute( + t1.insert(), + dict(id=1, c1=array.array("b", [6, 7, 8, 5])), + ) + eq_( + connection.execute(t1.select()).first(), + (1, [6, 7, 8, 5]), + ) + + connection.execute(t1.delete().where(t1.c.id == 1)) + + connection.execute( + t1.insert(), dict(id=1, c1=array.array("b", [6, 7])) + ) + eq_( + connection.execute(t1.select()).first(), + (1, [6, 7]), + ) + else: + with expect_raises_message(exc.DatabaseError, "ORA-03060"): + t1.create(connection) + + def test_vector_multiformat_insert(self, metadata, connection): + t1 = Table( + "t1", + metadata, + Column("id", Integer, primary_key=True), + Column("c1", VECTOR), + ) + + if testing.against("oracle>23.4"): + t1.create(connection) + connection.execute( + t1.insert(), + dict(id=1, c1=[6.12, 7.54, 8.33]), + ) + eq_( + connection.execute(t1.select()).first(), + (1, [6.12, 7.54, 8.33]), + ) + connection.execute(t1.delete().where(t1.c.id == 1)) + connection.execute(t1.insert(), dict(id=1, c1=[6, 7])) + eq_( + connection.execute(t1.select()).first(), + (1, [6, 7]), + ) + else: + with expect_raises_message(exc.DatabaseError, "ORA-03060"): + t1.create(connection) + + def test_vector_format(self, metadata, connection): + t1 = Table("t1", metadata, Column("c1", VECTOR(3, "float32"))) + + if testing.against("oracle>23.4"): + t1.create(connection) + eq_(t1.c.c1.type.storage_format, "float32") + else: + with expect_raises_message(exc.DatabaseError, "ORA-03060"): + t1.create(connection) + + def test_vector_hnsw_index(self, metadata, connection): + t1 = Table( + "t1", + metadata, + Column("id", Integer), + Column("embedding", VECTOR(3, "float32")), + ) + + if testing.against("oracle>23.4"): + t1.create(connection) + + hnsw_index = Index( + "hnsw_vector_index", t1.c.embedding, oracle_vector=True + ) + hnsw_index.create(connection) + + connection.execute(t1.insert(), dict(id=1, embedding=[6, 7, 8])) + eq_( + connection.execute(t1.select()).first(), + (1, [6.0, 7.0, 8.0]), + ) + else: + with expect_raises_message(exc.DatabaseError, "ORA-03060"): + t1.create(connection) + + def test_vector_ivf_index(self, metadata, connection): + t1 = Table( + "t1", + metadata, + Column("id", Integer), + Column("embedding", VECTOR(3, "float32")), + ) + + if testing.against("oracle>23.4"): + t1.create(connection) + ivf_index = Index( + "ivf_vector_index", + t1.c.embedding, + oracle_vector={ + "accuracy": 90, + "distance": "DOT", + "parameters": {"type": "IVF", "neighbor partitions": 10}, + }, + ) + ivf_index.create(connection) + + connection.execute(t1.insert(), dict(id=1, embedding=[6, 7, 8])) + eq_( + connection.execute(t1.select()).first(), + (1, [6.0, 7.0, 8.0]), + ) + else: + with expect_raises_message(exc.DatabaseError, "ORA-03060"): + t1.create(connection) + + def test_vector_l2_distance(self, metadata, connection): + t1 = Table( + "t1", + metadata, + Column("id", Integer), + Column("embedding", VECTOR(3, "int8")), + ) + + if testing.against("oracle>23.4"): + t1.create(connection) + + connection.execute(t1.insert(), dict(id=1, embedding=[8, 9, 10])) + connection.execute(t1.insert(), dict(id=2, embedding=[1, 2, 3])) + connection.execute( + t1.insert(), + dict(id=3, embedding=[15, 16, 17]), + ) + + query_vector = [2, 3, 4] + res = connection.execute( + t1.select().order_by( + (t1.c.embedding.l2_distance(query_vector)) + ) + ).first() + eq_(res.embedding, [1, 2, 3]) + else: + with expect_raises_message(exc.DatabaseError, "ORA-03060"): + t1.create(connection) + class LOBFetchTest(fixtures.TablesTest): __only_on__ = "oracle"