From: Mike Bayer Date: Wed, 10 Jan 2018 03:17:59 +0000 (-0500) Subject: Make column-level collation quoting dialect-specific X-Git-Tag: rel_1_2_1~2^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=7402987fd218c42ed2a909a5031186d2b702bb88;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Make column-level collation quoting dialect-specific Fixed regression in 1.2 where newly repaired quoting of collation names in :ticket:`3785` breaks SQL Server, which explicitly does not understand a quoted collation name. Whether or not mixed-case collation names are quoted or not is now deferred down to a dialect-level decision so that each dialect can prepare these identifiers directly. Change-Id: Iaf0a8123d9bf4711219e320896bb28c5d2649304 Fixes: #4154 --- diff --git a/doc/build/changelog/unreleased_12/4154.rst b/doc/build/changelog/unreleased_12/4154.rst new file mode 100644 index 0000000000..523e330374 --- /dev/null +++ b/doc/build/changelog/unreleased_12/4154.rst @@ -0,0 +1,11 @@ +.. change:: + :tags: bug, mssql + :tickets: 4154 + + Fixed regression in 1.2 where newly repaired quoting + of collation names in :ticket:`3785` breaks SQL Server, + which explicitly does not understand a quoted collation + name. Whether or not mixed-case collation names are + quoted or not is now deferred down to a dialect-level + decision so that each dialect can prepare these identifiers + directly. diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index 9f4e7a9c41..e72ca06b0a 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -1713,13 +1713,13 @@ class MSIdentifierPreparer(compiler.IdentifierPreparer): reserved_words = RESERVED_WORDS def __init__(self, dialect): - super(MSIdentifierPreparer, self).__init__(dialect, initial_quote='[', - final_quote=']') + super(MSIdentifierPreparer, self).__init__( + dialect, initial_quote='[', + final_quote=']', quote_case_sensitive_collations=False) def _escape_identifier(self, value): return value - def quote_schema(self, schema, force=None): """Prepare a quoted table and schema name.""" diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index cb058affa2..9411329a15 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -733,6 +733,9 @@ class SQLCompiler(Compiled): self.preparer.quote(tablename) + \ "." + name + def visit_collation(self, element, **kw): + return self.preparer.format_collation(element.collation) + def visit_fromclause(self, fromclause, **kwargs): return fromclause.name @@ -2961,7 +2964,8 @@ class IdentifierPreparer(object): schema_for_object = schema._schema_getter(None) def __init__(self, dialect, initial_quote='"', - final_quote=None, escape_quote='"', omit_schema=False): + final_quote=None, escape_quote='"', + quote_case_sensitive_collations=True, omit_schema=False): """Construct a new ``IdentifierPreparer`` object. initial_quote @@ -2982,6 +2986,7 @@ class IdentifierPreparer(object): self.escape_quote = escape_quote self.escape_to_quote = self.escape_quote * 2 self.omit_schema = omit_schema + self.quote_case_sensitive_collations = quote_case_sensitive_collations self._strings = {} self._double_percents = self.dialect.paramstyle in ('format', 'pyformat') @@ -3064,6 +3069,12 @@ class IdentifierPreparer(object): else: return ident + def format_collation(self, collation_name): + if self.quote_case_sensitive_collations: + return self.quote(collation_name) + else: + return collation_name + def format_sequence(self, sequence, use_schema=True): name = self.quote(sequence.name) diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 2cc1d9c423..fd2c9c0bd1 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -52,7 +52,7 @@ def collate(expression, collation): expr = _literal_as_binds(expression) return BinaryExpression( expr, - ColumnClause(collation), + CollationClause(collation), operators.collate, type_=expr.type) @@ -3873,6 +3873,13 @@ class ColumnClause(Immutable, ColumnElement): return c +class CollationClause(ColumnElement): + __visit_name__ = "collation" + + def __init__(self, collation): + self.collation = collation + + class _IdentifiedClause(Executable, ClauseElement): __visit_name__ = 'identified' diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py index b89d149d69..cc9e074efa 100644 --- a/lib/sqlalchemy/testing/requirements.py +++ b/lib/sqlalchemy/testing/requirements.py @@ -756,6 +756,20 @@ class SuiteRequirements(Requirements): """ return exclusions.closed() + @property + def order_by_collation(self): + def check(config): + try: + self.get_order_by_collation(config) + return False + except NotImplementedError: + return True + + return exclusions.skip_if(check) + + def get_order_by_collation(self, config): + raise NotImplementedError() + @property def unicode_connections(self): """Target driver must support non-ASCII characters being passed at diff --git a/lib/sqlalchemy/testing/suite/test_select.py b/lib/sqlalchemy/testing/suite/test_select.py index df638c140d..d9755c8f97 100644 --- a/lib/sqlalchemy/testing/suite/test_select.py +++ b/lib/sqlalchemy/testing/suite/test_select.py @@ -9,6 +9,46 @@ from sqlalchemy import literal_column from ..schema import Table, Column +class CollateTest(fixtures.TablesTest): + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table("some_table", metadata, + Column('id', Integer, primary_key=True), + Column('data', String(100)) + ) + + @classmethod + def insert_data(cls): + config.db.execute( + cls.tables.some_table.insert(), + [ + {"id": 1, "data": "collate data1"}, + {"id": 2, "data": "collate data2"}, + ] + ) + + def _assert_result(self, select, result): + eq_( + config.db.execute(select).fetchall(), + result + ) + + @testing.requires.order_by_collation + def test_collate_order_by(self): + collation = testing.requires.get_order_by_collation(testing.config) + + self._assert_result( + select([self.tables.some_table]). + order_by(self.tables.some_table.c.data.collate(collation).asc()), + [ + (1, "collate data1"), + (2, "collate data2"), + ] + ) + + class OrderByLabelTest(fixtures.TablesTest): """Test the dialect sends appropriate ORDER BY expressions when labels are used. diff --git a/test/dialect/mssql/test_compiler.py b/test/dialect/mssql/test_compiler.py index d62753b9d7..e9f9afef5b 100644 --- a/test/dialect/mssql/test_compiler.py +++ b/test/dialect/mssql/test_compiler.py @@ -45,6 +45,20 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): 'SELECT test_schema.sometable.somecolumn ' 'FROM test_schema.sometable WITH (NOLOCK)') + def test_select_w_order_by_collate(self): + m = MetaData() + t = Table('sometable', m, Column('somecolumn', String)) + + self.assert_compile( + select([t]). + order_by( + t.c.somecolumn.collate("Latin1_General_CS_AS_KS_WS_CI").asc()), + "SELECT sometable.somecolumn FROM sometable " + "ORDER BY sometable.somecolumn COLLATE " + "Latin1_General_CS_AS_KS_WS_CI ASC" + + ) + def test_join_with_hint(self): t1 = table('t1', column('a', Integer), diff --git a/test/requirements.py b/test/requirements.py index 3cbc5aaada..4be91b938d 100644 --- a/test/requirements.py +++ b/test/requirements.py @@ -946,6 +946,23 @@ class DefaultRequirements(SuiteRequirements): ('mssql', None, None, 'only simple labels allowed') ]) + def get_order_by_collation(self, config): + lookup = { + + # will raise without quoting + "postgresql": "POSIX", + + "mysql": "latin1_general_ci", + "sqlite": "NOCASE", + + # will raise *with* quoting + "mssql": "Latin1_General_CI_AS" + } + try: + return lookup[config.db.name] + except KeyError: + raise NotImplementedError() + @property def skip_mysql_on_windows(self): """Catchall for a large variety of MySQL on Windows failures""" diff --git a/test/sql/test_compiler.py b/test/sql/test_compiler.py index 988230ac5c..25eb2b24b6 100644 --- a/test/sql/test_compiler.py +++ b/test/sql/test_compiler.py @@ -1450,6 +1450,25 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): []).compile, dialect=empty_in_dialect) + def test_collate(self): + # columns clause + self.assert_compile( + select([column('x').collate('bar')]), + "SELECT x COLLATE bar AS anon_1" + ) + + # WHERE clause + self.assert_compile( + select([column('x')]).where(column('x').collate('bar') == 'foo'), + "SELECT x WHERE (x COLLATE bar) = :param_1" + ) + + # ORDER BY clause + self.assert_compile( + select([column('x')]).order_by(column('x').collate('bar')), + "SELECT x ORDER BY x COLLATE bar" + ) + def test_literal(self): self.assert_compile(select([literal('foo')]), diff --git a/test/sql/test_quote.py b/test/sql/test_quote.py index 477fca7836..a51e14244e 100644 --- a/test/sql/test_quote.py +++ b/test/sql/test_quote.py @@ -1,6 +1,6 @@ from sqlalchemy import MetaData, Table, Column, Integer, select, \ ForeignKey, Index, CheckConstraint, inspect, column -from sqlalchemy import sql, schema +from sqlalchemy import sql, schema, types as sqltypes from sqlalchemy.sql import compiler from sqlalchemy.testing import fixtures, AssertsCompiledSQL, eq_ from sqlalchemy import testing @@ -462,7 +462,8 @@ class QuoteTest(fixtures.TestBase, AssertsCompiledSQL): self.assert_compile( column('foo').collate('fr_FR'), - 'foo COLLATE "fr_FR"' + 'foo COLLATE "fr_FR"', + dialect="postgresql" ) self.assert_compile( @@ -471,6 +472,12 @@ class QuoteTest(fixtures.TestBase, AssertsCompiledSQL): dialect="mysql" ) + self.assert_compile( + column('foo').collate('SQL_Latin1_General_CP1_CI_AS'), + 'foo COLLATE SQL_Latin1_General_CP1_CI_AS', + dialect="mssql" + ) + def test_join(self): # Lower case names, should not quote metadata = MetaData()