]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Fix is_disconnect false positive for mssql+pyodbc
authorGord Thompson <gord@gordthompson.com>
Fri, 29 May 2020 13:20:54 +0000 (07:20 -0600)
committerGord Thompson <gord@gordthompson.com>
Mon, 1 Jun 2020 12:02:20 +0000 (06:02 -0600)
Fixed an issue where the ``is_disconnect`` function in the SQL Server
pyodbc dialect was incorrectly reporting the disconnect state when the
exception messsage had a substring that matched a SQL Server ODBC error
code.

Fixes: #5359
Change-Id: I450c6818405a20f4daee20d58fce2d5ecb33e17f

doc/build/changelog/unreleased_13/5359.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/mssql/pyodbc.py
test/dialect/mssql/test_engine.py

diff --git a/doc/build/changelog/unreleased_13/5359.rst b/doc/build/changelog/unreleased_13/5359.rst
new file mode 100644 (file)
index 0000000..b5f690d
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: bug, mssql
+    :tickets: 5359
+
+    Fixed an issue where the ``is_disconnect`` function in the SQL Server
+    pyodbc dialect was incorrectly reporting the disconnect state when the
+    exception messsage had a substring that matched a SQL Server ODBC error
+    code.
\ No newline at end of file
index ff164e8868d6707274498db048640074483e770f..6cf45e7b8e83a2565207c7cd3180d09b1b21195b 100644 (file)
@@ -419,7 +419,8 @@ class MSDialect_pyodbc(PyODBCConnector, MSDialect):
 
     def is_disconnect(self, e, connection, cursor):
         if isinstance(e, self.dbapi.Error):
-            for code in (
+            code = e.args[0]
+            if code in (
                 "08S01",
                 "01002",
                 "08003",
@@ -430,8 +431,7 @@ class MSDialect_pyodbc(PyODBCConnector, MSDialect):
                 "HY010",
                 "10054",
             ):
-                if code in str(e):
-                    return True
+                return True
         return super(MSDialect_pyodbc, self).is_disconnect(
             e, connection, cursor
         )
index 0dea2688a7901334c61a6a561e7f66788ddf6a44..734224ed1bf6b0752eef41030af14a93209d9e4b 100644 (file)
@@ -1,4 +1,5 @@
 # -*- encoding: utf-8
+
 from sqlalchemy import Column
 from sqlalchemy import engine_from_config
 from sqlalchemy import event
@@ -11,6 +12,8 @@ from sqlalchemy.dialects.mssql import base
 from sqlalchemy.dialects.mssql import pymssql
 from sqlalchemy.dialects.mssql import pyodbc
 from sqlalchemy.engine import url
+from sqlalchemy.exc import IntegrityError
+from sqlalchemy.testing import assert_raises
 from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.testing import assert_warnings
 from sqlalchemy.testing import engines
@@ -294,7 +297,7 @@ class ParseConnectTest(fixtures.TestBase):
         )
 
         for error in [
-            MockDBAPIError("[%s] some pyodbc message" % code)
+            MockDBAPIError(code, "[%s] some pyodbc message" % code)
             for code in [
                 "08S01",
                 "01002",
@@ -316,7 +319,9 @@ class ParseConnectTest(fixtures.TestBase):
 
         eq_(
             dialect.is_disconnect(
-                MockProgrammingError("not an error"), None, None
+                MockProgrammingError("Query with abc08007def failed"),
+                None,
+                None,
             ),
             False,
         )
@@ -511,3 +516,39 @@ class IsolationLevelDetectTest(fixtures.TestBase):
                 dialect.get_isolation_level,
                 connection,
             )
+
+
+class InvalidTransactionFalsePositiveTest(fixtures.TablesTest):
+    __only_on__ = "mssql"
+    __backend__ = True
+
+    @classmethod
+    def define_tables(cls, metadata):
+        Table(
+            "error_t",
+            metadata,
+            Column("error_code", String(50), primary_key=True),
+        )
+
+    @classmethod
+    def insert_data(cls, connection):
+        connection.execute(
+            cls.tables.error_t.insert(), [{"error_code": "01002"}],
+        )
+
+    def test_invalid_transaction_detection(self, connection):
+        # issue #5359
+        t = self.tables.error_t
+
+        # force duplicate PK error
+        assert_raises(
+            IntegrityError,
+            connection.execute,
+            t.insert(),
+            {"error_code": "01002"},
+        )
+
+        # this should not fail with
+        # "Can't reconnect until invalid transaction is rolled back."
+        result = connection.execute(t.select()).fetchall()
+        eq_(len(result), 1)