]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Custom dialects that implement :class:`.GenericTypeCompiler` can
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 17 Jan 2015 01:03:33 +0000 (20:03 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 17 Jan 2015 01:03:33 +0000 (20:03 -0500)
now be constructed such that the visit methods receive an indication
of the owning expression object, if any.  Any visit method that
accepts keyword arguments (e.g. ``**kw``) will in most cases
receive a keyword argument ``type_expression``, referring to the
expression object that the type is contained within.  For columns
in DDL, the dialect's compiler class may need to alter its
``get_column_specification()`` method to support this as well.
The ``UserDefinedType.get_col_spec()`` method will also receive
``type_expression`` if it provides ``**kw`` in its argument
signature.
fixes #3074

13 files changed:
doc/build/changelog/changelog_10.rst
lib/sqlalchemy/dialects/firebird/base.py
lib/sqlalchemy/dialects/mssql/base.py
lib/sqlalchemy/dialects/mysql/base.py
lib/sqlalchemy/dialects/oracle/base.py
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/dialects/sqlite/base.py
lib/sqlalchemy/dialects/sybase/base.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/type_api.py
lib/sqlalchemy/util/__init__.py
lib/sqlalchemy/util/langhelpers.py
test/sql/test_types.py

index 5d8bb7b6837f637f0d6884e9961de602d8e06a3e..089c9fafb356a3abd00133a4867316268eb289fa 100644 (file)
     series as well.  For changes that are specific to 1.0 with an emphasis
     on compatibility concerns, see :doc:`/changelog/migration_10`.
 
+    .. change::
+        :tags: enhancement, sql
+        :tickets: 3074
+
+        Custom dialects that implement :class:`.GenericTypeCompiler` can
+        now be constructed such that the visit methods receive an indication
+        of the owning expression object, if any.  Any visit method that
+        accepts keyword arguments (e.g. ``**kw``) will in most cases
+        receive a keyword argument ``type_expression``, referring to the
+        expression object that the type is contained within.  For columns
+        in DDL, the dialect's compiler class may need to alter its
+        ``get_column_specification()`` method to support this as well.
+        The ``UserDefinedType.get_col_spec()`` method will also receive
+        ``type_expression`` if it provides ``**kw`` in its argument
+        signature.
+
     .. change::
         :tags: bug, sql
         :tickets: 3288
index 36229a105f5312ae562d10959aa3bb59cb9c063c..74e8abfc26e4237bcdadf11effd2a180284b704b 100644 (file)
@@ -180,16 +180,16 @@ ischema_names = {
 # _FBDate, etc. as bind/result functionality is required)
 
 class FBTypeCompiler(compiler.GenericTypeCompiler):
-    def visit_boolean(self, type_):
-        return self.visit_SMALLINT(type_)
+    def visit_boolean(self, type_, **kw):
+        return self.visit_SMALLINT(type_, **kw)
 
-    def visit_datetime(self, type_):
-        return self.visit_TIMESTAMP(type_)
+    def visit_datetime(self, type_, **kw):
+        return self.visit_TIMESTAMP(type_, **kw)
 
-    def visit_TEXT(self, type_):
+    def visit_TEXT(self, type_, **kw):
         return "BLOB SUB_TYPE 1"
 
-    def visit_BLOB(self, type_):
+    def visit_BLOB(self, type_, **kw):
         return "BLOB SUB_TYPE 0"
 
     def _extend_string(self, type_, basic):
@@ -199,16 +199,16 @@ class FBTypeCompiler(compiler.GenericTypeCompiler):
         else:
             return '%s CHARACTER SET %s' % (basic, charset)
 
-    def visit_CHAR(self, type_):
-        basic = super(FBTypeCompiler, self).visit_CHAR(type_)
+    def visit_CHAR(self, type_, **kw):
+        basic = super(FBTypeCompiler, self).visit_CHAR(type_, **kw)
         return self._extend_string(type_, basic)
 
-    def visit_VARCHAR(self, type_):
+    def visit_VARCHAR(self, type_, **kw):
         if not type_.length:
             raise exc.CompileError(
                 "VARCHAR requires a length on dialect %s" %
                 self.dialect.name)
-        basic = super(FBTypeCompiler, self).visit_VARCHAR(type_)
+        basic = super(FBTypeCompiler, self).visit_VARCHAR(type_, **kw)
         return self._extend_string(type_, basic)
 
 
index 5d84975c064558d77e5d2c6185dd630907ec4ff3..92d7e4ab3105fbd19290a78cd06d4838297b1278 100644 (file)
@@ -694,7 +694,6 @@ ischema_names = {
 
 
 class MSTypeCompiler(compiler.GenericTypeCompiler):
-
     def _extend(self, spec, type_, length=None):
         """Extend a string-type declaration with standard SQL
         COLLATE annotations.
@@ -715,115 +714,115 @@ class MSTypeCompiler(compiler.GenericTypeCompiler):
         return ' '.join([c for c in (spec, collation)
                          if c is not None])
 
-    def visit_FLOAT(self, type_):
+    def visit_FLOAT(self, type_, **kw):
         precision = getattr(type_, 'precision', None)
         if precision is None:
             return "FLOAT"
         else:
             return "FLOAT(%(precision)s)" % {'precision': precision}
 
-    def visit_TINYINT(self, type_):
+    def visit_TINYINT(self, type_, **kw):
         return "TINYINT"
 
-    def visit_DATETIMEOFFSET(self, type_):
+    def visit_DATETIMEOFFSET(self, type_, **kw):
         if type_.precision:
             return "DATETIMEOFFSET(%s)" % type_.precision
         else:
             return "DATETIMEOFFSET"
 
-    def visit_TIME(self, type_):
+    def visit_TIME(self, type_, **kw):
         precision = getattr(type_, 'precision', None)
         if precision:
             return "TIME(%s)" % precision
         else:
             return "TIME"
 
-    def visit_DATETIME2(self, type_):
+    def visit_DATETIME2(self, type_, **kw):
         precision = getattr(type_, 'precision', None)
         if precision:
             return "DATETIME2(%s)" % precision
         else:
             return "DATETIME2"
 
-    def visit_SMALLDATETIME(self, type_):
+    def visit_SMALLDATETIME(self, type_, **kw):
         return "SMALLDATETIME"
 
-    def visit_unicode(self, type_):
-        return self.visit_NVARCHAR(type_)
+    def visit_unicode(self, type_, **kw):
+        return self.visit_NVARCHAR(type_, **kw)
 
-    def visit_text(self, type_):
+    def visit_text(self, type_, **kw):
         if self.dialect.deprecate_large_types:
-            return self.visit_VARCHAR(type_)
+            return self.visit_VARCHAR(type_, **kw)
         else:
-            return self.visit_TEXT(type_)
+            return self.visit_TEXT(type_, **kw)
 
-    def visit_unicode_text(self, type_):
+    def visit_unicode_text(self, type_, **kw):
         if self.dialect.deprecate_large_types:
-            return self.visit_NVARCHAR(type_)
+            return self.visit_NVARCHAR(type_, **kw)
         else:
-            return self.visit_NTEXT(type_)
+            return self.visit_NTEXT(type_, **kw)
 
-    def visit_NTEXT(self, type_):
+    def visit_NTEXT(self, type_, **kw):
         return self._extend("NTEXT", type_)
 
-    def visit_TEXT(self, type_):
+    def visit_TEXT(self, type_, **kw):
         return self._extend("TEXT", type_)
 
-    def visit_VARCHAR(self, type_):
+    def visit_VARCHAR(self, type_, **kw):
         return self._extend("VARCHAR", type_, length=type_.length or 'max')
 
-    def visit_CHAR(self, type_):
+    def visit_CHAR(self, type_, **kw):
         return self._extend("CHAR", type_)
 
-    def visit_NCHAR(self, type_):
+    def visit_NCHAR(self, type_, **kw):
         return self._extend("NCHAR", type_)
 
-    def visit_NVARCHAR(self, type_):
+    def visit_NVARCHAR(self, type_, **kw):
         return self._extend("NVARCHAR", type_, length=type_.length or 'max')
 
-    def visit_date(self, type_):
+    def visit_date(self, type_, **kw):
         if self.dialect.server_version_info < MS_2008_VERSION:
-            return self.visit_DATETIME(type_)
+            return self.visit_DATETIME(type_, **kw)
         else:
-            return self.visit_DATE(type_)
+            return self.visit_DATE(type_, **kw)
 
-    def visit_time(self, type_):
+    def visit_time(self, type_, **kw):
         if self.dialect.server_version_info < MS_2008_VERSION:
-            return self.visit_DATETIME(type_)
+            return self.visit_DATETIME(type_, **kw)
         else:
-            return self.visit_TIME(type_)
+            return self.visit_TIME(type_, **kw)
 
-    def visit_large_binary(self, type_):
+    def visit_large_binary(self, type_, **kw):
         if self.dialect.deprecate_large_types:
-            return self.visit_VARBINARY(type_)
+            return self.visit_VARBINARY(type_, **kw)
         else:
-            return self.visit_IMAGE(type_)
+            return self.visit_IMAGE(type_, **kw)
 
-    def visit_IMAGE(self, type_):
+    def visit_IMAGE(self, type_, **kw):
         return "IMAGE"
 
-    def visit_VARBINARY(self, type_):
+    def visit_VARBINARY(self, type_, **kw):
         return self._extend(
             "VARBINARY",
             type_,
             length=type_.length or 'max')
 
-    def visit_boolean(self, type_):
+    def visit_boolean(self, type_, **kw):
         return self.visit_BIT(type_)
 
-    def visit_BIT(self, type_):
+    def visit_BIT(self, type_, **kw):
         return "BIT"
 
-    def visit_MONEY(self, type_):
+    def visit_MONEY(self, type_, **kw):
         return "MONEY"
 
-    def visit_SMALLMONEY(self, type_):
+    def visit_SMALLMONEY(self, type_, **kw):
         return 'SMALLMONEY'
 
-    def visit_UNIQUEIDENTIFIER(self, type_):
+    def visit_UNIQUEIDENTIFIER(self, type_, **kw):
         return "UNIQUEIDENTIFIER"
 
-    def visit_SQL_VARIANT(self, type_):
+    def visit_SQL_VARIANT(self, type_, **kw):
         return 'SQL_VARIANT'
 
 
@@ -1240,8 +1239,11 @@ class MSSQLStrictCompiler(MSSQLCompiler):
 class MSDDLCompiler(compiler.DDLCompiler):
 
     def get_column_specification(self, column, **kwargs):
-        colspec = (self.preparer.format_column(column) + " "
-                   + self.dialect.type_compiler.process(column.type))
+        colspec = (
+            self.preparer.format_column(column) + " "
+            + self.dialect.type_compiler.process(
+                column.type, type_expression=column)
+        )
 
         if column.nullable is not None:
             if not column.nullable or column.primary_key or \
index 9c3f23cb2f2090c2e859018804279102c17983ae..ca56a4d232026325e6d1d6e5aff7cefffeac220e 100644 (file)
@@ -1859,9 +1859,11 @@ class MySQLDDLCompiler(compiler.DDLCompiler):
     def get_column_specification(self, column, **kw):
         """Builds column DDL."""
 
-        colspec = [self.preparer.format_column(column),
-                   self.dialect.type_compiler.process(column.type)
-                   ]
+        colspec = [
+            self.preparer.format_column(column),
+            self.dialect.type_compiler.process(
+                column.type, type_expression=column)
+        ]
 
         default = self.get_column_default_string(column)
         if default is not None:
@@ -2059,7 +2061,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
     def _mysql_type(self, type_):
         return isinstance(type_, (_StringType, _NumericType))
 
-    def visit_NUMERIC(self, type_):
+    def visit_NUMERIC(self, type_, **kw):
         if type_.precision is None:
             return self._extend_numeric(type_, "NUMERIC")
         elif type_.scale is None:
@@ -2072,7 +2074,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
                                         {'precision': type_.precision,
                                          'scale': type_.scale})
 
-    def visit_DECIMAL(self, type_):
+    def visit_DECIMAL(self, type_, **kw):
         if type_.precision is None:
             return self._extend_numeric(type_, "DECIMAL")
         elif type_.scale is None:
@@ -2085,7 +2087,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
                                         {'precision': type_.precision,
                                          'scale': type_.scale})
 
-    def visit_DOUBLE(self, type_):
+    def visit_DOUBLE(self, type_, **kw):
         if type_.precision is not None and type_.scale is not None:
             return self._extend_numeric(type_,
                                         "DOUBLE(%(precision)s, %(scale)s)" %
@@ -2094,7 +2096,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
         else:
             return self._extend_numeric(type_, 'DOUBLE')
 
-    def visit_REAL(self, type_):
+    def visit_REAL(self, type_, **kw):
         if type_.precision is not None and type_.scale is not None:
             return self._extend_numeric(type_,
                                         "REAL(%(precision)s, %(scale)s)" %
@@ -2103,7 +2105,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
         else:
             return self._extend_numeric(type_, 'REAL')
 
-    def visit_FLOAT(self, type_):
+    def visit_FLOAT(self, type_, **kw):
         if self._mysql_type(type_) and \
                 type_.scale is not None and \
                 type_.precision is not None:
@@ -2115,7 +2117,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
         else:
             return self._extend_numeric(type_, "FLOAT")
 
-    def visit_INTEGER(self, type_):
+    def visit_INTEGER(self, type_, **kw):
         if self._mysql_type(type_) and type_.display_width is not None:
             return self._extend_numeric(
                 type_, "INTEGER(%(display_width)s)" %
@@ -2123,7 +2125,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
         else:
             return self._extend_numeric(type_, "INTEGER")
 
-    def visit_BIGINT(self, type_):
+    def visit_BIGINT(self, type_, **kw):
         if self._mysql_type(type_) and type_.display_width is not None:
             return self._extend_numeric(
                 type_, "BIGINT(%(display_width)s)" %
@@ -2131,7 +2133,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
         else:
             return self._extend_numeric(type_, "BIGINT")
 
-    def visit_MEDIUMINT(self, type_):
+    def visit_MEDIUMINT(self, type_, **kw):
         if self._mysql_type(type_) and type_.display_width is not None:
             return self._extend_numeric(
                 type_, "MEDIUMINT(%(display_width)s)" %
@@ -2139,14 +2141,14 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
         else:
             return self._extend_numeric(type_, "MEDIUMINT")
 
-    def visit_TINYINT(self, type_):
+    def visit_TINYINT(self, type_, **kw):
         if self._mysql_type(type_) and type_.display_width is not None:
             return self._extend_numeric(type_,
                                         "TINYINT(%s)" % type_.display_width)
         else:
             return self._extend_numeric(type_, "TINYINT")
 
-    def visit_SMALLINT(self, type_):
+    def visit_SMALLINT(self, type_, **kw):
         if self._mysql_type(type_) and type_.display_width is not None:
             return self._extend_numeric(type_,
                                         "SMALLINT(%(display_width)s)" %
@@ -2155,55 +2157,55 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
         else:
             return self._extend_numeric(type_, "SMALLINT")
 
-    def visit_BIT(self, type_):
+    def visit_BIT(self, type_, **kw):
         if type_.length is not None:
             return "BIT(%s)" % type_.length
         else:
             return "BIT"
 
-    def visit_DATETIME(self, type_):
+    def visit_DATETIME(self, type_, **kw):
         if getattr(type_, 'fsp', None):
             return "DATETIME(%d)" % type_.fsp
         else:
             return "DATETIME"
 
-    def visit_DATE(self, type_):
+    def visit_DATE(self, type_, **kw):
         return "DATE"
 
-    def visit_TIME(self, type_):
+    def visit_TIME(self, type_, **kw):
         if getattr(type_, 'fsp', None):
             return "TIME(%d)" % type_.fsp
         else:
             return "TIME"
 
-    def visit_TIMESTAMP(self, type_):
+    def visit_TIMESTAMP(self, type_, **kw):
         if getattr(type_, 'fsp', None):
             return "TIMESTAMP(%d)" % type_.fsp
         else:
             return "TIMESTAMP"
 
-    def visit_YEAR(self, type_):
+    def visit_YEAR(self, type_, **kw):
         if type_.display_width is None:
             return "YEAR"
         else:
             return "YEAR(%s)" % type_.display_width
 
-    def visit_TEXT(self, type_):
+    def visit_TEXT(self, type_, **kw):
         if type_.length:
             return self._extend_string(type_, {}, "TEXT(%d)" % type_.length)
         else:
             return self._extend_string(type_, {}, "TEXT")
 
-    def visit_TINYTEXT(self, type_):
+    def visit_TINYTEXT(self, type_, **kw):
         return self._extend_string(type_, {}, "TINYTEXT")
 
-    def visit_MEDIUMTEXT(self, type_):
+    def visit_MEDIUMTEXT(self, type_, **kw):
         return self._extend_string(type_, {}, "MEDIUMTEXT")
 
-    def visit_LONGTEXT(self, type_):
+    def visit_LONGTEXT(self, type_, **kw):
         return self._extend_string(type_, {}, "LONGTEXT")
 
-    def visit_VARCHAR(self, type_):
+    def visit_VARCHAR(self, type_, **kw):
         if type_.length:
             return self._extend_string(
                 type_, {}, "VARCHAR(%d)" % type_.length)
@@ -2212,14 +2214,14 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
                 "VARCHAR requires a length on dialect %s" %
                 self.dialect.name)
 
-    def visit_CHAR(self, type_):
+    def visit_CHAR(self, type_, **kw):
         if type_.length:
             return self._extend_string(type_, {}, "CHAR(%(length)s)" %
                                        {'length': type_.length})
         else:
             return self._extend_string(type_, {}, "CHAR")
 
-    def visit_NVARCHAR(self, type_):
+    def visit_NVARCHAR(self, type_, **kw):
         # We'll actually generate the equiv. "NATIONAL VARCHAR" instead
         # of "NVARCHAR".
         if type_.length:
@@ -2231,7 +2233,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
                 "NVARCHAR requires a length on dialect %s" %
                 self.dialect.name)
 
-    def visit_NCHAR(self, type_):
+    def visit_NCHAR(self, type_, **kw):
         # We'll actually generate the equiv.
         # "NATIONAL CHAR" instead of "NCHAR".
         if type_.length:
@@ -2241,31 +2243,31 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
         else:
             return self._extend_string(type_, {'national': True}, "CHAR")
 
-    def visit_VARBINARY(self, type_):
+    def visit_VARBINARY(self, type_, **kw):
         return "VARBINARY(%d)" % type_.length
 
-    def visit_large_binary(self, type_):
+    def visit_large_binary(self, type_, **kw):
         return self.visit_BLOB(type_)
 
-    def visit_enum(self, type_):
+    def visit_enum(self, type_, **kw):
         if not type_.native_enum:
             return super(MySQLTypeCompiler, self).visit_enum(type_)
         else:
             return self._visit_enumerated_values("ENUM", type_, type_.enums)
 
-    def visit_BLOB(self, type_):
+    def visit_BLOB(self, type_, **kw):
         if type_.length:
             return "BLOB(%d)" % type_.length
         else:
             return "BLOB"
 
-    def visit_TINYBLOB(self, type_):
+    def visit_TINYBLOB(self, type_, **kw):
         return "TINYBLOB"
 
-    def visit_MEDIUMBLOB(self, type_):
+    def visit_MEDIUMBLOB(self, type_, **kw):
         return "MEDIUMBLOB"
 
-    def visit_LONGBLOB(self, type_):
+    def visit_LONGBLOB(self, type_, **kw):
         return "LONGBLOB"
 
     def _visit_enumerated_values(self, name, type_, enumerated_values):
@@ -2276,15 +2278,15 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
             name, ",".join(quoted_enums))
         )
 
-    def visit_ENUM(self, type_):
+    def visit_ENUM(self, type_, **kw):
         return self._visit_enumerated_values("ENUM", type_,
                                              type_._enumerated_values)
 
-    def visit_SET(self, type_):
+    def visit_SET(self, type_, **kw):
         return self._visit_enumerated_values("SET", type_,
                                              type_._enumerated_values)
 
-    def visit_BOOLEAN(self, type):
+    def visit_BOOLEAN(self, type, **kw):
         return "BOOL"
 
 
index 9f375da9443911cf78b031e8709f50878322d950..b482c9069d4f4f67cb935f115ffdd10447b340df 100644 (file)
@@ -457,19 +457,19 @@ class OracleTypeCompiler(compiler.GenericTypeCompiler):
     # Oracle does not allow milliseconds in DATE
     # Oracle does not support TIME columns
 
-    def visit_datetime(self, type_):
-        return self.visit_DATE(type_)
+    def visit_datetime(self, type_, **kw):
+        return self.visit_DATE(type_, **kw)
 
-    def visit_float(self, type_):
-        return self.visit_FLOAT(type_)
+    def visit_float(self, type_, **kw):
+        return self.visit_FLOAT(type_, **kw)
 
-    def visit_unicode(self, type_):
+    def visit_unicode(self, type_, **kw):
         if self.dialect._supports_nchar:
-            return self.visit_NVARCHAR2(type_)
+            return self.visit_NVARCHAR2(type_, **kw)
         else:
-            return self.visit_VARCHAR2(type_)
+            return self.visit_VARCHAR2(type_, **kw)
 
-    def visit_INTERVAL(self, type_):
+    def visit_INTERVAL(self, type_, **kw):
         return "INTERVAL DAY%s TO SECOND%s" % (
             type_.day_precision is not None and
             "(%d)" % type_.day_precision or
@@ -479,22 +479,22 @@ class OracleTypeCompiler(compiler.GenericTypeCompiler):
             "",
         )
 
-    def visit_LONG(self, type_):
+    def visit_LONG(self, type_, **kw):
         return "LONG"
 
-    def visit_TIMESTAMP(self, type_):
+    def visit_TIMESTAMP(self, type_, **kw):
         if type_.timezone:
             return "TIMESTAMP WITH TIME ZONE"
         else:
             return "TIMESTAMP"
 
-    def visit_DOUBLE_PRECISION(self, type_):
-        return self._generate_numeric(type_, "DOUBLE PRECISION")
+    def visit_DOUBLE_PRECISION(self, type_, **kw):
+        return self._generate_numeric(type_, "DOUBLE PRECISION", **kw)
 
     def visit_NUMBER(self, type_, **kw):
         return self._generate_numeric(type_, "NUMBER", **kw)
 
-    def _generate_numeric(self, type_, name, precision=None, scale=None):
+    def _generate_numeric(self, type_, name, precision=None, scale=None, **kw):
         if precision is None:
             precision = type_.precision
 
@@ -510,17 +510,17 @@ class OracleTypeCompiler(compiler.GenericTypeCompiler):
             n = "%(name)s(%(precision)s, %(scale)s)"
             return n % {'name': name, 'precision': precision, 'scale': scale}
 
-    def visit_string(self, type_):
-        return self.visit_VARCHAR2(type_)
+    def visit_string(self, type_, **kw):
+        return self.visit_VARCHAR2(type_, **kw)
 
-    def visit_VARCHAR2(self, type_):
+    def visit_VARCHAR2(self, type_, **kw):
         return self._visit_varchar(type_, '', '2')
 
-    def visit_NVARCHAR2(self, type_):
+    def visit_NVARCHAR2(self, type_, **kw):
         return self._visit_varchar(type_, 'N', '2')
     visit_NVARCHAR = visit_NVARCHAR2
 
-    def visit_VARCHAR(self, type_):
+    def visit_VARCHAR(self, type_, **kw):
         return self._visit_varchar(type_, '', '')
 
     def _visit_varchar(self, type_, n, num):
@@ -533,31 +533,31 @@ class OracleTypeCompiler(compiler.GenericTypeCompiler):
             varchar = "%(n)sVARCHAR%(two)s(%(length)s)"
             return varchar % {'length': type_.length, 'two': num, 'n': n}
 
-    def visit_text(self, type_):
-        return self.visit_CLOB(type_)
+    def visit_text(self, type_, **kw):
+        return self.visit_CLOB(type_, **kw)
 
-    def visit_unicode_text(self, type_):
+    def visit_unicode_text(self, type_, **kw):
         if self.dialect._supports_nchar:
-            return self.visit_NCLOB(type_)
+            return self.visit_NCLOB(type_, **kw)
         else:
-            return self.visit_CLOB(type_)
+            return self.visit_CLOB(type_, **kw)
 
-    def visit_large_binary(self, type_):
-        return self.visit_BLOB(type_)
+    def visit_large_binary(self, type_, **kw):
+        return self.visit_BLOB(type_, **kw)
 
-    def visit_big_integer(self, type_):
-        return self.visit_NUMBER(type_, precision=19)
+    def visit_big_integer(self, type_, **kw):
+        return self.visit_NUMBER(type_, precision=19, **kw)
 
-    def visit_boolean(self, type_):
-        return self.visit_SMALLINT(type_)
+    def visit_boolean(self, type_, **kw):
+        return self.visit_SMALLINT(type_, **kw)
 
-    def visit_RAW(self, type_):
+    def visit_RAW(self, type_, **kw):
         if type_.length:
             return "RAW(%(length)s)" % {'length': type_.length}
         else:
             return "RAW"
 
-    def visit_ROWID(self, type_):
+    def visit_ROWID(self, type_, **kw):
         return "ROWID"
 
 
index 0817fe8371ca07ec6ece743ded8ba88368aebb59..89bea100e4e3a210c32da398cec92d5adc36f507 100644 (file)
@@ -1425,7 +1425,8 @@ class PGDDLCompiler(compiler.DDLCompiler):
             else:
                 colspec += " SERIAL"
         else:
-            colspec += " " + self.dialect.type_compiler.process(column.type)
+            colspec += " " + self.dialect.type_compiler.process(column.type,
+                                                    type_expression=column)
             default = self.get_column_default_string(column)
             if default is not None:
                 colspec += " DEFAULT " + default
@@ -1545,94 +1546,93 @@ class PGDDLCompiler(compiler.DDLCompiler):
 
 
 class PGTypeCompiler(compiler.GenericTypeCompiler):
-
-    def visit_TSVECTOR(self, type):
+    def visit_TSVECTOR(self, type, **kw):
         return "TSVECTOR"
 
-    def visit_INET(self, type_):
+    def visit_INET(self, type_, **kw):
         return "INET"
 
-    def visit_CIDR(self, type_):
+    def visit_CIDR(self, type_, **kw):
         return "CIDR"
 
-    def visit_MACADDR(self, type_):
+    def visit_MACADDR(self, type_, **kw):
         return "MACADDR"
 
-    def visit_OID(self, type_):
+    def visit_OID(self, type_, **kw):
         return "OID"
 
-    def visit_FLOAT(self, type_):
+    def visit_FLOAT(self, type_, **kw):
         if not type_.precision:
             return "FLOAT"
         else:
             return "FLOAT(%(precision)s)" % {'precision': type_.precision}
 
-    def visit_DOUBLE_PRECISION(self, type_):
+    def visit_DOUBLE_PRECISION(self, type_, **kw):
         return "DOUBLE PRECISION"
 
-    def visit_BIGINT(self, type_):
+    def visit_BIGINT(self, type_, **kw):
         return "BIGINT"
 
-    def visit_HSTORE(self, type_):
+    def visit_HSTORE(self, type_, **kw):
         return "HSTORE"
 
-    def visit_JSON(self, type_):
+    def visit_JSON(self, type_, **kw):
         return "JSON"
 
-    def visit_JSONB(self, type_):
+    def visit_JSONB(self, type_, **kw):
         return "JSONB"
 
-    def visit_INT4RANGE(self, type_):
+    def visit_INT4RANGE(self, type_, **kw):
         return "INT4RANGE"
 
-    def visit_INT8RANGE(self, type_):
+    def visit_INT8RANGE(self, type_, **kw):
         return "INT8RANGE"
 
-    def visit_NUMRANGE(self, type_):
+    def visit_NUMRANGE(self, type_, **kw):
         return "NUMRANGE"
 
-    def visit_DATERANGE(self, type_):
+    def visit_DATERANGE(self, type_, **kw):
         return "DATERANGE"
 
-    def visit_TSRANGE(self, type_):
+    def visit_TSRANGE(self, type_, **kw):
         return "TSRANGE"
 
-    def visit_TSTZRANGE(self, type_):
+    def visit_TSTZRANGE(self, type_, **kw):
         return "TSTZRANGE"
 
-    def visit_datetime(self, type_):
-        return self.visit_TIMESTAMP(type_)
+    def visit_datetime(self, type_, **kw):
+        return self.visit_TIMESTAMP(type_, **kw)
 
-    def visit_enum(self, type_):
+    def visit_enum(self, type_, **kw):
         if not type_.native_enum or not self.dialect.supports_native_enum:
-            return super(PGTypeCompiler, self).visit_enum(type_)
+            return super(PGTypeCompiler, self).visit_enum(type_, **kw)
         else:
-            return self.visit_ENUM(type_)
+            return self.visit_ENUM(type_, **kw)
 
-    def visit_ENUM(self, type_):
+    def visit_ENUM(self, type_, **kw):
         return self.dialect.identifier_preparer.format_type(type_)
 
-    def visit_TIMESTAMP(self, type_):
+    def visit_TIMESTAMP(self, type_, **kw):
         return "TIMESTAMP%s %s" % (
             getattr(type_, 'precision', None) and "(%d)" %
             type_.precision or "",
             (type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE"
         )
 
-    def visit_TIME(self, type_):
+    def visit_TIME(self, type_, **kw):
         return "TIME%s %s" % (
             getattr(type_, 'precision', None) and "(%d)" %
             type_.precision or "",
             (type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE"
         )
 
-    def visit_INTERVAL(self, type_):
+    def visit_INTERVAL(self, type_, **kw):
         if type_.precision is not None:
             return "INTERVAL(%d)" % type_.precision
         else:
             return "INTERVAL"
 
-    def visit_BIT(self, type_):
+    def visit_BIT(self, type_, **kw):
         if type_.varying:
             compiled = "BIT VARYING"
             if type_.length is not None:
@@ -1641,16 +1641,16 @@ class PGTypeCompiler(compiler.GenericTypeCompiler):
             compiled = "BIT(%d)" % type_.length
         return compiled
 
-    def visit_UUID(self, type_):
+    def visit_UUID(self, type_, **kw):
         return "UUID"
 
-    def visit_large_binary(self, type_):
-        return self.visit_BYTEA(type_)
+    def visit_large_binary(self, type_, **kw):
+        return self.visit_BYTEA(type_, **kw)
 
-    def visit_BYTEA(self, type_):
+    def visit_BYTEA(self, type_, **kw):
         return "BYTEA"
 
-    def visit_ARRAY(self, type_):
+    def visit_ARRAY(self, type_, **kw):
         return self.process(type_.item_type) + ('[]' * (type_.dimensions
                                                         if type_.dimensions
                                                         is not None else 1))
index 3d7b0788b6cc5504a61a959f42882b36dd18f4e8..f7442196750abf489b79502f6c2b45ad0489d521 100644 (file)
@@ -660,7 +660,8 @@ class SQLiteCompiler(compiler.SQLCompiler):
 class SQLiteDDLCompiler(compiler.DDLCompiler):
 
     def get_column_specification(self, column, **kwargs):
-        coltype = self.dialect.type_compiler.process(column.type)
+        coltype = self.dialect.type_compiler.process(
+            column.type, type_expression=column)
         colspec = self.preparer.format_column(column) + " " + coltype
         default = self.get_column_default_string(column)
         if default is not None:
@@ -716,24 +717,24 @@ class SQLiteDDLCompiler(compiler.DDLCompiler):
 
 
 class SQLiteTypeCompiler(compiler.GenericTypeCompiler):
-    def visit_large_binary(self, type_):
+    def visit_large_binary(self, type_, **kw):
         return self.visit_BLOB(type_)
 
-    def visit_DATETIME(self, type_):
+    def visit_DATETIME(self, type_, **kw):
         if not isinstance(type_, _DateTimeMixin) or \
                 type_.format_is_text_affinity:
             return super(SQLiteTypeCompiler, self).visit_DATETIME(type_)
         else:
             return "DATETIME_CHAR"
 
-    def visit_DATE(self, type_):
+    def visit_DATE(self, type_, **kw):
         if not isinstance(type_, _DateTimeMixin) or \
                 type_.format_is_text_affinity:
             return super(SQLiteTypeCompiler, self).visit_DATE(type_)
         else:
             return "DATE_CHAR"
 
-    def visit_TIME(self, type_):
+    def visit_TIME(self, type_, **kw):
         if not isinstance(type_, _DateTimeMixin) or \
                 type_.format_is_text_affinity:
             return super(SQLiteTypeCompiler, self).visit_TIME(type_)
index f65a76a271cbb9e36d4d16c7febe1a03038f35b0..369420358cc890fbcfa7c46c86bb51900a30a253 100644 (file)
@@ -146,40 +146,40 @@ class IMAGE(sqltypes.LargeBinary):
 
 
 class SybaseTypeCompiler(compiler.GenericTypeCompiler):
-    def visit_large_binary(self, type_):
+    def visit_large_binary(self, type_, **kw):
         return self.visit_IMAGE(type_)
 
-    def visit_boolean(self, type_):
+    def visit_boolean(self, type_, **kw):
         return self.visit_BIT(type_)
 
-    def visit_unicode(self, type_):
+    def visit_unicode(self, type_, **kw):
         return self.visit_NVARCHAR(type_)
 
-    def visit_UNICHAR(self, type_):
+    def visit_UNICHAR(self, type_, **kw):
         return "UNICHAR(%d)" % type_.length
 
-    def visit_UNIVARCHAR(self, type_):
+    def visit_UNIVARCHAR(self, type_, **kw):
         return "UNIVARCHAR(%d)" % type_.length
 
-    def visit_UNITEXT(self, type_):
+    def visit_UNITEXT(self, type_, **kw):
         return "UNITEXT"
 
-    def visit_TINYINT(self, type_):
+    def visit_TINYINT(self, type_, **kw):
         return "TINYINT"
 
-    def visit_IMAGE(self, type_):
+    def visit_IMAGE(self, type_, **kw):
         return "IMAGE"
 
-    def visit_BIT(self, type_):
+    def visit_BIT(self, type_, **kw):
         return "BIT"
 
-    def visit_MONEY(self, type_):
+    def visit_MONEY(self, type_, **kw):
         return "MONEY"
 
-    def visit_SMALLMONEY(self, type_):
+    def visit_SMALLMONEY(self, type_, **kw):
         return "SMALLMONEY"
 
-    def visit_UNIQUEIDENTIFIER(self, type_):
+    def visit_UNIQUEIDENTIFIER(self, type_, **kw):
         return "UNIQUEIDENTIFIER"
 
 ischema_names = {
@@ -377,7 +377,8 @@ class SybaseSQLCompiler(compiler.SQLCompiler):
 class SybaseDDLCompiler(compiler.DDLCompiler):
     def get_column_specification(self, column, **kwargs):
         colspec = self.preparer.format_column(column) + " " + \
-            self.dialect.type_compiler.process(column.type)
+            self.dialect.type_compiler.process(
+                column.type, type_expression=column)
 
         if column.table is None:
             raise exc.CompileError(
index ca14c93710034315ab968043b81ef2f9febdf081..da62b14348b739d3b295fecc1bc7c6200671ba4a 100644 (file)
@@ -248,15 +248,16 @@ class Compiled(object):
         return self.execute(*multiparams, **params).scalar()
 
 
-class TypeCompiler(object):
-
+class TypeCompiler(util.with_metaclass(util.EnsureKWArgType, object)):
     """Produces DDL specification for TypeEngine objects."""
 
+    ensure_kwarg = 'visit_\w+'
+
     def __init__(self, dialect):
         self.dialect = dialect
 
-    def process(self, type_):
-        return type_._compiler_dispatch(self)
+    def process(self, type_, **kw):
+        return type_._compiler_dispatch(self, **kw)
 
 
 class _CompileLabel(visitors.Visitable):
@@ -638,8 +639,9 @@ class SQLCompiler(Compiled):
     def visit_index(self, index, **kwargs):
         return index.name
 
-    def visit_typeclause(self, typeclause, **kwargs):
-        return self.dialect.type_compiler.process(typeclause.type)
+    def visit_typeclause(self, typeclause, **kw):
+        kw['type_expression'] = typeclause
+        return self.dialect.type_compiler.process(typeclause.type, **kw)
 
     def post_process_text(self, text):
         return text
@@ -2259,7 +2261,8 @@ class DDLCompiler(Compiled):
 
     def get_column_specification(self, column, **kwargs):
         colspec = self.preparer.format_column(column) + " " + \
-            self.dialect.type_compiler.process(column.type)
+            self.dialect.type_compiler.process(
+                column.type, type_expression=column)
         default = self.get_column_default_string(column)
         if default is not None:
             colspec += " DEFAULT " + default
@@ -2383,13 +2386,13 @@ class DDLCompiler(Compiled):
 
 class GenericTypeCompiler(TypeCompiler):
 
-    def visit_FLOAT(self, type_):
+    def visit_FLOAT(self, type_, **kw):
         return "FLOAT"
 
-    def visit_REAL(self, type_):
+    def visit_REAL(self, type_, **kw):
         return "REAL"
 
-    def visit_NUMERIC(self, type_):
+    def visit_NUMERIC(self, type_, **kw):
         if type_.precision is None:
             return "NUMERIC"
         elif type_.scale is None:
@@ -2400,7 +2403,7 @@ class GenericTypeCompiler(TypeCompiler):
                 {'precision': type_.precision,
                  'scale': type_.scale}
 
-    def visit_DECIMAL(self, type_):
+    def visit_DECIMAL(self, type_, **kw):
         if type_.precision is None:
             return "DECIMAL"
         elif type_.scale is None:
@@ -2411,31 +2414,31 @@ class GenericTypeCompiler(TypeCompiler):
                 {'precision': type_.precision,
                  'scale': type_.scale}
 
-    def visit_INTEGER(self, type_):
+    def visit_INTEGER(self, type_, **kw):
         return "INTEGER"
 
-    def visit_SMALLINT(self, type_):
+    def visit_SMALLINT(self, type_, **kw):
         return "SMALLINT"
 
-    def visit_BIGINT(self, type_):
+    def visit_BIGINT(self, type_, **kw):
         return "BIGINT"
 
-    def visit_TIMESTAMP(self, type_):
+    def visit_TIMESTAMP(self, type_, **kw):
         return 'TIMESTAMP'
 
-    def visit_DATETIME(self, type_):
+    def visit_DATETIME(self, type_, **kw):
         return "DATETIME"
 
-    def visit_DATE(self, type_):
+    def visit_DATE(self, type_, **kw):
         return "DATE"
 
-    def visit_TIME(self, type_):
+    def visit_TIME(self, type_, **kw):
         return "TIME"
 
-    def visit_CLOB(self, type_):
+    def visit_CLOB(self, type_, **kw):
         return "CLOB"
 
-    def visit_NCLOB(self, type_):
+    def visit_NCLOB(self, type_, **kw):
         return "NCLOB"
 
     def _render_string_type(self, type_, name):
@@ -2447,91 +2450,91 @@ class GenericTypeCompiler(TypeCompiler):
             text += ' COLLATE "%s"' % type_.collation
         return text
 
-    def visit_CHAR(self, type_):
+    def visit_CHAR(self, type_, **kw):
         return self._render_string_type(type_, "CHAR")
 
-    def visit_NCHAR(self, type_):
+    def visit_NCHAR(self, type_, **kw):
         return self._render_string_type(type_, "NCHAR")
 
-    def visit_VARCHAR(self, type_):
+    def visit_VARCHAR(self, type_, **kw):
         return self._render_string_type(type_, "VARCHAR")
 
-    def visit_NVARCHAR(self, type_):
+    def visit_NVARCHAR(self, type_, **kw):
         return self._render_string_type(type_, "NVARCHAR")
 
-    def visit_TEXT(self, type_):
+    def visit_TEXT(self, type_, **kw):
         return self._render_string_type(type_, "TEXT")
 
-    def visit_BLOB(self, type_):
+    def visit_BLOB(self, type_, **kw):
         return "BLOB"
 
-    def visit_BINARY(self, type_):
+    def visit_BINARY(self, type_, **kw):
         return "BINARY" + (type_.length and "(%d)" % type_.length or "")
 
-    def visit_VARBINARY(self, type_):
+    def visit_VARBINARY(self, type_, **kw):
         return "VARBINARY" + (type_.length and "(%d)" % type_.length or "")
 
-    def visit_BOOLEAN(self, type_):
+    def visit_BOOLEAN(self, type_, **kw):
         return "BOOLEAN"
 
-    def visit_large_binary(self, type_):
-        return self.visit_BLOB(type_)
+    def visit_large_binary(self, type_, **kw):
+        return self.visit_BLOB(type_, **kw)
 
-    def visit_boolean(self, type_):
-        return self.visit_BOOLEAN(type_)
+    def visit_boolean(self, type_, **kw):
+        return self.visit_BOOLEAN(type_, **kw)
 
-    def visit_time(self, type_):
-        return self.visit_TIME(type_)
+    def visit_time(self, type_, **kw):
+        return self.visit_TIME(type_, **kw)
 
-    def visit_datetime(self, type_):
-        return self.visit_DATETIME(type_)
+    def visit_datetime(self, type_, **kw):
+        return self.visit_DATETIME(type_, **kw)
 
-    def visit_date(self, type_):
-        return self.visit_DATE(type_)
+    def visit_date(self, type_, **kw):
+        return self.visit_DATE(type_, **kw)
 
-    def visit_big_integer(self, type_):
-        return self.visit_BIGINT(type_)
+    def visit_big_integer(self, type_, **kw):
+        return self.visit_BIGINT(type_, **kw)
 
-    def visit_small_integer(self, type_):
-        return self.visit_SMALLINT(type_)
+    def visit_small_integer(self, type_, **kw):
+        return self.visit_SMALLINT(type_, **kw)
 
-    def visit_integer(self, type_):
-        return self.visit_INTEGER(type_)
+    def visit_integer(self, type_, **kw):
+        return self.visit_INTEGER(type_, **kw)
 
-    def visit_real(self, type_):
-        return self.visit_REAL(type_)
+    def visit_real(self, type_, **kw):
+        return self.visit_REAL(type_, **kw)
 
-    def visit_float(self, type_):
-        return self.visit_FLOAT(type_)
+    def visit_float(self, type_, **kw):
+        return self.visit_FLOAT(type_, **kw)
 
-    def visit_numeric(self, type_):
-        return self.visit_NUMERIC(type_)
+    def visit_numeric(self, type_, **kw):
+        return self.visit_NUMERIC(type_, **kw)
 
-    def visit_string(self, type_):
-        return self.visit_VARCHAR(type_)
+    def visit_string(self, type_, **kw):
+        return self.visit_VARCHAR(type_, **kw)
 
-    def visit_unicode(self, type_):
-        return self.visit_VARCHAR(type_)
+    def visit_unicode(self, type_, **kw):
+        return self.visit_VARCHAR(type_, **kw)
 
-    def visit_text(self, type_):
-        return self.visit_TEXT(type_)
+    def visit_text(self, type_, **kw):
+        return self.visit_TEXT(type_, **kw)
 
-    def visit_unicode_text(self, type_):
-        return self.visit_TEXT(type_)
+    def visit_unicode_text(self, type_, **kw):
+        return self.visit_TEXT(type_, **kw)
 
-    def visit_enum(self, type_):
-        return self.visit_VARCHAR(type_)
+    def visit_enum(self, type_, **kw):
+        return self.visit_VARCHAR(type_, **kw)
 
-    def visit_null(self, type_):
+    def visit_null(self, type_, **kw):
         raise exc.CompileError("Can't generate DDL for %r; "
                                "did you forget to specify a "
                                "type on this Column?" % type_)
 
-    def visit_type_decorator(self, type_):
-        return self.process(type_.type_engine(self.dialect))
+    def visit_type_decorator(self, type_, **kw):
+        return self.process(type_.type_engine(self.dialect), **kw)
 
-    def visit_user_defined(self, type_):
-        return type_.get_col_spec()
+    def visit_user_defined(self, type_, **kw):
+        return type_.get_col_spec(**kw)
 
 
 class IdentifierPreparer(object):
index bff497800992f5ae9771623e5cedf65a39dc1549..19398ae96cf2ac521479019bcb5a03c01ac305cc 100644 (file)
@@ -12,7 +12,7 @@
 
 from .. import exc, util
 from . import operators
-from .visitors import Visitable
+from .visitors import Visitable, VisitableType
 
 # these are back-assigned by sqltypes.
 BOOLEANTYPE = None
@@ -460,7 +460,11 @@ class TypeEngine(Visitable):
         return util.generic_repr(self)
 
 
-class UserDefinedType(TypeEngine):
+class VisitableCheckKWArg(util.EnsureKWArgType, VisitableType):
+    pass
+
+
+class UserDefinedType(util.with_metaclass(VisitableCheckKWArg, TypeEngine)):
     """Base for user defined types.
 
     This should be the base of new types.  Note that
@@ -473,7 +477,7 @@ class UserDefinedType(TypeEngine):
           def __init__(self, precision = 8):
               self.precision = precision
 
-          def get_col_spec(self):
+          def get_col_spec(self, **kw):
               return "MYTYPE(%s)" % self.precision
 
           def bind_processor(self, dialect):
@@ -493,9 +497,23 @@ class UserDefinedType(TypeEngine):
           Column('data', MyType(16))
           )
 
+    The ``get_col_spec()`` method will in most cases receive a keyword
+    argument ``type_expression`` which refers to the owning expression
+    of the type as being compiled, such as a :class:`.Column` or
+    :func:`.cast` construct.  This keyword is only sent if the method
+    accepts keyword arguments (e.g. ``**kw``) in its argument signature;
+    introspection is used to check for this in order to support legacy
+    forms of this function.
+
+    .. versionadded:: 1.0.0 the owning expression is passed to
+       the ``get_col_spec()`` method via the keyword argument
+       ``type_expression``, if it receives ``**kw`` in its signature.
+
     """
     __visit_name__ = "user_defined"
 
+    ensure_kwarg = 'get_col_spec'
+
     class Comparator(TypeEngine.Comparator):
         __slots__ = ()
 
index c23b0196f370fa5b0e5b323a12d7a5516a455d2f..ceee18d86e05e396c75a48108a64b7c99b9d589e 100644 (file)
@@ -36,7 +36,7 @@ from .langhelpers import iterate_attributes, class_hierarchy, \
     generic_repr, counter, PluginLoader, hybridproperty, hybridmethod, \
     safe_reraise,\
     get_callable_argspec, only_once, attrsetter, ellipses_string, \
-    warn_limited, map_bits, MemoizedSlots
+    warn_limited, map_bits, MemoizedSlots, EnsureKWArgType
 
 from .deprecations import warn_deprecated, warn_pending_deprecation, \
     deprecated, pending_deprecation, inject_docstring_text
index 22b6ad4ca7b0ce2df8efe6eaf7e7a2e62e65132e..5a938501a9c6f6726582afb25c153d8970182e3a 100644 (file)
@@ -1348,6 +1348,7 @@ def chop_traceback(tb, exclude_prefix=_UNITTEST_RE, exclude_suffix=_SQLA_RE):
 
 NoneType = type(None)
 
+
 def attrsetter(attrname):
     code = \
         "def set(obj, value):"\
@@ -1355,3 +1356,29 @@ def attrsetter(attrname):
     env = locals().copy()
     exec(code, env)
     return env['set']
+
+
+class EnsureKWArgType(type):
+    """Apply translation of functions to accept **kw arguments if they
+    don't already.
+
+    """
+    def __init__(cls, clsname, bases, clsdict):
+        fn_reg = cls.ensure_kwarg
+        if fn_reg:
+            for key in clsdict:
+                m = re.match(fn_reg, key)
+                if m:
+                    fn = clsdict[key]
+                    spec = inspect.getargspec(fn)
+                    if not spec.keywords:
+                        clsdict[key] = wrapped = cls._wrap_w_kw(fn)
+                        setattr(cls, key, wrapped)
+        super(EnsureKWArgType, cls).__init__(clsname, bases, clsdict)
+
+    def _wrap_w_kw(self, fn):
+
+        def wrap(*arg, **kw):
+            return fn(*arg)
+        return update_wrapper(wrap, fn)
+
index 6ffd88d789a99cc63474a29a307c49468f3a4eab..38b3ced13655f816d16ec9cb7f2865e6af9322c9 100644 (file)
@@ -10,6 +10,8 @@ from sqlalchemy import (
     type_coerce, VARCHAR, Time, DateTime, BigInteger, SmallInteger, BOOLEAN,
     BLOB, NCHAR, NVARCHAR, CLOB, TIME, DATE, DATETIME, TIMESTAMP, SMALLINT,
     INTEGER, DECIMAL, NUMERIC, FLOAT, REAL)
+from sqlalchemy.sql import ddl
+
 from sqlalchemy import exc, types, util, dialects
 for name in dialects.__all__:
     __import__("sqlalchemy.dialects.%s" % name)
@@ -309,6 +311,24 @@ class UserDefinedTest(fixtures.TablesTest, AssertsCompiledSQL):
             literal_binds=True
         )
 
+    def test_kw_colspec(self):
+        class MyType(types.UserDefinedType):
+            def get_col_spec(self, **kw):
+                return "FOOB %s" % kw['type_expression'].name
+
+        class MyOtherType(types.UserDefinedType):
+            def get_col_spec(self):
+                return "BAR"
+
+        self.assert_compile(
+            ddl.CreateColumn(Column('bar', MyType)),
+            "bar FOOB bar"
+        )
+        self.assert_compile(
+            ddl.CreateColumn(Column('bar', MyOtherType)),
+            "bar BAR"
+        )
+
     def test_typedecorator_literal_render_fallback_bound(self):
         # fall back to process_bind_param for literal
         # value rendering.
@@ -1642,6 +1662,49 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
     def test_decimal_scale(self):
         self.assert_compile(types.DECIMAL(2, 4), 'DECIMAL(2, 4)')
 
+    def test_kwarg_legacy_typecompiler(self):
+        from sqlalchemy.sql import compiler
+
+        class SomeTypeCompiler(compiler.GenericTypeCompiler):
+            # transparently decorated w/ kw decorator
+            def visit_VARCHAR(self, type_):
+                return "MYVARCHAR"
+
+            # not affected
+            def visit_INTEGER(self, type_, **kw):
+                return "MYINTEGER %s" % kw['type_expression'].name
+
+        dialect = default.DefaultDialect()
+        dialect.type_compiler = SomeTypeCompiler(dialect)
+        self.assert_compile(
+            ddl.CreateColumn(Column('bar', VARCHAR(50))),
+            "bar MYVARCHAR",
+            dialect=dialect
+        )
+        self.assert_compile(
+            ddl.CreateColumn(Column('bar', INTEGER)),
+            "bar MYINTEGER bar",
+            dialect=dialect
+        )
+
+
+class TestKWArgPassThru(AssertsCompiledSQL, fixtures.TestBase):
+    __backend__ = True
+
+    def test_user_defined(self):
+        """test that dialects pass the column through on DDL."""
+
+        class MyType(types.UserDefinedType):
+            def get_col_spec(self, **kw):
+                return "FOOB %s" % kw['type_expression'].name
+
+        m = MetaData()
+        t = Table('t', m, Column('bar', MyType))
+        self.assert_compile(
+            ddl.CreateColumn(t.c.bar),
+            "bar FOOB bar"
+        )
+
 
 class NumericRawSQLTest(fixtures.TestBase):