From: Mike Bayer Date: Tue, 18 Oct 2022 13:44:37 +0000 (-0400) Subject: further qualify pyodbc setinputsizes types for long stirngs X-Git-Tag: rel_2_0_0b2~10^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=f4214975a7deb5e13f8b6cf21e39697821396a7f;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git further qualify pyodbc setinputsizes types for long stirngs Fixed regression caused by SQL Server pyodbc change :ticket:`8177` where we now use ``setinputsizes()`` by default; for VARCHAR, this fails if the character size is greater than 4000 (or 2000, depending on data) characters as the incoming datatype is NVARCHAR, which has a limit of 4000 characters, despite the fact that VARCHAR can handle unlimited characters. Additional pyodbc-specific typing information is now passed to ``setinputsizes()`` when the datatype's size is > 2000 characters. The change is also applied to the :class:`.JSON` type which was also impacted by this issue for large JSON serializations. Fixes: #8661 Change-Id: I07fa873e95dbd2c94f3d286e93e8b3229c3a9807 --- diff --git a/doc/build/changelog/unreleased_20/8661.rst b/doc/build/changelog/unreleased_20/8661.rst new file mode 100644 index 0000000000..80dcc5ee4e --- /dev/null +++ b/doc/build/changelog/unreleased_20/8661.rst @@ -0,0 +1,13 @@ +.. change:: + :tags: bug, mssql + :tickets: 8661 + + Fixed regression caused by SQL Server pyodbc change :ticket:`8177` where we + now use ``setinputsizes()`` by default; for VARCHAR, this fails if the + character size is greater than 4000 (or 2000, depending on data) characters + as the incoming datatype is NVARCHAR, which has a limit of 4000 characters, + despite the fact that VARCHAR can handle unlimited characters. Additional + pyodbc-specific typing information is now passed to ``setinputsizes()`` + when the datatype's size is > 2000 characters. The change is also applied + to the :class:`.JSON` type which was also impacted by this issue for large + JSON serializations. diff --git a/lib/sqlalchemy/dialects/mssql/pyodbc.py b/lib/sqlalchemy/dialects/mssql/pyodbc.py index 5eb6b95282..09b4a80b6e 100644 --- a/lib/sqlalchemy/dialects/mssql/pyodbc.py +++ b/lib/sqlalchemy/dialects/mssql/pyodbc.py @@ -516,22 +516,31 @@ class _BINARY_pyodbc(_ms_binary_pyodbc, BINARY): class _String_pyodbc(sqltypes.String): def get_dbapi_type(self, dbapi): - return dbapi.SQL_VARCHAR + if self.length in (None, "max") or self.length >= 2000: + return (dbapi.SQL_VARCHAR, 0, 0) + else: + return dbapi.SQL_VARCHAR class _Unicode_pyodbc(_MSUnicode): def get_dbapi_type(self, dbapi): - return dbapi.SQL_WVARCHAR + if self.length in (None, "max") or self.length >= 2000: + return (dbapi.SQL_WVARCHAR, 0, 0) + else: + return dbapi.SQL_WVARCHAR class _UnicodeText_pyodbc(_MSUnicodeText): def get_dbapi_type(self, dbapi): - return dbapi.SQL_WVARCHAR + if self.length in (None, "max") or self.length >= 2000: + return (dbapi.SQL_WVARCHAR, 0, 0) + else: + return dbapi.SQL_WVARCHAR class _JSON_pyodbc(_MSJson): def get_dbapi_type(self, dbapi): - return dbapi.SQL_WVARCHAR + return (dbapi.SQL_WVARCHAR, 0, 0) class _JSONIndexType_pyodbc(_MSJsonIndexType): diff --git a/lib/sqlalchemy/testing/config.py b/lib/sqlalchemy/testing/config.py index 1cb463977b..44133984bb 100644 --- a/lib/sqlalchemy/testing/config.py +++ b/lib/sqlalchemy/testing/config.py @@ -12,10 +12,11 @@ from __future__ import annotations import collections import typing from typing import Any +from typing import Callable from typing import Iterable from typing import Optional -from typing import overload from typing import Tuple +from typing import TypeVar from typing import Union from .. import util @@ -39,16 +40,15 @@ else: _fixture_functions = None # installed by plugin_base -@overload +_FN = TypeVar("_FN", bound=Callable[..., Any]) + + def combinations( *comb: Union[Any, Tuple[Any, ...]], argnames: Optional[str] = None, id_: Optional[str] = None, -): - ... - - -def combinations(*comb: Union[Any, Tuple[Any, ...]], **kw: str): + **kw: str, +) -> Callable[[_FN], _FN]: r"""Deliver multiple versions of a test based on positional combinations. This is a facade over pytest.mark.parametrize. @@ -111,7 +111,9 @@ def combinations(*comb: Union[Any, Tuple[Any, ...]], **kw: str): """ - return _fixture_functions.combinations(*comb, **kw) + return _fixture_functions.combinations( + *comb, id_=id_, argnames=argnames, **kw + ) def combinations_list( diff --git a/lib/sqlalchemy/testing/suite/test_types.py b/lib/sqlalchemy/testing/suite/test_types.py index 9461298b9f..25ed041c2a 100644 --- a/lib/sqlalchemy/testing/suite/test_types.py +++ b/lib/sqlalchemy/testing/suite/test_types.py @@ -1148,6 +1148,21 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest): def test_round_trip_data1(self, connection): self._test_round_trip({"key1": "value1", "key2": "value2"}, connection) + @testing.combinations( + ("unicode", True), ("ascii", False), argnames="unicode_", id_="ia" + ) + @testing.combinations(100, 1999, 3000, 4000, 5000, 9000, argnames="length") + def test_round_trip_pretty_large_data(self, connection, unicode_, length): + + if unicode_: + data = "réve🐍illé" * ((length // 9) + 1) + data = data[0 : (length // 2)] + else: + data = "abcdefg" * ((length // 7) + 1) + data = data[0:length] + + self._test_round_trip({"key1": data, "key2": data}, connection) + def _test_round_trip(self, data_element, connection): data_table = self.tables.data_table diff --git a/test/dialect/mssql/test_types.py b/test/dialect/mssql/test_types.py index ff84f180bd..eb14cb30f7 100644 --- a/test/dialect/mssql/test_types.py +++ b/test/dialect/mssql/test_types.py @@ -6,6 +6,7 @@ import os import sqlalchemy as sa from sqlalchemy import Boolean +from sqlalchemy import cast from sqlalchemy import Column from sqlalchemy import column from sqlalchemy import Date @@ -18,6 +19,7 @@ from sqlalchemy import LargeBinary from sqlalchemy import literal from sqlalchemy import MetaData from sqlalchemy import Numeric +from sqlalchemy import NVARCHAR from sqlalchemy import PickleType from sqlalchemy import schema from sqlalchemy import select @@ -32,6 +34,7 @@ from sqlalchemy import types from sqlalchemy import Unicode from sqlalchemy import UnicodeText from sqlalchemy.dialects.mssql import base as mssql +from sqlalchemy.dialects.mssql import NTEXT from sqlalchemy.dialects.mssql import ROWVERSION from sqlalchemy.dialects.mssql import TIMESTAMP from sqlalchemy.dialects.mssql import UNIQUEIDENTIFIER @@ -1236,6 +1239,170 @@ class StringTest(fixtures.TestBase, AssertsCompiledSQL): ) +class StringRoundTripTest(fixtures.TestBase): + """tests for #8661 + + + at the moment most of these are using the default setinputsizes enabled + behavior, with the exception of the plain executemany() calls for inserts. + + + """ + + __only_on__ = "mssql" + __backend__ = True + + @testing.combinations( + ("unicode", True), ("ascii", False), argnames="unicode_", id_="ia" + ) + @testing.combinations( + String, + Unicode, + NVARCHAR, + NTEXT, + Text, + UnicodeText, + argnames="datatype", + ) + @testing.combinations( + 100, 1999, 2000, 2001, 3999, 4000, 4001, 5000, argnames="length" + ) + def test_long_strings_inpplace( + self, connection, unicode_, length, datatype + ): + if datatype is NVARCHAR and length != "max" and length > 4000: + return + elif unicode_ and datatype not in (NVARCHAR, UnicodeText): + return + + if datatype in (String, NVARCHAR): + dt = datatype(length) + else: + dt = datatype() + + if length == "max": + length = 12000 + + if unicode_: + data = "réve🐍illé" * ((length // 9) + 1) + data = data[0 : (length // 2)] + else: + data = "abcdefg" * ((length // 7) + 1) + data = data[0:length] + assert len(data) == length + + stmt = select(cast(literal(data, type_=dt), type_=dt)) + result = connection.scalar(stmt) + eq_(result, data) + + @testing.combinations( + ("unicode", True), ("ascii", False), argnames="unicode_", id_="ia" + ) + @testing.combinations( + ("returning", True), + ("noreturning", False), + argnames="use_returning", + id_="ia", + ) + @testing.combinations( + ("insertmany", True), + ("insertsingle", False), + argnames="insertmany", + id_="ia", + ) + @testing.combinations( + String, + Unicode, + NVARCHAR, + NTEXT, + Text, + UnicodeText, + argnames="datatype", + ) + @testing.combinations( + 100, 1999, 2000, 2001, 3999, 4000, 4001, 5000, "max", argnames="length" + ) + def test_long_strings_in_context( + self, + connection, + metadata, + unicode_, + length, + datatype, + use_returning, + insertmany, + ): + + if datatype is NVARCHAR and length != "max" and length > 4000: + return + elif unicode_ and datatype not in (NVARCHAR, UnicodeText): + return + + if datatype in (String, NVARCHAR): + dt = datatype(length) + else: + dt = datatype() + + t = Table( + "t", + metadata, + Column("id", Integer, primary_key=True), + Column("data", dt), + ) + + t.create(connection) + + if length == "max": + length = 12000 + + if unicode_: + data = "réve🐍illé" * ((length // 9) + 1) + data = data[0 : (length // 2)] + else: + data = "abcdefg" * ((length // 7) + 1) + data = data[0:length] + assert len(data) == length + + if insertmany: + insert_data = [{"data": data}, {"data": data}, {"data": data}] + expected_data = [data, data, data] + else: + insert_data = {"data": data} + expected_data = [data] + + if use_returning: + result = connection.execute( + t.insert().returning(t.c.data), insert_data + ) + eq_(result.scalars().all(), expected_data) + else: + connection.execute(t.insert(), insert_data) + + result = connection.scalars(select(t.c.data)) + eq_(result.all(), expected_data) + + # note that deprecate_large_types indicates that UnicodeText + # will be fulfilled by NVARCHAR, and not NTEXT. However if NTEXT + # is used directly, it isn't supported in WHERE clauses: + # "The data types ntext and (anything, including ntext itself) are + # incompatible in the equal to operator." + if datatype is NTEXT: + return + + # test WHERE criteria + connection.execute( + t.insert(), [{"data": "some other data %d" % i} for i in range(3)] + ) + + result = connection.scalars(select(t.c.data).where(t.c.data == data)) + eq_(result.all(), expected_data) + + result = connection.scalars( + select(t.c.data).where(t.c.data != data).order_by(t.c.id) + ) + eq_(result.all(), ["some other data %d" % i for i in range(3)]) + + class UniqueIdentifierTest(test_types.UuidTest): __only_on__ = "mssql" __backend__ = True