]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Make column-level collation quoting dialect-specific
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 10 Jan 2018 03:17:59 +0000 (22:17 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 12 Jan 2018 18:01:26 +0000 (13:01 -0500)
Fixed regression in 1.2 where newly repaired quoting
of collation names in :ticket:`3785` breaks SQL Server,
which explicitly does not understand a quoted collation
name.   Whether or not mixed-case collation names are
quoted or not is now deferred down to a dialect-level
decision so that each dialect can prepare these identifiers
directly.

Change-Id: Iaf0a8123d9bf4711219e320896bb28c5d2649304
Fixes: #4154
doc/build/changelog/unreleased_12/4154.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/mssql/base.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/testing/requirements.py
lib/sqlalchemy/testing/suite/test_select.py
test/dialect/mssql/test_compiler.py
test/requirements.py
test/sql/test_compiler.py
test/sql/test_quote.py

diff --git a/doc/build/changelog/unreleased_12/4154.rst b/doc/build/changelog/unreleased_12/4154.rst
new file mode 100644 (file)
index 0000000..523e330
--- /dev/null
@@ -0,0 +1,11 @@
+.. change::
+    :tags: bug, mssql
+    :tickets: 4154
+
+    Fixed regression in 1.2 where newly repaired quoting
+    of collation names in :ticket:`3785` breaks SQL Server,
+    which explicitly does not understand a quoted collation
+    name.   Whether or not mixed-case collation names are
+    quoted or not is now deferred down to a dialect-level
+    decision so that each dialect can prepare these identifiers
+    directly.
index 9f4e7a9c4151e934b720a00f59f774cd3a406b52..e72ca06b0a4223eaa5f642c7f5ecdcba44a0b603 100644 (file)
@@ -1713,13 +1713,13 @@ class MSIdentifierPreparer(compiler.IdentifierPreparer):
     reserved_words = RESERVED_WORDS
 
     def __init__(self, dialect):
-        super(MSIdentifierPreparer, self).__init__(dialect, initial_quote='[',
-                                                   final_quote=']')
+        super(MSIdentifierPreparer, self).__init__(
+            dialect, initial_quote='[',
+            final_quote=']', quote_case_sensitive_collations=False)
 
     def _escape_identifier(self, value):
         return value
 
-
     def quote_schema(self, schema, force=None):
         """Prepare a quoted table and schema name."""
 
index cb058affa28a4c5bb0dbd1f586d0e010ccdda1ef..9411329a155a1fa01f8da0b4f44b84a42ddaf12f 100644 (file)
@@ -733,6 +733,9 @@ class SQLCompiler(Compiled):
                 self.preparer.quote(tablename) + \
                 "." + name
 
+    def visit_collation(self, element, **kw):
+        return self.preparer.format_collation(element.collation)
+
     def visit_fromclause(self, fromclause, **kwargs):
         return fromclause.name
 
@@ -2961,7 +2964,8 @@ class IdentifierPreparer(object):
     schema_for_object = schema._schema_getter(None)
 
     def __init__(self, dialect, initial_quote='"',
-                 final_quote=None, escape_quote='"', omit_schema=False):
+                 final_quote=None, escape_quote='"',
+                 quote_case_sensitive_collations=True, omit_schema=False):
         """Construct a new ``IdentifierPreparer`` object.
 
         initial_quote
@@ -2982,6 +2986,7 @@ class IdentifierPreparer(object):
         self.escape_quote = escape_quote
         self.escape_to_quote = self.escape_quote * 2
         self.omit_schema = omit_schema
+        self.quote_case_sensitive_collations = quote_case_sensitive_collations
         self._strings = {}
         self._double_percents = self.dialect.paramstyle in ('format', 'pyformat')
 
@@ -3064,6 +3069,12 @@ class IdentifierPreparer(object):
         else:
             return ident
 
+    def format_collation(self, collation_name):
+        if self.quote_case_sensitive_collations:
+            return self.quote(collation_name)
+        else:
+            return collation_name
+
     def format_sequence(self, sequence, use_schema=True):
         name = self.quote(sequence.name)
 
index 2cc1d9c42382939a17bcc8beb457b4a1c78c5b6d..fd2c9c0bd123a87cc9c2cc808112fb9dae678c7a 100644 (file)
@@ -52,7 +52,7 @@ def collate(expression, collation):
     expr = _literal_as_binds(expression)
     return BinaryExpression(
         expr,
-        ColumnClause(collation),
+        CollationClause(collation),
         operators.collate, type_=expr.type)
 
 
@@ -3873,6 +3873,13 @@ class ColumnClause(Immutable, ColumnElement):
         return c
 
 
+class CollationClause(ColumnElement):
+    __visit_name__ = "collation"
+
+    def __init__(self, collation):
+        self.collation = collation
+
+
 class _IdentifiedClause(Executable, ClauseElement):
 
     __visit_name__ = 'identified'
index b89d149d698d702893c937834d7f8422092e55c0..cc9e074efafd66ed01f9b43b1b78e68eaf089aad 100644 (file)
@@ -756,6 +756,20 @@ class SuiteRequirements(Requirements):
         """
         return exclusions.closed()
 
+    @property
+    def order_by_collation(self):
+        def check(config):
+            try:
+                self.get_order_by_collation(config)
+                return False
+            except NotImplementedError:
+                return True
+
+        return exclusions.skip_if(check)
+
+    def get_order_by_collation(self, config):
+        raise NotImplementedError()
+
     @property
     def unicode_connections(self):
         """Target driver must support non-ASCII characters being passed at
index df638c140d240aa16f9ed77ab0c8c6c117e1595f..d9755c8f972c6b38629dc6234d5aebd56d61c7e7 100644 (file)
@@ -9,6 +9,46 @@ from sqlalchemy import literal_column
 from ..schema import Table, Column
 
 
+class CollateTest(fixtures.TablesTest):
+    __backend__ = True
+
+    @classmethod
+    def define_tables(cls, metadata):
+        Table("some_table", metadata,
+              Column('id', Integer, primary_key=True),
+              Column('data', String(100))
+              )
+
+    @classmethod
+    def insert_data(cls):
+        config.db.execute(
+            cls.tables.some_table.insert(),
+            [
+                {"id": 1, "data": "collate data1"},
+                {"id": 2, "data": "collate data2"},
+            ]
+        )
+
+    def _assert_result(self, select, result):
+        eq_(
+            config.db.execute(select).fetchall(),
+            result
+        )
+
+    @testing.requires.order_by_collation
+    def test_collate_order_by(self):
+        collation = testing.requires.get_order_by_collation(testing.config)
+
+        self._assert_result(
+            select([self.tables.some_table]).
+            order_by(self.tables.some_table.c.data.collate(collation).asc()),
+            [
+                (1, "collate data1"),
+                (2, "collate data2"),
+            ]
+        )
+
+
 class OrderByLabelTest(fixtures.TablesTest):
     """Test the dialect sends appropriate ORDER BY expressions when
     labels are used.
index d62753b9d74c4b423436abcba598d8d9bfcf8e75..e9f9afef5b647d2fa8ff0b398e93746f69c5ede2 100644 (file)
@@ -45,6 +45,20 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
             'SELECT test_schema.sometable.somecolumn '
             'FROM test_schema.sometable WITH (NOLOCK)')
 
+    def test_select_w_order_by_collate(self):
+        m = MetaData()
+        t = Table('sometable', m, Column('somecolumn', String))
+
+        self.assert_compile(
+            select([t]).
+            order_by(
+                t.c.somecolumn.collate("Latin1_General_CS_AS_KS_WS_CI").asc()),
+            "SELECT sometable.somecolumn FROM sometable "
+            "ORDER BY sometable.somecolumn COLLATE "
+            "Latin1_General_CS_AS_KS_WS_CI ASC"
+
+        )
+
     def test_join_with_hint(self):
         t1 = table('t1',
                    column('a', Integer),
index 3cbc5aaada675efa73e7a0d29d671b04116c7934..4be91b938d6c43e7faed2f1bc58a8f8d8c80abaa 100644 (file)
@@ -946,6 +946,23 @@ class DefaultRequirements(SuiteRequirements):
             ('mssql', None, None, 'only simple labels allowed')
         ])
 
+    def get_order_by_collation(self, config):
+        lookup = {
+
+            # will raise without quoting
+            "postgresql": "POSIX",
+
+            "mysql": "latin1_general_ci",
+            "sqlite": "NOCASE",
+
+            # will raise *with* quoting
+            "mssql": "Latin1_General_CI_AS"
+        }
+        try:
+            return lookup[config.db.name]
+        except KeyError:
+            raise NotImplementedError()
+
     @property
     def skip_mysql_on_windows(self):
         """Catchall for a large variety of MySQL on Windows failures"""
index 988230ac5c66d8ef3dfcbf7c1216bbc2976088fe..25eb2b24b6241b428ea64d73aa8f012d2b165536 100644 (file)
@@ -1450,6 +1450,25 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
                 []).compile,
             dialect=empty_in_dialect)
 
+    def test_collate(self):
+        # columns clause
+        self.assert_compile(
+            select([column('x').collate('bar')]),
+            "SELECT x COLLATE bar AS anon_1"
+        )
+
+        # WHERE clause
+        self.assert_compile(
+            select([column('x')]).where(column('x').collate('bar') == 'foo'),
+            "SELECT x WHERE (x COLLATE bar) = :param_1"
+        )
+
+        # ORDER BY clause
+        self.assert_compile(
+            select([column('x')]).order_by(column('x').collate('bar')),
+            "SELECT x ORDER BY x COLLATE bar"
+        )
+
     def test_literal(self):
 
         self.assert_compile(select([literal('foo')]),
index 477fca7836998f6ecaede4444a4409b28a68a2ba..a51e14244eeade143d1ce538e9011fdcc9fc997a 100644 (file)
@@ -1,6 +1,6 @@
 from sqlalchemy import MetaData, Table, Column, Integer, select, \
     ForeignKey, Index, CheckConstraint, inspect, column
-from sqlalchemy import sql, schema
+from sqlalchemy import sql, schema, types as sqltypes
 from sqlalchemy.sql import compiler
 from sqlalchemy.testing import fixtures, AssertsCompiledSQL, eq_
 from sqlalchemy import testing
@@ -462,7 +462,8 @@ class QuoteTest(fixtures.TestBase, AssertsCompiledSQL):
 
         self.assert_compile(
             column('foo').collate('fr_FR'),
-            'foo COLLATE "fr_FR"'
+            'foo COLLATE "fr_FR"',
+            dialect="postgresql"
         )
 
         self.assert_compile(
@@ -471,6 +472,12 @@ class QuoteTest(fixtures.TestBase, AssertsCompiledSQL):
             dialect="mysql"
         )
 
+        self.assert_compile(
+            column('foo').collate('SQL_Latin1_General_CP1_CI_AS'),
+            'foo COLLATE SQL_Latin1_General_CP1_CI_AS',
+            dialect="mssql"
+        )
+
     def test_join(self):
         # Lower case names, should not quote
         metadata = MetaData()