]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
invoke mariadb-connector .rowcount after all statements
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 28 Sep 2023 12:58:16 +0000 (08:58 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 29 Sep 2023 16:17:26 +0000 (12:17 -0400)
Modified the mariadb-connector driver to pre-load the ``cursor.rowcount``
value for all queries, to suit tools such as Pandas that hardcode to
calling :attr:`.Result.rowcount` in this way. SQLAlchemy normally pre-loads
``cursor.rowcount`` only for UPDATE/DELETE statements and otherwise passes
through to the DBAPI where it can return -1 if no value is available.
However, mariadb-connector does not support invoking ``cursor.rowcount``
after the cursor itself is closed, raising an error instead.  Generic test
support has been added to ensure all backends support the allowing
:attr:`.Result.rowcount` to succceed (that is, returning an integer value
with -1 for "not available") after the result is closed.

This change also restores mariadb-connector to CI including
as part of the "dbdriver" suite; in 366a5e3e2e503a20ef0334fbf9f we had
taken it out of the DBAPI main job.

Additional fixes for the mariadb-connector dialect to support UUID data
values in the result in INSERT..RETURNING statements.

Added rounding to one remaining INSERT..RETURNING with floats test
to allow mariadbconnector to pass (likely similar issue as the one with
UUID but not worth making a new handler)

Fixes: #10396
Change-Id: Ic11b1b5d0c41356863829d0eacbb812d401e8dd1

doc/build/changelog/unreleased_20/10396.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/mysql/mariadbconnector.py
lib/sqlalchemy/dialects/oracle/cx_oracle.py
lib/sqlalchemy/engine/cursor.py
lib/sqlalchemy/testing/suite/test_insert.py
lib/sqlalchemy/testing/suite/test_rowcount.py
test/requirements.py
tox.ini

diff --git a/doc/build/changelog/unreleased_20/10396.rst b/doc/build/changelog/unreleased_20/10396.rst
new file mode 100644 (file)
index 0000000..e2e33eb
--- /dev/null
@@ -0,0 +1,22 @@
+.. change::
+    :tags: bug, mariadb
+    :tickets: 10396
+
+    Modified the mariadb-connector driver to pre-load the ``cursor.rowcount``
+    value for all queries, to suit tools such as Pandas that hardcode to
+    calling :attr:`.Result.rowcount` in this way. SQLAlchemy normally pre-loads
+    ``cursor.rowcount`` only for UPDATE/DELETE statements and otherwise passes
+    through to the DBAPI where it can return -1 if no value is available.
+    However, mariadb-connector does not support invoking ``cursor.rowcount``
+    after the cursor itself is closed, raising an error instead.  Generic test
+    support has been added to ensure all backends support the allowing
+    :attr:`.Result.rowcount` to succceed (that is, returning an integer
+    value with -1 for "not available") after the result is closed.
+
+
+
+.. change::
+    :tags: bug, mariadb
+
+    Additional fixes for the mariadb-connector dialect to support UUID data
+    values in the result in INSERT..RETURNING statements.
index 896b332c12388b1c836147aaaca77f66733a67a8..df9e84ad51a599e70955a7f3015c08dc0958c126 100644 (file)
@@ -30,16 +30,46 @@ be ``mysqldb``. ``mariadb+mariadbconnector://`` is required to use this driver.
 
 """  # noqa
 import re
+from uuid import UUID as _python_UUID
 
 from .base import MySQLCompiler
 from .base import MySQLDialect
 from .base import MySQLExecutionContext
 from ... import sql
 from ... import util
+from ...sql import sqltypes
+
 
 mariadb_cpy_minimum_version = (1, 0, 1)
 
 
+class _MariaDBUUID(sqltypes.UUID[sqltypes._UUID_RETURN]):
+    # work around JIRA issue
+    # https://jira.mariadb.org/browse/CONPY-270.  When that issue is fixed,
+    # this type can be removed.
+    def result_processor(self, dialect, coltype):
+        if self.as_uuid:
+
+            def process(value):
+                if value is not None:
+                    if hasattr(value, "decode"):
+                        value = value.decode("ascii")
+                    value = _python_UUID(value)
+                return value
+
+            return process
+        else:
+
+            def process(value):
+                if value is not None:
+                    if hasattr(value, "decode"):
+                        value = value.decode("ascii")
+                    value = str(_python_UUID(value))
+                return value
+
+            return process
+
+
 class MySQLExecutionContext_mariadbconnector(MySQLExecutionContext):
     _lastrowid = None
 
@@ -50,9 +80,18 @@ class MySQLExecutionContext_mariadbconnector(MySQLExecutionContext):
         return self._dbapi_connection.cursor(buffered=True)
 
     def post_exec(self):
+        self._rowcount = self.cursor.rowcount
+
         if self.isinsert and self.compiled.postfetch_lastrowid:
             self._lastrowid = self.cursor.lastrowid
 
+    @property
+    def rowcount(self):
+        if self._rowcount is not None:
+            return self._rowcount
+        else:
+            return self.cursor.rowcount
+
     def get_lastrowid(self):
         return self._lastrowid
 
@@ -87,6 +126,10 @@ class MySQLDialect_mariadbconnector(MySQLDialect):
 
     supports_server_side_cursors = True
 
+    colspecs = util.update_copy(
+        MySQLDialect.colspecs, {sqltypes.Uuid: _MariaDBUUID}
+    )
+
     @util.memoized_property
     def _dbapi_version(self):
         if self.dbapi and hasattr(self.dbapi, "__version__"):
index da51f35a810768fab2c0c1548d7adc8c4ca9320a..a73d1699de5b4b068239feedc72603e467c5441d 100644 (file)
@@ -978,8 +978,8 @@ class OracleDialect_cx_oracle(OracleDialect):
 
     driver = "cx_oracle"
 
-    colspecs = OracleDialect.colspecs
-    colspecs.update(
+    colspecs = util.update_copy(
+        OracleDialect.colspecs,
         {
             sqltypes.TIMESTAMP: _CXOracleTIMESTAMP,
             sqltypes.Numeric: _OracleNumeric,
@@ -1006,7 +1006,7 @@ class OracleDialect_cx_oracle(OracleDialect):
             sqltypes.Uuid: _OracleUUID,
             oracle.NCLOB: _OracleUnicodeTextNCLOB,
             oracle.ROWID: _OracleRowid,
-        }
+        },
     )
 
     execute_sequence_format = list
index 246cf6fe78009537ed35048aa0ae346525abdb93..45af49afccb877f14644d61fe18b2d04e9c4b479 100644 (file)
@@ -1995,7 +1995,7 @@ class CursorResult(Result[_T]):
            * :attr:`_engine.CursorResult.rowcount`
              is *only* useful in conjunction
              with an UPDATE or DELETE statement.  Contrary to what the Python
-             DBAPI says, it does *not* return the
+             DBAPI says, it does *not* reliably return the
              number of rows available from the results of a SELECT statement
              as DBAPIs cannot support this functionality when rows are
              unbuffered.
index e164605e4fdfd96e942154a7951de2210987f570..a893e30334733501dcb9f75f4abfbb57012b6cbe 100644 (file)
@@ -394,7 +394,7 @@ class ReturningTest(fixtures.TablesTest):
             True,
             testing.requires.float_or_double_precision_behaves_generically,
         ),
-        (Float(), 8.5514, False),
+        (Float(), 8.5514, True),
         (
             Float(8),
             8.5514,
index ba8e1043772caa50c1a988aab1948307442c0695..58295a5c531570bac794696b7c85c73a6f6ad832 100644 (file)
@@ -66,6 +66,49 @@ class RowCountTest(fixtures.TablesTest):
 
         eq_(rows, self.data)
 
+    @testing.variation("statement", ["update", "delete", "insert", "select"])
+    @testing.variation("close_first", [True, False])
+    def test_non_rowcount_scenarios_no_raise(
+        self, connection, statement, close_first
+    ):
+        employees_table = self.tables.employees
+
+        # WHERE matches 3, 3 rows changed
+        department = employees_table.c.department
+
+        if statement.update:
+            r = connection.execute(
+                employees_table.update().where(department == "C"),
+                {"department": "Z"},
+            )
+        elif statement.delete:
+            r = connection.execute(
+                employees_table.delete().where(department == "C"),
+                {"department": "Z"},
+            )
+        elif statement.insert:
+            r = connection.execute(
+                employees_table.insert(),
+                [
+                    {"employee_id": 25, "name": "none 1", "department": "X"},
+                    {"employee_id": 26, "name": "none 2", "department": "Z"},
+                    {"employee_id": 27, "name": "none 3", "department": "Z"},
+                ],
+            )
+        elif statement.select:
+            s = select(
+                employees_table.c.name, employees_table.c.department
+            ).where(employees_table.c.department == "C")
+            r = connection.execute(s)
+            r.all()
+        else:
+            statement.fail()
+
+        if close_first:
+            r.close()
+
+        assert r.rowcount in (-1, 3)
+
     def test_update_rowcount1(self, connection):
         employees_table = self.tables.employees
 
index 95083aed0b5ed1f85d1290d8e8133e053e5f0142..88798d6cd7b6019fa88a2552a5dc4501126eb52d 100644 (file)
@@ -1246,14 +1246,16 @@ class DefaultRequirements(SuiteRequirements):
 
         """
 
-        # mariadbconnector works.  pyodbc we dont know, not supported in
-        # testing.
+        # this may have worked with mariadbconnector at some point, but
+        # this now seems to not be the case.   Since no other mysql driver
+        # supports these tests, that's fine
         return exclusions.fails_on(
             [
                 "+mysqldb",
                 "+pymysql",
                 "+asyncmy",
                 "+mysqlconnector",
+                "+mariadbconnector",
                 "+cymysql",
                 "+aiomysql",
             ]
diff --git a/tox.ini b/tox.ini
index 1e0a5769832641131de8e5f39b8cb3d22ae04d72..895419d2894e965c8fb920b1b1600dcdd31659a6 100644 (file)
--- a/tox.ini
+++ b/tox.ini
@@ -26,6 +26,7 @@ extras=
      mysql: pymysql
      mysql: asyncmy
      mysql: aiomysql
+     mysql: mariadb_connector
 
      oracle: oracle
      oracle: oracle_oracledb
@@ -131,9 +132,7 @@ setenv=
     memusage: WORKERS={env:TOX_WORKERS:-n2}
 
     mysql: MYSQL={env:TOX_MYSQL:--db mysql}
-    mysql: EXTRA_MYSQL_DRIVERS={env:EXTRA_MYSQL_DRIVERS:--dbdriver mysqldb --dbdriver pymysql}
-
-    mysql: EXTRA_MYSQL_DRIVERS={env:EXTRA_MYSQL_DRIVERS:--dbdriver mysqldb --dbdriver pymysql --dbdriver asyncmy --dbdriver aiomysql}
+    mysql: EXTRA_MYSQL_DRIVERS={env:EXTRA_MYSQL_DRIVERS:--dbdriver mysqldb --dbdriver pymysql --dbdriver asyncmy --dbdriver aiomysql --dbdriver mariadbconnector}
 
     mssql: MSSQL={env:TOX_MSSQL:--db mssql}
     py{3,37,38,39,310,311}-mssql: EXTRA_MSSQL_DRIVERS={env:EXTRA_MSSQL_DRIVERS:--dbdriver pyodbc --dbdriver pymssql}