]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Fixes: #12317 Added vector datatype support in Oracle dialect
authorsuraj <suraj.shaw@oracle.com>
Thu, 6 Feb 2025 08:02:47 +0000 (13:32 +0530)
committersuraj <suraj.shaw@oracle.com>
Thu, 6 Feb 2025 15:38:28 +0000 (21:08 +0530)
Fixes: #12317 Added vector datatype support in Oracle dialect
lib/sqlalchemy/dialects/oracle/__init__.py
lib/sqlalchemy/dialects/oracle/base.py
lib/sqlalchemy/dialects/oracle/oracledb.py
lib/sqlalchemy/dialects/oracle/types.py
test/dialect/oracle/test_types.py

index 7ceb743d616ec31b95e0012fe1ffb53dda801e1a..c05d8bdf872961b15bcf8ad967ee9807cd23798e 100644 (file)
@@ -32,6 +32,7 @@ from .base import ROWID
 from .base import TIMESTAMP
 from .base import VARCHAR
 from .base import VARCHAR2
+from .base import VECTOR
 
 # Alias oracledb also as oracledb_async
 oracledb_async = type(
@@ -64,4 +65,5 @@ __all__ = (
     "NVARCHAR2",
     "ROWID",
     "REAL",
+    "VECTOR",
 )
index 3d3ff9d5170b3d7d5ae319ac47c71a5b0d83d5e8..131a341bb432c453cf2f6113ba24284db74f543e 100644 (file)
@@ -771,6 +771,7 @@ from .types import RAW
 from .types import ROWID  # noqa
 from .types import TIMESTAMP
 from .types import VARCHAR2  # noqa
+from .types import VECTOR
 from ... import Computed
 from ... import exc
 from ... import schema as sa_schema
@@ -850,6 +851,7 @@ ischema_names = {
     "BINARY_DOUBLE": BINARY_DOUBLE,
     "BINARY_FLOAT": BINARY_FLOAT,
     "ROWID": ROWID,
+    "VECTOR": VECTOR,
 }
 
 
@@ -1007,6 +1009,16 @@ class OracleTypeCompiler(compiler.GenericTypeCompiler):
     def visit_ROWID(self, type_, **kw):
         return "ROWID"
 
+    def visit_VECTOR(self, type_, **kw):
+        if type_.dim is None and type_.storage_format is None:
+            return f"VECTOR"
+        elif type_.storage_format is None:
+            return f"VECTOR({type_.dim},*)"
+        elif type_.dim is None:
+            return f"VECTOR(*, {type_.storage_format})"
+        else:
+            return f"VECTOR({type_.dim},{type_.storage_format})"
+
 
 class OracleCompiler(compiler.SQLCompiler):
     """Oracle compiler modifies the lexical structure of Select
@@ -1525,6 +1537,9 @@ class OracleDDLCompiler(compiler.DDLCompiler):
             text += "UNIQUE "
         if index.dialect_options["oracle"]["bitmap"]:
             text += "BITMAP "
+        vector_options = index.dialect_options["oracle"]["vector"]
+        if vector_options:
+            text += "VECTOR "
         text += "INDEX %s ON %s (%s)" % (
             self._prepared_index_name(index, include_schema=True),
             preparer.format_table(index.table, use_schema=True),
@@ -1542,6 +1557,45 @@ class OracleDDLCompiler(compiler.DDLCompiler):
                 text += " COMPRESS %d" % (
                     index.dialect_options["oracle"]["compress"]
                 )
+        if vector_options:
+            if vector_options is True:
+                vector_options = {}
+            parts = []
+            parameters = vector_options.get("parameters", {})
+            using = parameters.get("type", "HNSW").upper()
+            if using == "HNSW":
+                parts.append("ORGANIZATION INMEMORY NEIGHBOR GRAPH")
+            elif using == "IVF":
+                parts.append("ORGANIZATION NEIGHBOR PARTITIONS")
+            vector_distance = vector_options.get("distance")
+            if vector_distance is not None:
+                vector_distance = vector_distance.upper()
+                if vector_distance not in (
+                    "EUCLIDEAN",
+                    "DOT",
+                    "COSINE",
+                    "MANHATTAN",
+                ):
+                    raise ValueError("Unknown vector_distance value")
+                parts.append(f"DISTANCE {vector_distance}")
+            target_accuracy = vector_options.get("accuracy")
+            if target_accuracy is not None:
+                if target_accuracy < 0 or target_accuracy > 100:
+                    raise ValueError(
+                        "Accuracy value should be an integer between 0 and 100"
+                    )
+                parts.append(f"WITH TARGET ACCURACY {target_accuracy}")
+            if parameters:
+                parameters_str = ", ".join(
+                    f"{k} {v}" for k, v in parameters.items()
+                )
+                parts.append(f"PARAMETERS ({parameters_str})")
+            parallel = vector_options.get("parallel")
+            if parallel is not None:
+                if not isinstance(parallel, int):
+                    raise ValueError("Parallel value must be an integer")
+                parts.append(f"PARALLEL {parallel}")
+            text += " " + " ".join(parts)
         return text
 
     def post_create_table(self, table):
@@ -1693,7 +1747,14 @@ class OracleDialect(default.DefaultDialect):
                 "tablespace": None,
             },
         ),
-        (sa_schema.Index, {"bitmap": False, "compress": False}),
+        (
+            sa_schema.Index,
+            {
+                "bitmap": False,
+                "compress": False,
+                "vector": False,
+            },
+        ),
         (sa_schema.Sequence, {"order": None}),
         (sa_schema.Identity, {"order": None, "on_null": None}),
     ]
index 8105608837f752274ae22de030f9032856af9e61..201902f997f04a9db96c474289fccd673c9d30b7 100644 (file)
@@ -591,6 +591,228 @@ SQLAlchemy type (or a subclass of such).
 
 .. versionadded:: 2.0.0 added support for the python-oracledb driver.
 
+VECTOR Datatype
+---------------
+
+Oracle Database 23ai introduced a new VECTOR datatype for artificial intelligence and machine
+learning search operations. The VECTOR datatype is a homogeneous array of 8-bit signed integers,
+8-bit unsigned integers, 32-bit floating-point numbers, or 64-bit floating-point numbers.
+For more information on the VECTOR datatype please visit this `link.
+<https://python-oracledb.readthedocs.io/en/latest/user_guide/vector_data_type.html>`_
+
+CREATE TABLE
+^^^^^^^^^^^^
+
+With the VECTOR datatype, you can specify the dimension for the data and the storage
+format. To create a table that includes a VECTOR column::
+
+    from sqlalchemy.dialects.oracle import VECTOR
+
+    t = Table("t1", metadata,
+        Column('id', Integer, primary_key=True),
+        Column("embedding", VECTOR(3, 'float32'),
+        Column(...), ...
+    )
+
+Vectors can also be defined with an arbitrary number of dimensions and formats. This allows
+you to specify vectors of different dimensions with the various storage formats mentioned above.
+
+For Example
+
+* In this case, the storage format is flexible, allowing any vector type data to be inserted,
+  such as int8 or binary etc.
+    
+    vector_col:Mapped[array.array] = mapped_column(VECTOR(3))
+
+* The dimension is flexible in this case, meaning that any dimension vector can be used.
+   
+    vector_col:Mapped[array.array] = mapped_column(VECTOR('int8'))
+
+* Both the dimensions and the storage format are flexible.
+
+    vector_col:Mapped[array.array] = mapped_column(VECTOR)
+
+INSERT VECTOR DATA
+^^^^^^^^^^^^^^^^^^
+
+VECTOR data can be inserted using Python list or Python array.array() objects. Python arrays of type
+float (32-bit), double (64-bit), or int8_t (8-bit signed integer) are used as bind values when
+inserting VECTOR columns::
+
+    from sqlalchemy import insert, select
+    import array
+
+    vector_data_8 = [1, 2, 3]
+    statement  = insert(t1)
+    with engine.connect() as conn:
+        conn.execute(statement,[
+            {"id":1,"embedding":vector_data_8},
+            ])
+    
+VECTOR INDEXES
+^^^^^^^^^^^^^^
+
+There are two VECTOR indexes supported in VECTOR search: IVF Flat index and HNSW
+index.
+
+To utilize VECTOR indexing, set the `oracle_vector` parameter to True to use
+the default values provided by Oracle. HNSW is the default indexing method::
+
+    Index(
+            'vector_index',
+            t1.c.embedding,
+            oracle_vector = True,
+        )
+
+If you wish to use custom parameters, you can specify all the parameters as a dictionary
+in the `oracle_vector` option. To learn more about the parameters that can be passed please
+visit this `link. <https://docs.oracle.com/en/database/oracle/oracle-database/23/sqlrf/create-vector-index.html>`_
+
+Configuring Oracle VECTOR Indexes
+=================================
+
+When using Oracle VECTOR indexes, the configuration parameters are divided into two levels:
+**top-level keys** and **nested keys**. This structure applies to both HNSW and IVF VECTOR indexes.
+
+Top-Level Keys
+==============
+
+These keys are specified directly under the ``oracle_vector`` dictionary.
+
+* ``accuracy``:
+    - Specifies the accuracy of the nearest neighbor search during query execution.
+    - **Valid Range**: Greater than 0 and less than or equal to 100.
+    - **Example**: ``'accuracy': 95``
+
+* ``distance``:
+    - Specifies the metric for calculating distance between VECTORS.
+    - **Valid Values**: ``"EUCLIDEAN"``, ``"COSINE"``, ``"DOT"``, ``"MANHATTAN"``.
+    - **Example**: ``'distance': "COSINE"``
+
+* ``parameters``:
+    - A nested dictionary where method-specific options are defined (e.g., HNSW or IVF-specific settings).
+    - **Example**: ``'parameters': {...}``
+
+Nested keys in parameters
+=========================
+
+These keys are specific to the indexing method and are included under the ``parameters`` dictionary.
+
+HNSW Parameters
+===============
+
+* ``type``:
+    - Specifies the indexing method. For HNSW, this must be ``"HNSW"``.
+    - **Placement**: Nested under ``parameters``.
+    - **Example**: ``'type': 'HNSW'``
+
+* ``neighbors``:
+    - The number of nearest neighbors considered during the search.
+    - **Valid Range**: Greater than 0 and less than or equal to 2048.
+    - **Placement**: Nested under ``parameters``.
+    - **Example**: ``'neighbors': 20``
+
+* ``efconstruction``:
+    - Controls the trade-off between indexing speed and recall quality during index construction.
+    - **Valid Range**: Greater than 0 and less than or equal to 65535.
+    - **Placement**: Nested under ``parameters``.
+    - **Example**: ``'efconstruction': 300``
+
+IVF Parameters
+==============
+
+* ``type``:
+    - Specifies the indexing method. For IVF, this must be ``"IVF"``.
+    - **Placement**: Nested under ``parameters``.
+    - **Example**: ``'type': 'IVF'``
+   
+* ``neighbor partitions``:
+    - The number of partitions used to divide the dataset.
+    - **Valid Range**: Greater than 0 and less than or equal to 10,000,000.
+    - **Placement**: Nested under ``parameters``.
+    - **Example**: ``'neighbor partitions': 10``
+
+* ``sample_per_partition``:
+    - The number of samples used per partition.
+    - **Valid Range**: Between 1 and ``num_vectors / neighbor_partitions``.
+    - **Placement**: Nested under ``parameters``.
+    - **Example**: ``'sample_per_partition': 5``
+
+* ``min_vectors_per_partition``:
+    - The minimum number of vectors per partition.
+    - **Valid Range**: From 0 (no trimming) to the total number of vectors (results in 1 partition).
+    - **Placement**: Nested under ``parameters``.
+    - **Example**: ``'min_vectors_per_partition': 100``
+
+Example Configurations
+======================
+
+For custom configurations, the parameters can be specified as shown in the following examples::
+
+    Index(
+            'hnsw_vector_index',
+            t1.c.embedding,
+            oracle_vector = {
+                'accuracy':95,
+                'distance':"COSINE",
+                'parameters':{
+                    'type':'HNSW',
+                    'neighbors':20',
+                    efconstruction':300
+                }
+            }
+        )
+
+    Index(
+            'ivf_vector_index',
+            t1.c.embedding,
+            oracle_vector = {
+                'accuracy':90,
+                'distance':"DOT",
+                'parameters':{
+                    'type':'IVF','
+                    neighbor partitions':10
+                }
+            }
+        )
+
+Similarity Searching
+^^^^^^^^^^^^^^^^^^^^
+
+You can  use the following shorthand VECTOR distance functions:
+
+* ``l2_distance``
+* ``cosine_distance``
+* ``inner_product``
+
+Example Usage::
+
+    from sqlalchemy.orm import Session
+    from sqlalchemy.sql import func
+    import array
+  
+    session = Session(bind=engine)
+    query_vector = [2,3,4]
+    result_vector = session.scalars(select(t1).order_by(t1.embedding.l2_distance(query_vector)).limit(3))
+    for user in vector:
+        print(user.id,user.embedding)
+
+Exact and Approximate Searching
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Similarity searches tend to get data from one or more clusters depending on the value of the query VECTOR and the fetch
+size. Approximate searches using VECTOR indexes can limit the searches to specific clusters, whereas exact searches visit
+VECTORS across all clusters.
+You can use the fetch_type clause to set the searching to be either EXACT, APPROX or APPROXIMATE::
+
+    result_vector = session.scalars(select(t1).order_by(t1.embedding.l1_distance(query_vector)).limit(3)).fetch_type("EXACT")
+    result_vector = session.scalars(select(t1).order_by(t1.embedding.l1_distance(query_vector)).limit(3)).fetch_type("APPROX")
+
+
+.. versionadded:: 2.1.0 Added support for VECTOR specific to Oracle Database.
+:class:`_oracle.VECTOR` datatype.
+
 """  # noqa
 from __future__ import annotations
 
index 06aeaace2f5fc7dba4f14cbab4d13822064f7701..f128a42352c555b8c418ca0ffeef844f5c9e673e 100644 (file)
@@ -11,11 +11,14 @@ import datetime as dt
 from typing import Optional
 from typing import Type
 from typing import TYPE_CHECKING
+import sqlalchemy.types as types
+from sqlalchemy.types import UserDefinedType, Float
 
 from ... import exc
 from ...sql import sqltypes
 from ...types import NVARCHAR
 from ...types import VARCHAR
+import array
 
 if TYPE_CHECKING:
     from ...engine.interfaces import Dialect
@@ -314,3 +317,82 @@ class ROWID(sqltypes.TypeEngine):
 class _OracleBoolean(sqltypes.Boolean):
     def get_dbapi_type(self, dbapi):
         return dbapi.NUMBER
+
+
+class VECTOR(types.TypeEngine):
+    """Oracle VECTOR datatype."""
+
+    cache_ok = True
+    __visit_name__ = "VECTOR"
+
+    def __init__(self, dim=None, storage_format=None, *args):
+        """
+        :param dim: The dimension of the VECTOR datatype. This should be an
+        integer value.
+        :param storage_format: The VECTOR storage type format. This
+        may be int8, binary, float32, or float64.
+        """
+        if dim is not None and isinstance(dim, int):
+            self.dim = dim
+            self.storage_format = storage_format
+
+        elif dim is not None and isinstance(dim, str):
+            self.dim = storage_format
+            self.storage_format = dim
+
+        else:
+            self.dim = storage_format
+            self.storage_format = dim
+
+    def _cached_bind_processor(self, dialect):
+        """
+        Convert a list to a array.array before binding it to the database.
+        """
+
+        def process(value):
+            if value is None or isinstance(value, array.array):
+                return value
+
+            # Convert list to a array.array
+            elif isinstance(value, list):
+                format = self._array_typecode(self.storage_format)
+                value = array.array(format, value)
+                return value
+
+            else:
+                raise TypeError("VECTOR accepts list or array.array()")
+
+        return process
+
+    def _cached_result_processor(self, dialect, coltype):
+        """
+        Convert a array.array to list before binding it to the database.
+        """
+
+        def process(value):
+            if isinstance(value, array.array):
+                return list(value)
+
+        return process
+
+    def _array_typecode(self, format):
+        """
+        Map storage format to array typecode.
+        """
+        typecode_map = {
+            "int8": "b",  # Signed int
+            "binary": "B",  # Unsigned int
+            "float32": "f",  # Float
+            "float64": "d",  # Double
+        }
+        return typecode_map.get(format, "d")
+
+    class comparator_factory(types.TypeEngine.Comparator):
+        def l2_distance(self, other):
+            return self.op("<->", return_type=Float)(other)
+
+        def inner_product(self, other):
+            return self.op("<#>", return_type=Float)(other)
+
+        def cosine_distance(self, other):
+            return self.op("<=>", return_type=Float)(other)
index b5ce61222e8f5ab6c5d300032145447b09a1ccd4..cf22f20f976c5d063f758d9c323ee2b60dc74ba4 100644 (file)
@@ -2,6 +2,7 @@ import datetime
 import decimal
 import os
 import random
+import array
 
 from sqlalchemy import bindparam
 from sqlalchemy import cast
@@ -34,9 +35,12 @@ from sqlalchemy import types as sqltypes
 from sqlalchemy import Unicode
 from sqlalchemy import UnicodeText
 from sqlalchemy import VARCHAR
+from sqlalchemy import Index
+from sqlalchemy import delete
 from sqlalchemy.dialects.oracle import base as oracle
 from sqlalchemy.dialects.oracle import cx_oracle
 from sqlalchemy.dialects.oracle import oracledb
+from sqlalchemy.dialects.oracle import VECTOR
 from sqlalchemy.sql import column
 from sqlalchemy.sql.sqltypes import NullType
 from sqlalchemy.testing import AssertsCompiledSQL
@@ -951,6 +955,198 @@ class TypesTest(fixtures.TestBase):
         finally:
             exec_sql(connection, "DROP TABLE Z_TEST")
 
+    def test_vector_dim(self, metadata, connection):
+        t1 = Table("t1", metadata, Column("c1", VECTOR(3, "float32")))
+
+        if testing.against("oracle>23.4"):
+            t1.create(connection)
+            eq_(t1.c.c1.type.dim, 3)
+        else:
+            with expect_raises_message(exc.DatabaseError, "ORA-03060"):
+                t1.create(connection)
+
+    def test_vector_insert(self, metadata, connection):
+        t1 = Table(
+            "t1",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            Column("c1", VECTOR("int8")),
+        )
+
+        if testing.against("oracle>23.4"):
+            t1.create(connection)
+            connection.execute(
+                t1.insert(),
+                dict(id=1, c1=[6, 7, 8, 5]),
+            )
+            eq_(
+                connection.execute(t1.select()).first(),
+                (1, [6, 7, 8, 5]),
+            )
+            connection.execute(t1.delete().where(t1.c.id == 1))
+            connection.execute(t1.insert(), dict(id=1, c1=[6, 7]))
+            eq_(
+                connection.execute(t1.select()).first(),
+                (1, [6, 7]),
+            )
+        else:
+            with expect_raises_message(exc.DatabaseError, "ORA-03060"):
+                t1.create(connection)
+
+    def test_vector_insert_array(self, metadata, connection):
+        t1 = Table(
+            "t1",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            Column("c1", VECTOR),
+        )
+
+        if testing.against("oracle>23.4"):
+            t1.create(connection)
+            connection.execute(
+                t1.insert(),
+                dict(id=1, c1=array.array("b", [6, 7, 8, 5])),
+            )
+            eq_(
+                connection.execute(t1.select()).first(),
+                (1, [6, 7, 8, 5]),
+            )
+
+            connection.execute(t1.delete().where(t1.c.id == 1))
+
+            connection.execute(
+                t1.insert(), dict(id=1, c1=array.array("b", [6, 7]))
+            )
+            eq_(
+                connection.execute(t1.select()).first(),
+                (1, [6, 7]),
+            )
+        else:
+            with expect_raises_message(exc.DatabaseError, "ORA-03060"):
+                t1.create(connection)
+
+    def test_vector_multiformat_insert(self, metadata, connection):
+        t1 = Table(
+            "t1",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            Column("c1", VECTOR),
+        )
+
+        if testing.against("oracle>23.4"):
+            t1.create(connection)
+            connection.execute(
+                t1.insert(),
+                dict(id=1, c1=[6.12, 7.54, 8.33]),
+            )
+            eq_(
+                connection.execute(t1.select()).first(),
+                (1, [6.12, 7.54, 8.33]),
+            )
+            connection.execute(t1.delete().where(t1.c.id == 1))
+            connection.execute(t1.insert(), dict(id=1, c1=[6, 7]))
+            eq_(
+                connection.execute(t1.select()).first(),
+                (1, [6, 7]),
+            )
+        else:
+            with expect_raises_message(exc.DatabaseError, "ORA-03060"):
+                t1.create(connection)
+
+    def test_vector_format(self, metadata, connection):
+        t1 = Table("t1", metadata, Column("c1", VECTOR(3, "float32")))
+
+        if testing.against("oracle>23.4"):
+            t1.create(connection)
+            eq_(t1.c.c1.type.storage_format, "float32")
+        else:
+            with expect_raises_message(exc.DatabaseError, "ORA-03060"):
+                t1.create(connection)
+
+    def test_vector_hnsw_index(self, metadata, connection):
+        t1 = Table(
+            "t1",
+            metadata,
+            Column("id", Integer),
+            Column("embedding", VECTOR(3, "float32")),
+        )
+
+        if testing.against("oracle>23.4"):
+            t1.create(connection)
+
+            hnsw_index = Index(
+                "hnsw_vector_index", t1.c.embedding, oracle_vector=True
+            )
+            hnsw_index.create(connection)
+
+            connection.execute(t1.insert(), dict(id=1, embedding=[6, 7, 8]))
+            eq_(
+                connection.execute(t1.select()).first(),
+                (1, [6.0, 7.0, 8.0]),
+            )
+        else:
+            with expect_raises_message(exc.DatabaseError, "ORA-03060"):
+                t1.create(connection)
+
+    def test_vector_ivf_index(self, metadata, connection):
+        t1 = Table(
+            "t1",
+            metadata,
+            Column("id", Integer),
+            Column("embedding", VECTOR(3, "float32")),
+        )
+
+        if testing.against("oracle>23.4"):
+            t1.create(connection)
+            ivf_index = Index(
+                "ivf_vector_index",
+                t1.c.embedding,
+                oracle_vector={
+                    "accuracy": 90,
+                    "distance": "DOT",
+                    "parameters": {"type": "IVF", "neighbor partitions": 10},
+                },
+            )
+            ivf_index.create(connection)
+
+            connection.execute(t1.insert(), dict(id=1, embedding=[6, 7, 8]))
+            eq_(
+                connection.execute(t1.select()).first(),
+                (1, [6.0, 7.0, 8.0]),
+            )
+        else:
+            with expect_raises_message(exc.DatabaseError, "ORA-03060"):
+                t1.create(connection)
+
+    def test_vector_l2_distance(self, metadata, connection):
+        t1 = Table(
+            "t1",
+            metadata,
+            Column("id", Integer),
+            Column("embedding", VECTOR(3, "int8")),
+        )
+
+        if testing.against("oracle>23.4"):
+            t1.create(connection)
+
+            connection.execute(t1.insert(), dict(id=1, embedding=[8, 9, 10]))
+            connection.execute(t1.insert(), dict(id=2, embedding=[1, 2, 3]))
+            connection.execute(
+                t1.insert(),
+                dict(id=3, embedding=[15, 16, 17]),
+            )
+
+            query_vector = [2, 3, 4]
+            res = connection.execute(
+                t1.select().order_by(
+                    (t1.c.embedding.l2_distance(query_vector))
+                )
+            ).first()
+            eq_(res.embedding, [1, 2, 3])
+        else:
+            with expect_raises_message(exc.DatabaseError, "ORA-03060"):
+                t1.create(connection)
+
 
 class LOBFetchTest(fixtures.TablesTest):
     __only_on__ = "oracle"