]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
further qualify pyodbc setinputsizes types for long stirngs
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 18 Oct 2022 13:44:37 +0000 (09:44 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 18 Oct 2022 18:11:11 +0000 (14:11 -0400)
Fixed regression caused by SQL Server pyodbc change :ticket:`8177` where we
now use ``setinputsizes()`` by default; for VARCHAR, this fails if the
character size is greater than 4000 (or 2000, depending on data) characters
as the incoming datatype is NVARCHAR, which has a limit of 4000 characters,
despite the fact that VARCHAR can handle unlimited characters. Additional
pyodbc-specific typing information is now passed to ``setinputsizes()``
when the datatype's size is > 2000 characters. The change is also applied
to the :class:`.JSON` type which was also impacted by this issue for large
JSON serializations.

Fixes: #8661
Change-Id: I07fa873e95dbd2c94f3d286e93e8b3229c3a9807

doc/build/changelog/unreleased_20/8661.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/mssql/pyodbc.py
lib/sqlalchemy/testing/config.py
lib/sqlalchemy/testing/suite/test_types.py
test/dialect/mssql/test_types.py

diff --git a/doc/build/changelog/unreleased_20/8661.rst b/doc/build/changelog/unreleased_20/8661.rst
new file mode 100644 (file)
index 0000000..80dcc5e
--- /dev/null
@@ -0,0 +1,13 @@
+.. change::
+    :tags: bug, mssql
+    :tickets: 8661
+
+    Fixed regression caused by SQL Server pyodbc change :ticket:`8177` where we
+    now use ``setinputsizes()`` by default; for VARCHAR, this fails if the
+    character size is greater than 4000 (or 2000, depending on data) characters
+    as the incoming datatype is NVARCHAR, which has a limit of 4000 characters,
+    despite the fact that VARCHAR can handle unlimited characters. Additional
+    pyodbc-specific typing information is now passed to ``setinputsizes()``
+    when the datatype's size is > 2000 characters. The change is also applied
+    to the :class:`.JSON` type which was also impacted by this issue for large
+    JSON serializations.
index 5eb6b952824f3d5d8f79e3c4a565c00e6bdfcf05..09b4a80b6e34009a8a560a1a3ff0cface0af1a8e 100644 (file)
@@ -516,22 +516,31 @@ class _BINARY_pyodbc(_ms_binary_pyodbc, BINARY):
 
 class _String_pyodbc(sqltypes.String):
     def get_dbapi_type(self, dbapi):
-        return dbapi.SQL_VARCHAR
+        if self.length in (None, "max") or self.length >= 2000:
+            return (dbapi.SQL_VARCHAR, 0, 0)
+        else:
+            return dbapi.SQL_VARCHAR
 
 
 class _Unicode_pyodbc(_MSUnicode):
     def get_dbapi_type(self, dbapi):
-        return dbapi.SQL_WVARCHAR
+        if self.length in (None, "max") or self.length >= 2000:
+            return (dbapi.SQL_WVARCHAR, 0, 0)
+        else:
+            return dbapi.SQL_WVARCHAR
 
 
 class _UnicodeText_pyodbc(_MSUnicodeText):
     def get_dbapi_type(self, dbapi):
-        return dbapi.SQL_WVARCHAR
+        if self.length in (None, "max") or self.length >= 2000:
+            return (dbapi.SQL_WVARCHAR, 0, 0)
+        else:
+            return dbapi.SQL_WVARCHAR
 
 
 class _JSON_pyodbc(_MSJson):
     def get_dbapi_type(self, dbapi):
-        return dbapi.SQL_WVARCHAR
+        return (dbapi.SQL_WVARCHAR, 0, 0)
 
 
 class _JSONIndexType_pyodbc(_MSJsonIndexType):
index 1cb463977b2858a9122889844f5e89cb77ac0b22..44133984bbc6d3e011cba7bad712125d66f509ce 100644 (file)
@@ -12,10 +12,11 @@ from __future__ import annotations
 import collections
 import typing
 from typing import Any
+from typing import Callable
 from typing import Iterable
 from typing import Optional
-from typing import overload
 from typing import Tuple
+from typing import TypeVar
 from typing import Union
 
 from .. import util
@@ -39,16 +40,15 @@ else:
     _fixture_functions = None  # installed by plugin_base
 
 
-@overload
+_FN = TypeVar("_FN", bound=Callable[..., Any])
+
+
 def combinations(
     *comb: Union[Any, Tuple[Any, ...]],
     argnames: Optional[str] = None,
     id_: Optional[str] = None,
-):
-    ...
-
-
-def combinations(*comb: Union[Any, Tuple[Any, ...]], **kw: str):
+    **kw: str,
+) -> Callable[[_FN], _FN]:
     r"""Deliver multiple versions of a test based on positional combinations.
 
     This is a facade over pytest.mark.parametrize.
@@ -111,7 +111,9 @@ def combinations(*comb: Union[Any, Tuple[Any, ...]], **kw: str):
 
 
     """
-    return _fixture_functions.combinations(*comb, **kw)
+    return _fixture_functions.combinations(
+        *comb, id_=id_, argnames=argnames, **kw
+    )
 
 
 def combinations_list(
index 9461298b9fcffeb182ebfc3d6e77de45b5851c99..25ed041c2a620aa1aeb99bee0632e05408f65735 100644 (file)
@@ -1148,6 +1148,21 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest):
     def test_round_trip_data1(self, connection):
         self._test_round_trip({"key1": "value1", "key2": "value2"}, connection)
 
+    @testing.combinations(
+        ("unicode", True), ("ascii", False), argnames="unicode_", id_="ia"
+    )
+    @testing.combinations(100, 1999, 3000, 4000, 5000, 9000, argnames="length")
+    def test_round_trip_pretty_large_data(self, connection, unicode_, length):
+
+        if unicode_:
+            data = "réve🐍illé" * ((length // 9) + 1)
+            data = data[0 : (length // 2)]
+        else:
+            data = "abcdefg" * ((length // 7) + 1)
+            data = data[0:length]
+
+        self._test_round_trip({"key1": data, "key2": data}, connection)
+
     def _test_round_trip(self, data_element, connection):
         data_table = self.tables.data_table
 
index ff84f180bdfcf68549cb4055b16550f6fa719f06..eb14cb30f794a9d82903b1b8215ab2c2f0e98582 100644 (file)
@@ -6,6 +6,7 @@ import os
 
 import sqlalchemy as sa
 from sqlalchemy import Boolean
+from sqlalchemy import cast
 from sqlalchemy import Column
 from sqlalchemy import column
 from sqlalchemy import Date
@@ -18,6 +19,7 @@ from sqlalchemy import LargeBinary
 from sqlalchemy import literal
 from sqlalchemy import MetaData
 from sqlalchemy import Numeric
+from sqlalchemy import NVARCHAR
 from sqlalchemy import PickleType
 from sqlalchemy import schema
 from sqlalchemy import select
@@ -32,6 +34,7 @@ from sqlalchemy import types
 from sqlalchemy import Unicode
 from sqlalchemy import UnicodeText
 from sqlalchemy.dialects.mssql import base as mssql
+from sqlalchemy.dialects.mssql import NTEXT
 from sqlalchemy.dialects.mssql import ROWVERSION
 from sqlalchemy.dialects.mssql import TIMESTAMP
 from sqlalchemy.dialects.mssql import UNIQUEIDENTIFIER
@@ -1236,6 +1239,170 @@ class StringTest(fixtures.TestBase, AssertsCompiledSQL):
         )
 
 
+class StringRoundTripTest(fixtures.TestBase):
+    """tests for #8661
+
+
+    at the moment most of these are using the default setinputsizes enabled
+    behavior, with the exception of the plain executemany() calls for inserts.
+
+
+    """
+
+    __only_on__ = "mssql"
+    __backend__ = True
+
+    @testing.combinations(
+        ("unicode", True), ("ascii", False), argnames="unicode_", id_="ia"
+    )
+    @testing.combinations(
+        String,
+        Unicode,
+        NVARCHAR,
+        NTEXT,
+        Text,
+        UnicodeText,
+        argnames="datatype",
+    )
+    @testing.combinations(
+        100, 1999, 2000, 2001, 3999, 4000, 4001, 5000, argnames="length"
+    )
+    def test_long_strings_inpplace(
+        self, connection, unicode_, length, datatype
+    ):
+        if datatype is NVARCHAR and length != "max" and length > 4000:
+            return
+        elif unicode_ and datatype not in (NVARCHAR, UnicodeText):
+            return
+
+        if datatype in (String, NVARCHAR):
+            dt = datatype(length)
+        else:
+            dt = datatype()
+
+        if length == "max":
+            length = 12000
+
+        if unicode_:
+            data = "réve🐍illé" * ((length // 9) + 1)
+            data = data[0 : (length // 2)]
+        else:
+            data = "abcdefg" * ((length // 7) + 1)
+            data = data[0:length]
+            assert len(data) == length
+
+        stmt = select(cast(literal(data, type_=dt), type_=dt))
+        result = connection.scalar(stmt)
+        eq_(result, data)
+
+    @testing.combinations(
+        ("unicode", True), ("ascii", False), argnames="unicode_", id_="ia"
+    )
+    @testing.combinations(
+        ("returning", True),
+        ("noreturning", False),
+        argnames="use_returning",
+        id_="ia",
+    )
+    @testing.combinations(
+        ("insertmany", True),
+        ("insertsingle", False),
+        argnames="insertmany",
+        id_="ia",
+    )
+    @testing.combinations(
+        String,
+        Unicode,
+        NVARCHAR,
+        NTEXT,
+        Text,
+        UnicodeText,
+        argnames="datatype",
+    )
+    @testing.combinations(
+        100, 1999, 2000, 2001, 3999, 4000, 4001, 5000, "max", argnames="length"
+    )
+    def test_long_strings_in_context(
+        self,
+        connection,
+        metadata,
+        unicode_,
+        length,
+        datatype,
+        use_returning,
+        insertmany,
+    ):
+
+        if datatype is NVARCHAR and length != "max" and length > 4000:
+            return
+        elif unicode_ and datatype not in (NVARCHAR, UnicodeText):
+            return
+
+        if datatype in (String, NVARCHAR):
+            dt = datatype(length)
+        else:
+            dt = datatype()
+
+        t = Table(
+            "t",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            Column("data", dt),
+        )
+
+        t.create(connection)
+
+        if length == "max":
+            length = 12000
+
+        if unicode_:
+            data = "réve🐍illé" * ((length // 9) + 1)
+            data = data[0 : (length // 2)]
+        else:
+            data = "abcdefg" * ((length // 7) + 1)
+            data = data[0:length]
+            assert len(data) == length
+
+        if insertmany:
+            insert_data = [{"data": data}, {"data": data}, {"data": data}]
+            expected_data = [data, data, data]
+        else:
+            insert_data = {"data": data}
+            expected_data = [data]
+
+        if use_returning:
+            result = connection.execute(
+                t.insert().returning(t.c.data), insert_data
+            )
+            eq_(result.scalars().all(), expected_data)
+        else:
+            connection.execute(t.insert(), insert_data)
+
+        result = connection.scalars(select(t.c.data))
+        eq_(result.all(), expected_data)
+
+        # note that deprecate_large_types indicates that UnicodeText
+        # will be fulfilled by NVARCHAR, and not NTEXT.  However if NTEXT
+        # is used directly, it isn't supported in WHERE clauses:
+        # "The data types ntext and (anything, including ntext itself) are
+        # incompatible in the equal to operator."
+        if datatype is NTEXT:
+            return
+
+        # test WHERE criteria
+        connection.execute(
+            t.insert(), [{"data": "some other data %d" % i} for i in range(3)]
+        )
+
+        result = connection.scalars(select(t.c.data).where(t.c.data == data))
+        eq_(result.all(), expected_data)
+
+        result = connection.scalars(
+            select(t.c.data).where(t.c.data != data).order_by(t.c.id)
+        )
+        eq_(result.all(), ["some other data %d" % i for i in range(3)])
+
+
 class UniqueIdentifierTest(test_types.UuidTest):
     __only_on__ = "mssql"
     __backend__ = True