]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Fixed compilation of bitwise operators on oracle and sqlite.
authorFederico Caselli <cfederico87@gmail.com>
Mon, 29 Jul 2024 21:52:04 +0000 (23:52 +0200)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 1 Aug 2024 13:14:30 +0000 (09:14 -0400)
Implemented bitwise operators for Oracle which was previously
non-functional due to a non-standard syntax used by this database.
Oracle's support for bitwise "or" and "xor" starts with server version 21.
Additionally repaired the implementation of "xor" for SQLite.

As part of this change, the dialect compliance test suite has been enhanced
to include support for server-side bitwise tests; third party dialect
authors should refer to new "supports_bitwise" methods in the
requirements.py file to enable these tests.

Fixes: #11663
Change-Id: I41040bd67992b6c89ed3592edca8965d5d59be9e

doc/build/changelog/unreleased_20/11663.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/oracle/base.py
lib/sqlalchemy/dialects/sqlite/base.py
lib/sqlalchemy/testing/requirements.py
lib/sqlalchemy/testing/suite/test_select.py
test/requirements.py

diff --git a/doc/build/changelog/unreleased_20/11663.rst b/doc/build/changelog/unreleased_20/11663.rst
new file mode 100644 (file)
index 0000000..599cd74
--- /dev/null
@@ -0,0 +1,16 @@
+.. change::
+    :tags: bug, oracle, sqlite
+    :tickets: 11663
+
+    Implemented bitwise operators for Oracle which was previously
+    non-functional due to a non-standard syntax used by this database.
+    Oracle's support for bitwise "or" and "xor" starts with server version 21.
+    Additionally repaired the implementation of "xor" for SQLite.
+
+    As part of this change, the dialect compliance test suite has been enhanced
+    to include support for server-side bitwise tests; third party dialect
+    authors should refer to new "supports_bitwise" methods in the
+    requirements.py file to enable these tests.
+
+
+
index 8e5989990ef1badbfd97ad9954c04a8fb4f12d10..058becf831e7f1cc4148d1f6419843f153c3a637 100644 (file)
@@ -1256,6 +1256,31 @@ class OracleCompiler(compiler.SQLCompiler):
     def visit_aggregate_strings_func(self, fn, **kw):
         return "LISTAGG%s" % self.function_argspec(fn, **kw)
 
+    def _visit_bitwise(self, binary, fn_name, custom_right=None, **kw):
+        left = self.process(binary.left, **kw)
+        right = self.process(
+            custom_right if custom_right is not None else binary.right, **kw
+        )
+        return f"{fn_name}({left}, {right})"
+
+    def visit_bitwise_xor_op_binary(self, binary, operator, **kw):
+        return self._visit_bitwise(binary, "BITXOR", **kw)
+
+    def visit_bitwise_or_op_binary(self, binary, operator, **kw):
+        return self._visit_bitwise(binary, "BITOR", **kw)
+
+    def visit_bitwise_and_op_binary(self, binary, operator, **kw):
+        return self._visit_bitwise(binary, "BITAND", **kw)
+
+    def visit_bitwise_rshift_op_binary(self, binary, operator, **kw):
+        raise exc.CompileError("Cannot compile bitwise_rshift in oracle")
+
+    def visit_bitwise_lshift_op_binary(self, binary, operator, **kw):
+        raise exc.CompileError("Cannot compile bitwise_lshift in oracle")
+
+    def visit_bitwise_not_op_unary_operator(self, element, operator, **kw):
+        raise exc.CompileError("Cannot compile bitwise_not in oracle")
+
 
 class OracleDDLCompiler(compiler.DDLCompiler):
     def define_constraint_cascades(self, constraint):
index 8e3f7a560e0c7af5a424b366a6321727a04a49ee..04e84a68d2ea6b2b9465c70f5fbcaf38dc905bb6 100644 (file)
@@ -1528,6 +1528,13 @@ class SQLiteCompiler(compiler.SQLCompiler):
 
         return "ON CONFLICT %s DO UPDATE SET %s" % (target_text, action_text)
 
+    def visit_bitwise_xor_op_binary(self, binary, operator, **kw):
+        # sqlite has no xor. Use "a XOR b" = "(a | b) - (a & b)".
+        kw["eager_grouping"] = True
+        or_ = self._generate_generic_binary(binary, " | ", **kw)
+        and_ = self._generate_generic_binary(binary, " & ", **kw)
+        return f"({or_} - {and_})"
+
 
 class SQLiteDDLCompiler(compiler.DDLCompiler):
     def get_column_specification(self, column, **kwargs):
index ee175524fb0847929314a157eb4b73683f06d126..3b53dd943f40ebe9f7a2fd9f659f7985cc2527f0 100644 (file)
@@ -1776,3 +1776,28 @@ class SuiteRequirements(Requirements):
     def materialized_views_reflect_pk(self):
         """Target database reflect MATERIALIZED VIEWs pks."""
         return exclusions.closed()
+
+    @property
+    def supports_bitwise_or(self):
+        """Target database supports bitwise or"""
+        return exclusions.closed()
+
+    @property
+    def supports_bitwise_and(self):
+        """Target database supports bitwise and"""
+        return exclusions.closed()
+
+    @property
+    def supports_bitwise_not(self):
+        """Target database supports bitwise not"""
+        return exclusions.closed()
+
+    @property
+    def supports_bitwise_xor(self):
+        """Target database supports bitwise xor"""
+        return exclusions.closed()
+
+    @property
+    def supports_bitwise_shift(self):
+        """Target database supports bitwise left or right shift"""
+        return exclusions.closed()
index 882ca4596786f6c7e592441ede87cdb6d178b351..d81e5a04c89b6c5e7a0c2ad8e026efa961a23ae9 100644 (file)
@@ -1951,3 +1951,63 @@ class WindowFunctionTest(fixtures.TablesTest):
                 ).all()
 
                 eq_(result_rows, [(i,) for i in expected])
+
+
+class BitwiseTest(fixtures.TablesTest):
+    __backend__ = True
+    run_inserts = run_deletes = "once"
+
+    inserted_data = [{"a": i, "b": i + 1} for i in range(10)]
+
+    @classmethod
+    def define_tables(cls, metadata):
+        Table("bitwise", metadata, Column("a", Integer), Column("b", Integer))
+
+    @classmethod
+    def insert_data(cls, connection):
+        connection.execute(cls.tables.bitwise.insert(), cls.inserted_data)
+
+    @testing.combinations(
+        (
+            lambda a: a.bitwise_xor(5),
+            [i for i in range(10) if i != 5],
+            testing.requires.supports_bitwise_xor,
+        ),
+        (
+            lambda a: a.bitwise_or(1),
+            list(range(10)),
+            testing.requires.supports_bitwise_or,
+        ),
+        (
+            lambda a: a.bitwise_and(4),
+            list(range(4, 8)),
+            testing.requires.supports_bitwise_and,
+        ),
+        (
+            lambda a: (a - 2).bitwise_not(),
+            [0],
+            testing.requires.supports_bitwise_not,
+        ),
+        (
+            lambda a: a.bitwise_lshift(1),
+            list(range(1, 10)),
+            testing.requires.supports_bitwise_shift,
+        ),
+        (
+            lambda a: a.bitwise_rshift(2),
+            list(range(4, 10)),
+            testing.requires.supports_bitwise_shift,
+        ),
+        argnames="case, expected",
+    )
+    def test_bitwise(self, case, expected, connection):
+        tbl = self.tables.bitwise
+
+        a = tbl.c.a
+
+        op = testing.resolve_lambda(case, a=a)
+
+        stmt = select(tbl).where(op > 0).order_by(a)
+
+        res = connection.execute(stmt).mappings().all()
+        eq_(res, [self.inserted_data[i] for i in expected])
index 0f6fb3f0e3829f8d59f5a9e3cf3ee8dbd2b875ce..9d12652de25ca4f8e8b541abbdf673489c9a2514 100644 (file)
@@ -2068,3 +2068,28 @@ class DefaultRequirements(SuiteRequirements):
         statement.
         """
         return only_on(["mssql"])
+
+    @property
+    def supports_bitwise_and(self):
+        """Target database supports bitwise and"""
+        return exclusions.open()
+
+    @property
+    def supports_bitwise_or(self):
+        """Target database supports bitwise or"""
+        return fails_on(["oracle<21"])
+
+    @property
+    def supports_bitwise_not(self):
+        """Target database supports bitwise not"""
+        return fails_on(["oracle", "mysql", "mariadb"])
+
+    @property
+    def supports_bitwise_xor(self):
+        """Target database supports bitwise xor"""
+        return fails_on(["oracle<21"])
+
+    @property
+    def supports_bitwise_shift(self):
+        """Target database supports bitwise left or right shift"""
+        return fails_on(["oracle"])