]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- [feature] Added "collation" parameter to all
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 10 Oct 2012 23:34:29 +0000 (19:34 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 10 Oct 2012 23:34:29 +0000 (19:34 -0400)
    String types.  When present, renders as
    COLLATE <collation>.  This to support the
    COLLATE keyword now supported by several
    databases including MySQL, SQLite, and Postgresql.
    [ticket:2276]

  - [change] The Text() type renders the length
    given to it, if a length was specified.

CHANGES
lib/sqlalchemy/dialects/mysql/base.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/testing/requirements.py
lib/sqlalchemy/types.py
test/dialect/test_mysql.py
test/sql/test_types.py

diff --git a/CHANGES b/CHANGES
index 4dfee65a91ff5d5e3d18d618d0a776e69084557e..af4d1127e4173e94c1ec76e4c28ff80609ab2093 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -519,6 +519,16 @@ underneath "0.7.xx".
     also customizable via the "precedence" argument
     on the ``op()`` method.  [ticket:2537]
 
+  - [feature] Added "collation" parameter to all
+    String types.  When present, renders as
+    COLLATE <collation>.  This to support the
+    COLLATE keyword now supported by several
+    databases including MySQL, SQLite, and Postgresql.
+    [ticket:2276]
+
+  - [change] The Text() type renders the length
+    given to it, if a length was specified.
+
   - [feature] Custom unary operators can now be
     used by combining operators.custom_op() with
     UnaryExpression().
index 1ba567682c6c5b1ed100ac5364e8cb3ce6ca5982..c69ed24e88aaf972aeff2ac6660b2c61ddc42ff9 100644 (file)
@@ -389,8 +389,10 @@ class _StringType(sqltypes.String):
                  ascii=False, binary=False,
                  national=False, **kw):
         self.charset = charset
+
         # allow collate= or collation=
-        self.collation = kw.pop('collate', collation)
+        kw.setdefault('collation', kw.pop('collate', collation))
+
         self.ascii = ascii
         # We have to munge the 'unicode' param strictly as a dict
         # otherwise 2to3 will turn it into str.
@@ -402,19 +404,6 @@ class _StringType(sqltypes.String):
         self.national = national
         super(_StringType, self).__init__(**kw)
 
-    def __repr__(self):
-        attributes = inspect.getargspec(self.__init__)[0][1:]
-        attributes.extend(inspect.getargspec(_StringType.__init__)[0][1:])
-
-        params = {}
-        for attr in attributes:
-            val = getattr(self, attr)
-            if val is not None and val is not False:
-                params[attr] = val
-
-        return "%s(%s)" % (self.__class__.__name__,
-                           ', '.join(['%s=%r' % (k, params[k]) for k in params]))
-
 
 class NUMERIC(_NumericType, sqltypes.NUMERIC):
     """MySQL NUMERIC type."""
@@ -1489,7 +1478,7 @@ class MySQLDDLCompiler(compiler.DDLCompiler):
 
         opts = dict(
             (
-                k[len(self.dialect.name)+1:].upper(),
+                k[len(self.dialect.name) + 1:].upper(),
                 v
             )
             for k, v in table.kwargs.items()
@@ -1772,7 +1761,8 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
 
     def visit_CHAR(self, type_):
         if type_.length:
-            return self._extend_string(type_, {}, "CHAR(%(length)s)" % {'length' : type_.length})
+            return self._extend_string(type_, {}, "CHAR(%(length)s)" %
+                                        {'length': type_.length})
         else:
             return self._extend_string(type_, {}, "CHAR")
 
@@ -1780,7 +1770,8 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
         # We'll actually generate the equiv. "NATIONAL VARCHAR" instead
         # of "NVARCHAR".
         if type_.length:
-            return self._extend_string(type_, {'national':True}, "VARCHAR(%(length)s)" % {'length': type_.length})
+            return self._extend_string(type_, {'national': True},
+                        "VARCHAR(%(length)s)" % {'length': type_.length})
         else:
             raise exc.CompileError(
                     "NVARCHAR requires a length on dialect %s" %
@@ -1789,9 +1780,10 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
     def visit_NCHAR(self, type_):
         # We'll actually generate the equiv. "NATIONAL CHAR" instead of "NCHAR".
         if type_.length:
-            return self._extend_string(type_, {'national':True}, "CHAR(%(length)s)" % {'length': type_.length})
+            return self._extend_string(type_, {'national': True},
+                        "CHAR(%(length)s)" % {'length': type_.length})
         else:
-            return self._extend_string(type_, {'national':True}, "CHAR")
+            return self._extend_string(type_, {'national': True}, "CHAR")
 
     def visit_VARBINARY(self, type_):
         return "VARBINARY(%d)" % type_.length
index cc41e61825ca0cf261780275c0e1710010cdb2ac..f705a216e8b1deff5d5e7e8945c1de509530aeca 100644 (file)
@@ -2055,11 +2055,6 @@ class DDLCompiler(engine.Compiled):
 
 
 class GenericTypeCompiler(engine.TypeCompiler):
-    def visit_CHAR(self, type_):
-        return "CHAR" + (type_.length and "(%d)" % type_.length or "")
-
-    def visit_NCHAR(self, type_):
-        return "NCHAR" + (type_.length and "(%d)" % type_.length or "")
 
     def visit_FLOAT(self, type_):
         return "FLOAT"
@@ -2108,11 +2103,29 @@ class GenericTypeCompiler(engine.TypeCompiler):
     def visit_NCLOB(self, type_):
         return "NCLOB"
 
+    def _render_string_type(self, type_, name):
+
+        text = name
+        if type_.length:
+            text += "(%d)" % type_.length
+        if type_.collation:
+            text += ' COLLATE "%s"' % type_.collation
+        return text
+
+    def visit_CHAR(self, type_):
+        return self._render_string_type(type_, "CHAR")
+
+    def visit_NCHAR(self, type_):
+        return self._render_string_type(type_, "NCHAR")
+
     def visit_VARCHAR(self, type_):
-        return "VARCHAR" + (type_.length and "(%d)" % type_.length or "")
+        return self._render_string_type(type_, "VARCHAR")
 
     def visit_NVARCHAR(self, type_):
-        return "NVARCHAR" + (type_.length and "(%d)" % type_.length or "")
+        return self._render_string_type(type_, "NVARCHAR")
+
+    def visit_TEXT(self, type_):
+        return self._render_string_type(type_, "TEXT")
 
     def visit_BLOB(self, type_):
         return "BLOB"
@@ -2126,8 +2139,6 @@ class GenericTypeCompiler(engine.TypeCompiler):
     def visit_BOOLEAN(self, type_):
         return "BOOLEAN"
 
-    def visit_TEXT(self, type_):
-        return "TEXT"
 
     def visit_large_binary(self, type_):
         return self.visit_BLOB(type_)
index 560bc9c97a5397f7f4dbafde3489500ca871672b..bdd619bad403c93cbd82cf36f52241497de7c1d1 100644 (file)
@@ -162,3 +162,9 @@ class SuiteRequirements(Requirements):
     @property
     def index_reflection(self):
         return exclusions.open()
+
+    @property
+    def unbounded_varchar(self):
+        """Target database must support VARCHAR with no length"""
+
+        return exclusions.open()
index eeb19496b0cd5c84067defc0a84ba8a7eae0981d..71bd39ba6e9fdfba104f3af2534f661b85a978f2 100644 (file)
@@ -971,7 +971,8 @@ class String(Concatenable, TypeEngine):
 
     __visit_name__ = 'string'
 
-    def __init__(self, length=None, convert_unicode=False,
+    def __init__(self, length=None, collation=None,
+                        convert_unicode=False,
                         assert_unicode=None, unicode_error=None,
                         _warn_on_bytestring=False
                         ):
@@ -979,13 +980,25 @@ class String(Concatenable, TypeEngine):
         Create a string-holding type.
 
         :param length: optional, a length for the column for use in
-          DDL statements.  May be safely omitted if no ``CREATE
+          DDL and CAST expressions.  May be safely omitted if no ``CREATE
           TABLE`` will be issued.  Certain databases may require a
           ``length`` for use in DDL, and will raise an exception when
           the ``CREATE TABLE`` DDL is issued if a ``VARCHAR``
           with no length is included.  Whether the value is
           interpreted as bytes or characters is database specific.
 
+        :param collation: Optional, a column-level collation for
+          use in DDL and CAST expressions.  Renders using the
+          COLLATE keyword supported by SQLite, MySQL, and Postgresql.
+          E.g.::
+
+            >>> from sqlalchemy import cast, select, String
+            >>> print select([cast('some string', String(collation='utf8'))])
+            SELECT CAST(:param_1 AS VARCHAR COLLATE utf8) AS anon_1
+
+          .. versionadded:: 0.8 Added support for COLLATE to all
+             string types.
+
         :param convert_unicode: When set to ``True``, the
           :class:`.String` type will assume that
           input is to be passed as Python ``unicode`` objects,
@@ -1046,6 +1059,7 @@ class String(Concatenable, TypeEngine):
                                  '*not* apply to DBAPIs that coerce '
                                  'Unicode natively.')
         self.length = length
+        self.collation = collation
         self.convert_unicode = convert_unicode
         self.unicode_error = unicode_error
         self._warn_on_bytestring = _warn_on_bytestring
index de1df3846106d2af3ba48adff0bed485c73654ba..d5ce6e923ac677d5adb662cb5cc8428265b0b14b 100644 (file)
@@ -306,7 +306,11 @@ class TypesTest(fixtures.TestBase, AssertsExecutionResults, AssertsCompiledSQL):
              'LONGTEXT ASCII'),
 
             (mysql.ENUM, ["foo", "bar"], {'unicode':True},
-             '''ENUM('foo','bar') UNICODE''')
+             '''ENUM('foo','bar') UNICODE'''),
+
+            (String, [20], {"collation":"utf8"}, 'VARCHAR(20) COLLATE utf8')
+
+
            ]
 
         for type_, args, kw, res in columns:
index 81b572989a2bdb1248233701d2ce431a026d906c..fae28a0bde9bed0d8920d77b7726cb8785d04aec 100644 (file)
@@ -244,14 +244,13 @@ class PickleMetadataTest(fixtures.TestBase):
                 Column('Lar', LargeBinary()),
                 Column('Pic', PickleType()),
                 Column('Int', Interval()),
-                Column('Enu', Enum('x','y','z', name="somename")),
+                Column('Enu', Enum('x', 'y', 'z', name="somename")),
             ]
             for column_type in column_types:
-                #print column_type
                 meta = MetaData()
                 Table('foo', meta, column_type)
-                ct = loads(dumps(column_type))
-                mt = loads(dumps(meta))
+                loads(dumps(column_type))
+                loads(dumps(meta))
 
 
 class UserDefinedTest(fixtures.TablesTest, AssertsCompiledSQL):
@@ -305,7 +304,7 @@ class UserDefinedTest(fixtures.TablesTest, AssertsCompiledSQL):
                 raw_dialect_impl = raw_impl.dialect_impl(dialect_)
                 dec_dialect_impl = dec_type.dialect_impl(dialect_)
                 eq_(dec_dialect_impl.__class__, MyType)
-                eq_(raw_dialect_impl.__class__ , dec_dialect_impl.impl.__class__)
+                eq_(raw_dialect_impl.__class__, dec_dialect_impl.impl.__class__)
 
                 self.assert_compile(
                     MyType(**kw),
@@ -1394,24 +1393,51 @@ class ExpressionTest(fixtures.TestBase, AssertsExecutionResults, AssertsCompiled
         assert test_table.c.data.distinct().type == test_table.c.data.type
 
 class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
-    def test_default_compile(self):
-        """test that the base dialect of the type object is used
-        for default compilation.
 
-        """
+    @testing.requires.unbounded_varchar
+    def test_string_plain(self):
+        self.assert_compile(String(), "VARCHAR")
+
+    def test_string_length(self):
+        self.assert_compile(String(50), "VARCHAR(50)")
+
+    def test_string_collation(self):
+        self.assert_compile(String(50, collation="FOO"),
+                'VARCHAR(50) COLLATE "FOO"')
+
+    def test_char_plain(self):
+        self.assert_compile(CHAR(), "CHAR")
+
+    def test_char_length(self):
+        self.assert_compile(CHAR(50), "CHAR(50)")
+
+    def test_char_collation(self):
+        self.assert_compile(CHAR(50, collation="FOO"),
+                'CHAR(50) COLLATE "FOO"')
+
+    def test_text_plain(self):
+        self.assert_compile(Text(), "TEXT")
+
+    def test_text_length(self):
+        self.assert_compile(Text(50), "TEXT(50)")
+
+    def test_text_collation(self):
+        self.assert_compile(Text(collation="FOO"),
+                'TEXT COLLATE "FOO"')
+
+    def test_default_compile_pg_inet(self):
+        self.assert_compile(dialects.postgresql.INET(), "INET",
+                allow_dialect_select=True)
+
+    def test_default_compile_pg_float(self):
+        self.assert_compile(dialects.postgresql.FLOAT(), "FLOAT",
+                allow_dialect_select=True)
+
+    def test_default_compile_mysql_integer(self):
+        self.assert_compile(
+                dialects.mysql.INTEGER(display_width=5), "INTEGER(5)",
+                allow_dialect_select=True)
 
-        for type_, expected in (
-            (String(), "VARCHAR"),
-            (Integer(), "INTEGER"),
-            (dialects.postgresql.INET(), "INET"),
-            (dialects.postgresql.FLOAT(), "FLOAT"),
-            (dialects.mysql.REAL(precision=8, scale=2), "REAL(8, 2)"),
-            (dialects.postgresql.REAL(), "REAL"),
-            (INTEGER(), "INTEGER"),
-            (dialects.mysql.INTEGER(display_width=5), "INTEGER(5)")
-        ):
-            self.assert_compile(type_, expected,
-                                allow_dialect_select=True)
 
 class DateTest(fixtures.TestBase, AssertsExecutionResults):
     @classmethod