]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Render N'' for SQL Server unicode literals
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 13 Jan 2019 23:15:52 +0000 (18:15 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 15 Jan 2019 15:04:09 +0000 (10:04 -0500)
The ``literal_processor`` for the :class:`.Unicode` and
:class:`.UnicodeText` datatypes now render an ``N`` character in front of
the literal string expression as required by SQL Server for Unicode string
values rendered in SQL expressions.

Note that this adds full unicode characters to the standard test suite,
which means we also need to bump MySQL provisioning up to utf8mb4.
Modern installs do not seem to be reproducing the 1271 issue locally,
if it reproduces in CI it would be better for us to skip those ORM-centric
tests for MySQL.

Also remove unused _StringType from SQL Server dialect

Fixes: #4442
Change-Id: Id55817b3e8a2d81ddc8b7b27f85e3f1dcc1cea7e

doc/build/changelog/unreleased_13/4442.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/mssql/base.py
lib/sqlalchemy/testing/provision.py
lib/sqlalchemy/testing/requirements.py
lib/sqlalchemy/testing/suite/test_types.py
setup.cfg
test/dialect/mssql/test_types.py
test/requirements.py

diff --git a/doc/build/changelog/unreleased_13/4442.rst b/doc/build/changelog/unreleased_13/4442.rst
new file mode 100644 (file)
index 0000000..ea8c3b3
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+   :tags: bug, mssql
+   :tickets: 4222
+
+   The ``literal_processor`` for the :class:`.Unicode` and
+   :class:`.UnicodeText` datatypes now render an ``N`` character in front of
+   the literal string expression as required by SQL Server for Unicode string
+   values rendered in SQL expressions.
index d98915c0fc01b582a5f91f5371c2cc6b72124e48..bdbc6254b45a84cceb6ca39784952fb576e50383 100644 (file)
@@ -1003,12 +1003,26 @@ class DATETIMEOFFSET(sqltypes.TypeEngine):
         self.precision = precision
 
 
-class _StringType(object):
+class _UnicodeLiteral(object):
+    def literal_processor(self, dialect):
+        def process(value):
+
+            value = value.replace("'", "''")
+
+            if dialect.identifier_preparer._double_percents:
+                value = value.replace("%", "%%")
+
+            return "N'%s'" % value
+
+        return process
+
 
-    """Base for MSSQL string types."""
+class _MSUnicode(_UnicodeLiteral, sqltypes.Unicode):
+    pass
 
-    def __init__(self, collation=None):
-        super(_StringType, self).__init__(collation=collation)
+
+class _MSUnicodeText(_UnicodeLiteral, sqltypes.UnicodeText):
+    pass
 
 
 class TIMESTAMP(sqltypes._Binary):
@@ -2124,6 +2138,8 @@ class MSDialect(default.DefaultDialect):
         sqltypes.DateTime: _MSDateTime,
         sqltypes.Date: _MSDate,
         sqltypes.Time: TIME,
+        sqltypes.Unicode: _MSUnicode,
+        sqltypes.UnicodeText: _MSUnicodeText,
     }
 
     engine_config_types = default.DefaultDialect.engine_config_types.union(
index 88dc28528a6262f1f087f48f4e9bfb12a7d316f9..70ace05110e851aead79fc97b90309dabb972877 100644 (file)
@@ -207,15 +207,12 @@ def _mysql_create_db(cfg, eng, ident):
         except Exception:
             pass
 
-        # using utf8mb4 we are getting collation errors on UNIONS:
-        # test/orm/inheritance/test_polymorphic_rel.py"
-        # 1271, u"Illegal mix of collations for operation 'UNION'"
-        conn.execute("CREATE DATABASE %s CHARACTER SET utf8mb3" % ident)
+        conn.execute("CREATE DATABASE %s CHARACTER SET utf8mb4" % ident)
         conn.execute(
-            "CREATE DATABASE %s_test_schema CHARACTER SET utf8mb3" % ident
+            "CREATE DATABASE %s_test_schema CHARACTER SET utf8mb4" % ident
         )
         conn.execute(
-            "CREATE DATABASE %s_test_schema_2 CHARACTER SET utf8mb3" % ident
+            "CREATE DATABASE %s_test_schema_2 CHARACTER SET utf8mb4" % ident
         )
 
 
index a17d26edb8ad4620e16c625a470503ba68a5d934..3a216174065086fddfdd108f0a63478562dfdd15 100644 (file)
@@ -732,6 +732,13 @@ class SuiteRequirements(Requirements):
 
         return exclusions.open()
 
+    @property
+    def expressions_against_unbounded_text(self):
+        """target database supports use of an unbounded textual field in a
+        WHERE clause."""
+
+        return exclusions.open()
+
     @property
     def selectone(self):
         """target driver must support the literal statement 'select 1'"""
index ff8db5897cf47109b7fc10526601c639fa791524..4791671f31d9a2d01b28f9f1b25f0514640ebd67 100644 (file)
@@ -38,6 +38,8 @@ from ...util import u
 
 
 class _LiteralRoundTripFixture(object):
+    supports_whereclause = True
+
     @testing.provide_metadata
     def _literal_round_trip(self, type_, input_, output, filter_=None):
         """test literal rendering """
@@ -49,33 +51,47 @@ class _LiteralRoundTripFixture(object):
         t = Table("t", self.metadata, Column("x", type_))
         t.create()
 
-        for value in input_:
-            ins = (
-                t.insert()
-                .values(x=literal(value))
-                .compile(
-                    dialect=testing.db.dialect,
-                    compile_kwargs=dict(literal_binds=True),
+        with testing.db.connect() as conn:
+            for value in input_:
+                ins = (
+                    t.insert()
+                    .values(x=literal(value))
+                    .compile(
+                        dialect=testing.db.dialect,
+                        compile_kwargs=dict(literal_binds=True),
+                    )
                 )
-            )
-            testing.db.execute(ins)
+                conn.execute(ins)
+
+            if self.supports_whereclause:
+                stmt = t.select().where(t.c.x == literal(value))
+            else:
+                stmt = t.select()
 
-        for row in t.select().execute():
-            value = row[0]
-            if filter_ is not None:
-                value = filter_(value)
-            assert value in output
+            stmt = stmt.compile(
+                dialect=testing.db.dialect,
+                compile_kwargs=dict(literal_binds=True),
+            )
+            for row in conn.execute(stmt):
+                value = row[0]
+                if filter_ is not None:
+                    value = filter_(value)
+                assert value in output
 
 
 class _UnicodeFixture(_LiteralRoundTripFixture):
     __requires__ = ("unicode_data",)
 
     data = u(
-        "Alors vous imaginez ma surprise, au lever du jour, "
-        "quand une drôle de petite voix m’a réveillé. Elle "
-        "disait: « S’il vous plaît… dessine-moi un mouton! »"
+        "Alors vous imaginez ma 🐍 surprise, au lever du jour, "
+        "quand une drôle de petite 🐍 voix m’a réveillé. Elle "
+        "disait: « S’il vous plaît… dessine-moi 🐍 un mouton! »"
     )
 
+    @property
+    def supports_whereclause(self):
+        return config.requirements.expressions_against_unbounded_text.enabled
+
     @classmethod
     def define_tables(cls, metadata):
         Table(
@@ -122,6 +138,11 @@ class _UnicodeFixture(_LiteralRoundTripFixture):
     def test_literal(self):
         self._literal_round_trip(self.datatype, [self.data], [self.data])
 
+    def test_literal_non_ascii(self):
+        self._literal_round_trip(
+            self.datatype, [util.u("réve🐍 illé")], [util.u("réve🐍 illé")]
+        )
+
 
 class UnicodeVarcharTest(_UnicodeFixture, fixtures.TablesTest):
     __requires__ = ("unicode_data",)
@@ -149,6 +170,10 @@ class TextTest(_LiteralRoundTripFixture, fixtures.TablesTest):
     __requires__ = ("text_type",)
     __backend__ = True
 
+    @property
+    def supports_whereclause(self):
+        return config.requirements.expressions_against_unbounded_text.enabled
+
     @classmethod
     def define_tables(cls, metadata):
         Table(
@@ -177,6 +202,11 @@ class TextTest(_LiteralRoundTripFixture, fixtures.TablesTest):
     def test_literal(self):
         self._literal_round_trip(Text, ["some text"], ["some text"])
 
+    def test_literal_non_ascii(self):
+        self._literal_round_trip(
+            Text, [util.u("réve🐍 illé")], [util.u("réve🐍 illé")]
+        )
+
     def test_literal_quoting(self):
         data = """some 'text' hey "hi there" that's text"""
         self._literal_round_trip(Text, [data], [data])
@@ -202,8 +232,15 @@ class StringTest(_LiteralRoundTripFixture, fixtures.TestBase):
         foo.drop(config.db)
 
     def test_literal(self):
+        # note that in Python 3, this invokes the Unicode
+        # datatype for the literal part because all strings are unicode
         self._literal_round_trip(String(40), ["some text"], ["some text"])
 
+    def test_literal_non_ascii(self):
+        self._literal_round_trip(
+            String(40), [util.u("réve🐍 illé")], [util.u("réve🐍 illé")]
+        )
+
     def test_literal_quoting(self):
         data = """some 'text' hey "hi there" that's text"""
         self._literal_round_trip(String(40), [data], [data])
@@ -864,8 +901,8 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest):
                 {
                     "name": "r1",
                     "data": {
-                        util.u("réveillé"): util.u("réveillé"),
-                        "data": {"k1": util.u("drôle")},
+                        util.u("réve🐍 illé"): util.u("réve🐍 illé"),
+                        "data": {"k1": util.u("drôl🐍e")},
                     },
                 },
             )
@@ -873,8 +910,8 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest):
             eq_(
                 conn.scalar(select([self.tables.data_table.c.data])),
                 {
-                    util.u("réveillé"): util.u("réveillé"),
-                    "data": {"k1": util.u("drôle")},
+                    util.u("réve🐍 illé"): util.u("réve🐍 illé"),
+                    "data": {"k1": util.u("drôl🐍e")},
                 },
             )
 
index ec6da6057f9c964e50b3ed9a81cdc4fc5380c5db..f0c0d4d5af30568c54e77749652f98c51cca8dc9 100644 (file)
--- a/setup.cfg
+++ b/setup.cfg
@@ -66,8 +66,8 @@ postgresql=postgresql://scott:tiger@127.0.0.1:5432/test
 pg8000=postgresql+pg8000://scott:tiger@127.0.0.1:5432/test
 postgresql_psycopg2cffi=postgresql+psycopg2cffi://scott:tiger@127.0.0.1:5432/test
 
-mysql=mysql://scott:tiger@127.0.0.1:3306/test?charset=utf8
-pymysql=mysql+pymysql://scott:tiger@127.0.0.1:3306/test?charset=utf8
+mysql=mysql://scott:tiger@127.0.0.1:3306/test?charset=utf8mb4
+pymysql=mysql+pymysql://scott:tiger@127.0.0.1:3306/test?charset=utf8mb4
 
 mssql=mssql+pyodbc://scott:tiger@ms_2008
 mssql_pymssql=mssql+pymssql://scott:tiger@ms_2008
index 356057af16037f3c45a5d32bbadbb611c08bdf13..54dd6876a4edd0daf92c2ea0aaf49290fec6987f 100644 (file)
@@ -7,6 +7,7 @@ import os
 import sqlalchemy as sa
 from sqlalchemy import Boolean
 from sqlalchemy import Column
+from sqlalchemy import column
 from sqlalchemy import Date
 from sqlalchemy import DateTime
 from sqlalchemy import DefaultClause
@@ -14,6 +15,7 @@ from sqlalchemy import Float
 from sqlalchemy import inspect
 from sqlalchemy import Integer
 from sqlalchemy import LargeBinary
+from sqlalchemy import literal
 from sqlalchemy import MetaData
 from sqlalchemy import Numeric
 from sqlalchemy import PickleType
@@ -27,7 +29,9 @@ from sqlalchemy import Text
 from sqlalchemy import text
 from sqlalchemy import Time
 from sqlalchemy import types
+from sqlalchemy import Unicode
 from sqlalchemy import UnicodeText
+from sqlalchemy import util
 from sqlalchemy.databases import mssql
 from sqlalchemy.dialects.mssql import ROWVERSION
 from sqlalchemy.dialects.mssql import TIMESTAMP
@@ -37,6 +41,7 @@ from sqlalchemy.dialects.mssql.base import MS_2008_VERSION
 from sqlalchemy.dialects.mssql.base import TIME
 from sqlalchemy.sql import sqltypes
 from sqlalchemy.testing import assert_raises_message
+from sqlalchemy.testing import AssertsCompiledSQL
 from sqlalchemy.testing import AssertsExecutionResults
 from sqlalchemy.testing import ComparesTables
 from sqlalchemy.testing import emits_warning_on
@@ -922,6 +927,49 @@ class TypeRoundTripTest(
                 engine.execute(tbl.delete())
 
 
+class StringTest(fixtures.TestBase, AssertsCompiledSQL):
+    __dialect__ = mssql.dialect()
+
+    def test_unicode_literal_binds(self):
+        self.assert_compile(
+            column("x", Unicode()) == "foo", "x = N'foo'", literal_binds=True
+        )
+
+    def test_unicode_text_literal_binds(self):
+        self.assert_compile(
+            column("x", UnicodeText()) == "foo",
+            "x = N'foo'",
+            literal_binds=True,
+        )
+
+    def test_string_text_literal_binds(self):
+        self.assert_compile(
+            column("x", String()) == "foo", "x = 'foo'", literal_binds=True
+        )
+
+    def test_string_text_literal_binds_explicit_unicode_right(self):
+        self.assert_compile(
+            column("x", String()) == util.u("foo"),
+            "x = 'foo'",
+            literal_binds=True,
+        )
+
+    def test_string_text_explicit_literal_binds(self):
+        # the literal experssion here coerces the right side to
+        # Unicode on Python 3 for plain string, test with unicode
+        # string just to confirm literal is doing this
+        self.assert_compile(
+            column("x", String()) == literal(util.u("foo")),
+            "x = N'foo'",
+            literal_binds=True,
+        )
+
+    def test_text_text_literal_binds(self):
+        self.assert_compile(
+            column("x", Text()) == "foo", "x = 'foo'", literal_binds=True
+        )
+
+
 class BinaryTest(fixtures.TestBase):
     __only_on__ = "mssql"
     __requires__ = ("non_broken_binary",)
index c70169acfaacede0bec9b56f1423d24e6cf41048..c265bb3c9d783bfa72ed9787545c542540e5435c 100644 (file)
@@ -662,6 +662,16 @@ class DefaultRequirements(SuiteRequirements):
 
         return exclusions.open()
 
+    @property
+    def expressions_against_unbounded_text(self):
+        """target database supports use of an unbounded textual field in a
+        WHERE clause."""
+
+        return fails_if(
+            ["oracle"],
+            "ORA-00932: inconsistent datatypes: expected - got CLOB",
+        )
+
     @property
     def unicode_data(self):
         """target drive must support unicode data stored in columns."""
@@ -1173,9 +1183,9 @@ class DefaultRequirements(SuiteRequirements):
         lookup = {
             # will raise without quoting
             "postgresql": "POSIX",
-            # note MySQL databases need to be created w/ utf8mb3 charset
+            # note MySQL databases need to be created w/ utf8mb4 charset
             # for the test suite
-            "mysql": "utf8mb3_bin",
+            "mysql": "utf8mb4_bin",
             "sqlite": "NOCASE",
             # will raise *with* quoting
             "mssql": "Latin1_General_CI_AS",