]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Include column name in length-less String warning (more [ticket:912])
authorJason Kirtland <jek@discorporate.us>
Thu, 10 Jan 2008 23:16:56 +0000 (23:16 +0000)
committerJason Kirtland <jek@discorporate.us>
Thu, 10 Jan 2008 23:16:56 +0000 (23:16 +0000)
13 files changed:
lib/sqlalchemy/databases/access.py
lib/sqlalchemy/databases/firebird.py
lib/sqlalchemy/databases/informix.py
lib/sqlalchemy/databases/maxdb.py
lib/sqlalchemy/databases/mssql.py
lib/sqlalchemy/databases/mysql.py
lib/sqlalchemy/databases/oracle.py
lib/sqlalchemy/databases/postgres.py
lib/sqlalchemy/databases/sqlite.py
lib/sqlalchemy/databases/sybase.py
lib/sqlalchemy/types.py
lib/sqlalchemy/util.py
test/sql/testtypes.py

index bf972a1c34643fb58396c7d9f4e4448e3d755374..587470f8ef1f418bb074996cdc1c027e50640961 100644 (file)
@@ -378,7 +378,7 @@ class AccessCompiler(compiler.DefaultCompiler):
 
 class AccessSchemaGenerator(compiler.SchemaGenerator):
     def get_column_specification(self, column, **kwargs):
-        colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect, _for_ddl=True).get_col_spec()
+        colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect, _for_ddl=column).get_col_spec()
 
         # install a sequence if we have an implicit IDENTITY column
         if (not getattr(column.table, 'has_sequence', False)) and column.primary_key and \
index 8700b6cce93d782d13d237c97ea7d30a08e4004c..e16593eb87ec692059ebb36859b9a31d5d0cd482 100644 (file)
@@ -587,7 +587,7 @@ class FBSchemaGenerator(sql.compiler.SchemaGenerator):
 
     def get_column_specification(self, column, **kwargs):
         colspec = self.preparer.format_column(column)
-        colspec += " " + column.type.dialect_impl(self.dialect, _for_ddl=True).get_col_spec()
+        colspec += " " + column.type.dialect_impl(self.dialect, _for_ddl=column).get_col_spec()
 
         default = self.get_column_default_string(column)
         if default is not None:
index f27e7a5b8913e23c263ca3d7fd27f09f355474c3..eb07932433b464e89d054db14214a223d44ee820 100644 (file)
@@ -20,7 +20,7 @@ class informix_cursor(object):
     def __init__( self , con ):
         self.__cursor = con.cursor()
         self.rowcount = 0
-    
+
     def offset( self , n ):
         if n > 0:
             self.fetchmany( n )
@@ -29,13 +29,13 @@ class informix_cursor(object):
                 self.rowcount = 0
         else:
             self.rowcount = self.__cursor.rowcount
-            
+
     def execute( self , sql , params ):
         if params is None or len( params ) == 0:
             params = []
-        
+
         return self.__cursor.execute( sql , params )
-    
+
     def __getattr__( self , name ):
         if name not in ( 'offset' , '__cursor' , 'rowcount' , '__del__' , 'execute' ):
             return getattr( self.__cursor , name )
@@ -46,7 +46,7 @@ class InfoNumeric(sqltypes.Numeric):
             return 'NUMERIC'
         else:
             return "NUMERIC(%(precision)s, %(length)s)" % {'precision': self.precision, 'length' : self.length}
-    
+
 class InfoInteger(sqltypes.Integer):
     def get_col_spec(self):
         return "INTEGER"
@@ -62,7 +62,7 @@ class InfoDate(sqltypes.Date):
 class InfoDateTime(sqltypes.DateTime ):
     def get_col_spec(self):
         return "DATETIME YEAR TO SECOND"
-    
+
     def bind_processor(self, dialect):
         def process(value):
             if value is not None:
@@ -70,7 +70,7 @@ class InfoDateTime(sqltypes.DateTime ):
                     value = value.replace( microsecond = 0 )
             return value
         return process
-        
+
 class InfoTime(sqltypes.Time ):
     def get_col_spec(self):
         return "DATETIME HOUR TO SECOND"
@@ -82,15 +82,15 @@ class InfoTime(sqltypes.Time ):
                     value = value.replace( microsecond = 0 )
             return value
         return process
-        
+
     def result_processor(self, dialect):
         def process(value):
             if isinstance( value , datetime.datetime ):
                 return value.time()
             else:
-                return value        
+                return value
         return process
-        
+
 class InfoText(sqltypes.String):
     def get_col_spec(self):
         return "VARCHAR(255)"
@@ -98,7 +98,7 @@ class InfoText(sqltypes.String):
 class InfoString(sqltypes.String):
     def get_col_spec(self):
         return "VARCHAR(%(length)s)" % {'length' : self.length}
-    
+
     def bind_processor(self, dialect):
         def process(value):
             if value == '':
@@ -106,27 +106,27 @@ class InfoString(sqltypes.String):
             else:
                 return value
         return process
-        
+
 class InfoChar(sqltypes.CHAR):
     def get_col_spec(self):
         return "CHAR(%(length)s)" % {'length' : self.length}
-        
+
 class InfoBinary(sqltypes.Binary):
     def get_col_spec(self):
         return "BYTE"
-        
+
 class InfoBoolean(sqltypes.Boolean):
     default_type = 'NUM'
     def get_col_spec(self):
         return "SMALLINT"
-        
+
     def result_processor(self, dialect):
         def process(value):
             if value is None:
                 return None
             return value and True or False
         return process
-    
+
     def bind_processor(self, dialect):
         def process(value):
             if value is True:
@@ -138,7 +138,7 @@ class InfoBoolean(sqltypes.Boolean):
             else:
                 return value and True or False
         return process
-        
+
 colspecs = {
     sqltypes.Integer : InfoInteger,
     sqltypes.Smallinteger : InfoSmallInteger,
@@ -156,26 +156,26 @@ colspecs = {
 
 
 ischema_names = {
-    0   : InfoString,       # CHAR  
+    0   : InfoString,       # CHAR
     1   : InfoSmallInteger, # SMALLINT
-    2   : InfoInteger,      # INT     
+    2   : InfoInteger,      # INT
     3   : InfoNumeric,      # Float
     3   : InfoNumeric,      # SmallFloat
-    5   : InfoNumeric,      # DECIMAL 
+    5   : InfoNumeric,      # DECIMAL
     6   : InfoInteger,      # Serial
-    7   : InfoDate,         # DATE    
+    7   : InfoDate,         # DATE
     8   : InfoNumeric,      # MONEY
     10  : InfoDateTime,     # DATETIME
-    11  : InfoBinary,       # BYTE    
-    12  : InfoText,         # TEXT    
-    13  : InfoString,       # VARCHAR 
-    15  : InfoString,       # NCHAR  
-    16  : InfoString,       # NVARCHAR  
+    11  : InfoBinary,       # BYTE
+    12  : InfoText,         # TEXT
+    13  : InfoString,       # VARCHAR
+    15  : InfoString,       # NCHAR
+    16  : InfoString,       # NVARCHAR
     17  : InfoInteger,      # INT8
     18  : InfoInteger,      # Serial8
     43  : InfoString,       # LVARCHAR
-    -1  : InfoBinary,       # BLOB    
-    -1  : InfoText,         # CLOB    
+    -1  : InfoBinary,       # BLOB
+    -1  : InfoText,         # CLOB
 }
 
 def descriptor():
@@ -204,11 +204,11 @@ class InfoExecutionContext(default.DefaultExecutionContext):
 
     def create_cursor( self ):
         return informix_cursor( self.connection.connection )
-        
+
 class InfoDialect(default.DefaultDialect):
     # for informix 7.31
     max_identifier_length = 18
-    
+
     def __init__(self, use_ansi=True,**kwargs):
         self.use_ansi = use_ansi
         default.DefaultDialect.__init__(self, **kwargs)
@@ -229,7 +229,7 @@ class InfoDialect(default.DefaultDialect):
         cu = connect.cursor()
         cu.execute( 'SET LOCK MODE TO WAIT' )
         #cu.execute( 'SET ISOLATION TO REPEATABLE READ' )
-    
+
     def type_descriptor(self, typeobj):
         return sqltypes.adapt_type(typeobj, colspecs)
 
@@ -238,20 +238,20 @@ class InfoDialect(default.DefaultDialect):
             dsn = '%s@%s' % ( url.database , url.host )
         else:
             dsn = url.database
-        
+
         if url.username:
             opt = { 'user':url.username , 'password': url.password }
         else:
             opt = {}
-            
+
         return ([dsn,], opt )
-        
+
     def create_execution_context(self , *args, **kwargs):
         return InfoExecutionContext(self, *args, **kwargs)
-        
+
     def oid_column_name(self,column):
         return "rowid"
-    
+
     def table_names(self, connection, schema):
         s = "select tabname from systables"
         return [row[0] for row in connection.execute(s)]
@@ -259,7 +259,7 @@ class InfoDialect(default.DefaultDialect):
     def has_table(self, connection, table_name,schema=None):
         cursor = connection.execute("""select tabname from systables where tabname=?""", table_name.lower() )
         return bool( cursor.fetchone() is not None )
-        
+
     def reflecttable(self, connection, table, include_columns):
         c = connection.execute ("select distinct OWNER from systables where tabname=?", table.name.lower() )
         rows = c.fetchall()
@@ -278,12 +278,12 @@ class InfoDialect(default.DefaultDialect):
                     raise exceptions.AssertionError("There are multiple tables with name %s in the schema, you must specifie owner"%table.name)
 
         c = connection.execute ("""select colname , coltype , collength , t3.default , t1.colno from syscolumns as t1 , systables as t2 , OUTER sysdefaults as t3
-                                    where t1.tabid = t2.tabid and t2.tabname=? and t2.owner=? 
+                                    where t1.tabid = t2.tabid and t2.tabname=? and t2.owner=?
                                       and t3.tabid = t2.tabid and t3.colno = t1.colno
                                     order by t1.colno""", table.name.lower(), owner )
         rows = c.fetchall()
-        
-        if not rows: 
+
+        if not rows:
             raise exceptions.NoSuchTableError(table.name)
 
         for name , colattr , collength , default , colno in rows:
@@ -293,11 +293,11 @@ class InfoDialect(default.DefaultDialect):
 
             # in 7.31, coltype = 0x000
             #                       ^^-- column type
-            #                      ^-- 1 not null , 0 null 
+            #                      ^-- 1 not null , 0 null
             nullable , coltype = divmod( colattr , 256 )
             if coltype not in ( 0 , 13 ) and default:
                 default = default.split()[-1]
-            
+
             if coltype == 0 or coltype == 13: # char , varchar
                 coltype = ischema_names.get(coltype, InfoString)(collength)
                 if default:
@@ -313,26 +313,26 @@ class InfoDialect(default.DefaultDialect):
                 except KeyError:
                     warnings.warn(RuntimeWarning("Did not recognize type '%s' of column '%s'" % (coltype, name)))
                     coltype = sqltypes.NULLTYPE
-            
+
             colargs = []
             if default is not None:
                 colargs.append(schema.PassiveDefault(sql.text(default)))
-            
+
             table.append_column(schema.Column(name, coltype, nullable = (nullable == 0), *colargs))
 
         # FK
-        c = connection.execute("""select t1.constrname as cons_name , t1.constrtype as cons_type , 
-                                         t4.colname as local_column , t7.tabname as remote_table , 
-                                         t6.colname as remote_column 
-                                    from sysconstraints as t1 , systables as t2 , 
-                                         sysindexes as t3 , syscolumns as t4 , 
-                                         sysreferences as t5 , syscolumns as t6 , systables as t7 , 
+        c = connection.execute("""select t1.constrname as cons_name , t1.constrtype as cons_type ,
+                                         t4.colname as local_column , t7.tabname as remote_table ,
+                                         t6.colname as remote_column
+                                    from sysconstraints as t1 , systables as t2 ,
+                                         sysindexes as t3 , syscolumns as t4 ,
+                                         sysreferences as t5 , syscolumns as t6 , systables as t7 ,
                                          sysconstraints as t8 , sysindexes as t9
                                    where t1.tabid = t2.tabid and t2.tabname=? and t2.owner=? and t1.constrtype = 'R'
                                      and t3.tabid = t2.tabid and t3.idxname = t1.idxname
                                      and t4.tabid = t2.tabid and t4.colno = t3.part1
                                      and t5.constrid = t1.constrid and t8.constrid = t5.primary
-                                     and t6.tabid = t5.ptabid and t6.colno = t9.part1 and t9.idxname = t8.idxname 
+                                     and t6.tabid = t5.ptabid and t6.colno = t9.part1 and t9.idxname = t8.idxname
                                      and t7.tabid = t5.ptabid""", table.name.lower(), owner )
         rows = c.fetchall()
         fks = {}
@@ -348,15 +348,15 @@ class InfoDialect(default.DefaultDialect):
                 fk[0].append(local_column)
             if refspec not in fk[1]:
                 fk[1].append(refspec)
-                
+
         for name, value in fks.iteritems():
             table.append_constraint(schema.ForeignKeyConstraint(value[0], value[1] , None ))
-        
+
         # PK
-        c = connection.execute("""select t1.constrname as cons_name , t1.constrtype as cons_type , 
-                                         t4.colname as local_column 
-                                    from sysconstraints as t1 , systables as t2 , 
-                                         sysindexes as t3 , syscolumns as t4 
+        c = connection.execute("""select t1.constrname as cons_name , t1.constrtype as cons_type ,
+                                         t4.colname as local_column
+                                    from sysconstraints as t1 , systables as t2 ,
+                                         sysindexes as t3 , syscolumns as t4
                                    where t1.tabid = t2.tabid and t2.tabname=? and t2.owner=? and t1.constrtype = 'P'
                                      and t3.tabid = t2.tabid and t3.idxname = t1.idxname
                                      and t4.tabid = t2.tabid and t4.colno = t3.part1""", table.name.lower(), owner )
@@ -365,17 +365,17 @@ class InfoDialect(default.DefaultDialect):
             table.primary_key.add( table.c[local_column] )
 
 class InfoCompiler(compiler.DefaultCompiler):
-    """Info compiler modifies the lexical structure of Select statements to work under 
+    """Info compiler modifies the lexical structure of Select statements to work under
     non-ANSI configured Oracle databases, if the use_ansi flag is False."""
     def __init__(self, dialect, statement, parameters=None, **kwargs):
         self.limit = 0
         self.offset = 0
-        
+
         compiler.DefaultCompiler.__init__( self , dialect , statement , parameters , **kwargs )
-    
+
     def default_from(self):
         return " from systables where tabname = 'systables' "
-    
+
     def get_select_precolumns( self , select ):
         s = select._distinct and "DISTINCT " or ""
         # only has limit
@@ -385,27 +385,27 @@ class InfoCompiler(compiler.DefaultCompiler):
         else:
             s += ""
         return s
-    
+
     def visit_select(self, select):
         if select._offset:
             self.offset = select._offset
             self.limit  = select._limit or 0
         # the column in order by clause must in select too
-        
+
         def __label( c ):
             try:
                 return c._label.lower()
             except:
                 return ''
-        
-        # TODO: dont modify the original select, generate a new one        
+
+        # TODO: dont modify the original select, generate a new one
         a = [ __label(c) for c in select._raw_columns ]
         for c in select.order_by_clause.clauses:
             if ( __label(c) not in a ) and getattr( c , 'name' , '' ) != 'oid':
                 select.append_column( c )
-        
+
         return compiler.DefaultCompiler.visit_select(self, select)
-        
+
     def limit_clause(self, select):
         return ""
 
@@ -418,7 +418,7 @@ class InfoCompiler(compiler.DefaultCompiler):
             return "CURRENT YEAR TO SECOND"
         else:
             return compiler.DefaultCompiler.visit_function( self , func )
-            
+
     def visit_clauselist(self, list):
         try:
             li = [ c for c in list.clauses if c.name != 'oid' ]
@@ -434,41 +434,41 @@ class InfoSchemaGenerator(compiler.SchemaGenerator):
             colspec += " SERIAL"
             self.has_serial = True
         else:
-            colspec += " " + column.type.dialect_impl(self.dialect, _for_ddl=True).get_col_spec()
+            colspec += " " + column.type.dialect_impl(self.dialect, _for_ddl=column).get_col_spec()
             default = self.get_column_default_string(column)
             if default is not None:
                 colspec += " DEFAULT " + default
-        
+
         if not column.nullable:
             colspec += " NOT NULL"
-        
+
         return colspec
-        
+
     def post_create_table(self, table):
         if hasattr( self , 'has_serial' ):
             del self.has_serial
         return ''
-    
+
     def visit_primary_key_constraint(self, constraint):
         # for informix 7.31 not support constraint name
         name = constraint.name
         constraint.name = None
         super(InfoSchemaGenerator, self).visit_primary_key_constraint(constraint)
         constraint.name = name
-    
+
     def visit_unique_constraint(self, constraint):
         # for informix 7.31 not support constraint name
         name = constraint.name
         constraint.name = None
         super(InfoSchemaGenerator, self).visit_unique_constraint(constraint)
         constraint.name = name
-        
+
     def visit_foreign_key_constraint( self , constraint ):
         if constraint.name is not None:
             constraint.use_alter = True
         else:
             super( InfoSchemaGenerator , self ).visit_foreign_key_constraint( constraint )
-        
+
     def define_foreign_key(self, constraint):
         # for informix 7.31 not support constraint name
         if constraint.use_alter:
@@ -490,7 +490,7 @@ class InfoSchemaGenerator(compiler.SchemaGenerator):
 class InfoIdentifierPreparer(compiler.IdentifierPreparer):
     def __init__(self, dialect):
         super(InfoIdentifierPreparer, self).__init__(dialect, initial_quote="'")
-    
+
     def _requires_quotes(self, value):
         return False
 
index be1c4e70c155c333e7d0db6b3de6d0af6c797e27..4aa59bc28c5aacad00e8aaae3211c8fdb64330e1 100644 (file)
@@ -953,7 +953,7 @@ class MaxDBIdentifierPreparer(compiler.IdentifierPreparer):
 class MaxDBSchemaGenerator(compiler.SchemaGenerator):
     def get_column_specification(self, column, **kw):
         colspec = [self.preparer.format_column(column),
-                   column.type.dialect_impl(self.dialect, _for_ddl=True).get_col_spec()]
+                   column.type.dialect_impl(self.dialect, _for_ddl=column).get_col_spec()]
 
         if not column.nullable:
             colspec.append('NOT NULL')
index 7aeedad981b50c8db8104f91d77c61def7652220..572139d489017758fafc84f8b62e9954646353ab 100644 (file)
@@ -20,7 +20,7 @@
   Note that the start & increment values for sequences are optional
   and will default to 1,1.
 
-* Support for ``SET IDENTITY_INSERT ON`` mode (automagic on / off for 
+* Support for ``SET IDENTITY_INSERT ON`` mode (automagic on / off for
   ``INSERT`` s)
 
 * Support for auto-fetching of ``@@IDENTITY/@@SCOPE_IDENTITY()`` on ``INSERT``
@@ -34,7 +34,7 @@ Known issues / TODO:
 
 * pymssql has problems with binary and unicode data that this module
   does **not** work around
-  
+
 """
 
 import datetime, random, warnings, re, sys, operator
@@ -44,7 +44,7 @@ from sqlalchemy.sql import compiler, expression, operators as sqlops
 from sqlalchemy.engine import default, base
 from sqlalchemy import types as sqltypes
 from sqlalchemy.util import Decimal as _python_Decimal
-    
+
 MSSQL_RESERVED_WORDS = util.Set(['function'])
 
 class MSNumeric(sqltypes.Numeric):
@@ -67,9 +67,9 @@ class MSNumeric(sqltypes.Numeric):
                 # Not sure that this exception is needed
                 return value
             else:
-                return str(value) 
+                return str(value)
         return process
-        
+
     def get_col_spec(self):
         if self.precision is None:
             return "NUMERIC"
@@ -87,7 +87,7 @@ class MSFloat(sqltypes.Float):
                 return str(value)
             return None
         return process
-        
+
 class MSInteger(sqltypes.Integer):
     def get_col_spec(self):
         return "INTEGER"
@@ -123,7 +123,7 @@ class MSTime(sqltypes.Time):
 
     def __init__(self, *a, **kw):
         super(MSTime, self).__init__(False)
-    
+
     def get_col_spec(self):
         return "DATETIME"
 
@@ -135,7 +135,7 @@ class MSTime(sqltypes.Time):
                 value = datetime.datetime.combine(self.__zero_date, value)
             return value
         return process
-    
+
     def result_processor(self, dialect):
         def process(value):
             if type(value) is datetime.datetime:
@@ -144,7 +144,7 @@ class MSTime(sqltypes.Time):
                 return datetime.time(0, 0, 0)
             return value
         return process
-        
+
 class MSDateTime_adodbapi(MSDateTime):
     def result_processor(self, dialect):
         def process(value):
@@ -154,7 +154,7 @@ class MSDateTime_adodbapi(MSDateTime):
                 return datetime.datetime(value.year, value.month, value.day)
             return value
         return process
-        
+
 class MSDateTime_pyodbc(MSDateTime):
     def bind_processor(self, dialect):
         def process(value):
@@ -162,7 +162,7 @@ class MSDateTime_pyodbc(MSDateTime):
                 return datetime.datetime(value.year, value.month, value.day)
             return value
         return process
-        
+
 class MSDate_pyodbc(MSDate):
     def bind_processor(self, dialect):
         def process(value):
@@ -170,7 +170,7 @@ class MSDate_pyodbc(MSDate):
                 return datetime.datetime(value.year, value.month, value.day)
             return value
         return process
-    
+
     def result_processor(self, dialect):
         def process(value):
             # pyodbc returns SMALLDATETIME values as datetime.datetime(). truncate it back to datetime.date()
@@ -178,7 +178,7 @@ class MSDate_pyodbc(MSDate):
                 return value.date()
             return value
         return process
-        
+
 class MSDate_pymssql(MSDate):
     def result_processor(self, dialect):
         def process(value):
@@ -187,11 +187,11 @@ class MSDate_pymssql(MSDate):
                 return value.date()
             return value
         return process
-        
+
 class MSText(sqltypes.Text):
     def get_col_spec(self):
         if self.dialect.text_as_varchar:
-            return "VARCHAR(max)"            
+            return "VARCHAR(max)"
         else:
             return "TEXT"
 
@@ -238,7 +238,7 @@ class MSBoolean(sqltypes.Boolean):
                 return None
             return value and True or False
         return process
-    
+
     def bind_processor(self, dialect):
         def process(value):
             if value is True:
@@ -250,27 +250,27 @@ class MSBoolean(sqltypes.Boolean):
             else:
                 return value and True or False
         return process
-        
+
 class MSTimeStamp(sqltypes.TIMESTAMP):
     def get_col_spec(self):
         return "TIMESTAMP"
-        
+
 class MSMoney(sqltypes.TypeEngine):
     def get_col_spec(self):
         return "MONEY"
-        
+
 class MSSmallMoney(MSMoney):
     def get_col_spec(self):
         return "SMALLMONEY"
-        
+
 class MSUniqueIdentifier(sqltypes.TypeEngine):
     def get_col_spec(self):
         return "UNIQUEIDENTIFIER"
-        
+
 class MSVariant(sqltypes.TypeEngine):
     def get_col_spec(self):
         return "SQL_VARIANT"
-        
+
 def descriptor():
     return {'name':'mssql',
     'description':'MSSQL',
@@ -297,7 +297,7 @@ class MSSQLExecutionContext(default.DefaultExecutionContext):
     def pre_exec(self):
         """MS-SQL has a special mode for inserting non-NULL values
         into IDENTITY columns.
-        
+
         Activate it if the feature is turned on and needed.
         """
         if self.compiled.isinsert:
@@ -328,7 +328,7 @@ class MSSQLExecutionContext(default.DefaultExecutionContext):
         and fetch recently inserted IDENTIFY values (works only for
         one column).
         """
-        
+
         if self.compiled.isinsert and self.HASIDENT and not self.IINSERT:
             if not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None:
                 if self.dialect.use_scope_identity:
@@ -339,17 +339,17 @@ class MSSQLExecutionContext(default.DefaultExecutionContext):
                 self._last_inserted_ids = [int(row[0])] + self._last_inserted_ids[1:]
                 # print "LAST ROW ID", self._last_inserted_ids
         super(MSSQLExecutionContext, self).post_exec()
-    
+
     _ms_is_select = re.compile(r'\s*(?:SELECT|sp_columns)',
                                re.I | re.UNICODE)
-    
+
     def returns_rows_text(self, statement):
         return self._ms_is_select.match(statement) is not None
 
 
-class MSSQLExecutionContext_pyodbc (MSSQLExecutionContext):    
+class MSSQLExecutionContext_pyodbc (MSSQLExecutionContext):
     def pre_exec(self):
-        """where appropriate, issue "select scope_identity()" in the same statement"""                
+        """where appropriate, issue "select scope_identity()" in the same statement"""
         super(MSSQLExecutionContext_pyodbc, self).pre_exec()
         if self.compiled.isinsert and self.HASIDENT and (not self.IINSERT) \
                 and len(self.parameters) == 1 and self.dialect.use_scope_identity:
@@ -418,7 +418,7 @@ class MSSQLDialect(default.DefaultDialect):
             return dialect(*args, **kwargs)
         else:
             return object.__new__(cls, *args, **kwargs)
-                
+
     def __init__(self, auto_identity_insert=True, **params):
         super(MSSQLDialect, self).__init__(**params)
         self.auto_identity_insert = auto_identity_insert
@@ -442,7 +442,7 @@ class MSSQLDialect(default.DefaultDialect):
             else:
                 raise ImportError('No DBAPI module detected for MSSQL - please install pyodbc, pymssql, or adodbapi')
     dbapi = classmethod(dbapi)
-    
+
     def create_connect_args(self, url):
         opts = url.translate_connect_args(username='user')
         opts.update(url.query)
@@ -477,20 +477,20 @@ class MSSQLDialect(default.DefaultDialect):
 
     def last_inserted_ids(self):
         return self.context.last_inserted_ids
-            
+
     def do_execute(self, cursor, statement, params, context=None, **kwargs):
         if params == {}:
             params = ()
         try:
             super(MSSQLDialect, self).do_execute(cursor, statement, params, context=context, **kwargs)
-        finally:        
+        finally:
             if context.IINSERT:
                 cursor.execute("SET IDENTITY_INSERT %s OFF" % self.identifier_preparer.format_table(context.compiled.statement.table))
-         
+
     def do_executemany(self, cursor, statement, params, context=None, **kwargs):
         try:
             super(MSSQLDialect, self).do_executemany(cursor, statement, params, context=context, **kwargs)
-        finally:        
+        finally:
             if context.IINSERT:
                 cursor.execute("SET IDENTITY_INSERT %s OFF" % self.identifier_preparer.format_table(context.compiled.statement.table))
 
@@ -511,7 +511,7 @@ class MSSQLDialect(default.DefaultDialect):
     def raw_connection(self, connection):
         """Pull the raw pymmsql connection out--sensative to "pool.ConnectionFairy" and pymssql.pymssqlCnx Classes"""
         try:
-            # TODO: probably want to move this to individual dialect subclasses to 
+            # TODO: probably want to move this to individual dialect subclasses to
             # save on the exception throw + simplify
             return connection.connection.__dict__['_pymssqlCnx__cnx']
         except:
@@ -536,14 +536,14 @@ class MSSQLDialect(default.DefaultDialect):
                        and sql.and_(columns.c.table_name==tablename, columns.c.table_schema==current_schema)
                        or columns.c.table_name==tablename,
                    )
-        
+
         c = connection.execute(s)
         row  = c.fetchone()
         return row is not None
-        
+
     def reflecttable(self, connection, table, include_columns):
         import sqlalchemy.databases.information_schema as ischema
-        
+
         # Get base columns
         if table.schema is not None:
             current_schema = table.schema
@@ -556,7 +556,7 @@ class MSSQLDialect(default.DefaultDialect):
                        and sql.and_(columns.c.table_name==table.name, columns.c.table_schema==current_schema)
                        or columns.c.table_name==table.name,
                    order_by=[columns.c.ordinal_position])
-        
+
         c = connection.execute(s)
         found_table = False
         while True:
@@ -565,9 +565,9 @@ class MSSQLDialect(default.DefaultDialect):
                 break
             found_table = True
             (name, type, nullable, charlen, numericprec, numericscale, default) = (
-                row[columns.c.column_name], 
-                row[columns.c.data_type], 
-                row[columns.c.is_nullable] == 'YES', 
+                row[columns.c.column_name],
+                row[columns.c.data_type],
+                row[columns.c.is_nullable] == 'YES',
                 row[columns.c.character_maximum_length],
                 row[columns.c.numeric_precision],
                 row[columns.c.numeric_scale],
@@ -582,21 +582,21 @@ class MSSQLDialect(default.DefaultDialect):
                     args.append(a)
             coltype = self.ischema_names.get(type, None)
             if coltype == MSString and charlen == -1:
-                coltype = MSText()                
+                coltype = MSText()
             else:
                 if coltype is None:
                     warnings.warn(RuntimeWarning("Did not recognize type '%s' of column '%s'" % (type, name)))
                     coltype = sqltypes.NULLTYPE
-                    
+
                 elif coltype in (MSNVarchar, AdoMSNVarchar) and charlen == -1:
                     args[0] = None
                 coltype = coltype(*args)
             colargs= []
             if default is not None:
                 colargs.append(schema.PassiveDefault(sql.text(default)))
-                
+
             table.append_column(schema.Column(name, coltype, nullable=nullable, autoincrement=False, *colargs))
-        
+
         if not found_table:
             raise exceptions.NoSuchTableError(table.name)
 
@@ -631,7 +631,7 @@ class MSSQLDialect(default.DefaultDialect):
         # Add constraints
         RR = self.uppercase_table(ischema.ref_constraints)    #information_schema.referential_constraints
         TC = self.uppercase_table(ischema.constraints)        #information_schema.table_constraints
-        C  = self.uppercase_table(ischema.pg_key_constraints).alias('C') #information_schema.constraint_column_usage: the constrained column 
+        C  = self.uppercase_table(ischema.pg_key_constraints).alias('C') #information_schema.constraint_column_usage: the constrained column
         R  = self.uppercase_table(ischema.pg_key_constraints).alias('R') #information_schema.constraint_column_usage: the referenced column
 
         # Primary key constraints
@@ -672,7 +672,7 @@ class MSSQLDialect(default.DefaultDialect):
 class MSSQLDialect_pymssql(MSSQLDialect):
     supports_sane_rowcount = False
     max_identifier_length = 30
-    
+
     def import_dbapi(cls):
         import pymssql as module
         # pymmsql doesn't have a Binary method.  we use string
@@ -680,7 +680,7 @@ class MSSQLDialect_pymssql(MSSQLDialect):
         module.Binary = lambda st: str(st)
         return module
     import_dbapi = classmethod(import_dbapi)
-    
+
     colspecs = MSSQLDialect.colspecs.copy()
     colspecs[sqltypes.Date] = MSDate_pymssql
 
@@ -723,7 +723,7 @@ class MSSQLDialect_pymssql(MSSQLDialect):
 ##    This code is leftover from the initial implementation, for reference
 ##    def do_begin(self, connection):
 ##        """implementations might want to put logic here for turning autocommit on/off, etc."""
-##        pass  
+##        pass
 
 ##    def do_rollback(self, connection):
 ##        """implementations might want to put logic here for turning autocommit on/off, etc."""
@@ -740,7 +740,7 @@ class MSSQLDialect_pymssql(MSSQLDialect):
 
 ##    def do_commit(self, connection):
 ##        """implementations might want to put logic here for turning autocommit on/off, etc.
-##            do_commit is set for pymmsql connections--ADO seems to handle transactions without any issue 
+##            do_commit is set for pymmsql connections--ADO seems to handle transactions without any issue
 ##        """
 ##        # ADO Uses Implicit Transactions.
 ##        # This is very pymssql specific.  We use this instead of its commit, because it hangs on failed rollbacks.
@@ -757,7 +757,7 @@ class MSSQLDialect_pyodbc(MSSQLDialect):
     # PyODBC unicode is broken on UCS-4 builds
     supports_unicode = sys.maxunicode == 65535
     supports_unicode_statements = supports_unicode
-    
+
     def __init__(self, **params):
         super(MSSQLDialect_pyodbc, self).__init__(**params)
         # whether use_scope_identity will work depends on the version of pyodbc
@@ -766,12 +766,12 @@ class MSSQLDialect_pyodbc(MSSQLDialect):
             self.use_scope_identity = hasattr(pyodbc.Cursor, 'nextset')
         except:
             pass
-        
+
     def import_dbapi(cls):
         import pyodbc as module
         return module
     import_dbapi = classmethod(import_dbapi)
-    
+
     colspecs = MSSQLDialect.colspecs.copy()
     if supports_unicode:
         colspecs[sqltypes.Unicode] = AdoMSNVarchar
@@ -883,10 +883,10 @@ class MSSQLCompiler(compiler.DefaultCompiler):
             raise exceptions.InvalidRequestError('MSSQL does not support LIMIT with an offset')
         return s
 
-    def limit_clause(self, select):    
+    def limit_clause(self, select):
         # Limit in mssql is after the select keyword
         return ""
-            
+
     def _schema_aliased_table(self, table):
         if getattr(table, 'schema', None) is not None:
             if table not in self.tablealiases:
@@ -894,7 +894,7 @@ class MSSQLCompiler(compiler.DefaultCompiler):
             return self.tablealiases[table]
         else:
             return None
-            
+
     def visit_table(self, table, mssql_aliased=False, **kwargs):
         if mssql_aliased:
             return super(MSSQLCompiler, self).visit_table(table, **kwargs)
@@ -905,7 +905,7 @@ class MSSQLCompiler(compiler.DefaultCompiler):
             return self.process(alias, mssql_aliased=True, **kwargs)
         else:
             return super(MSSQLCompiler, self).visit_table(table, **kwargs)
+
     def visit_alias(self, alias, **kwargs):
         # translate for schema-qualified table aliases
         self.tablealiases[alias.original] = alias
@@ -956,8 +956,8 @@ class MSSQLCompiler(compiler.DefaultCompiler):
 
 class MSSQLSchemaGenerator(compiler.SchemaGenerator):
     def get_column_specification(self, column, **kwargs):
-        colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect, _for_ddl=True).get_col_spec()
-        
+        colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect, _for_ddl=column).get_col_spec()
+
         # install a IDENTITY Sequence if we have an implicit IDENTITY column
         if (not getattr(column.table, 'has_sequence', False)) and column.primary_key and \
                 column.autoincrement and isinstance(column.type, sqltypes.Integer) and not column.foreign_keys:
@@ -974,7 +974,7 @@ class MSSQLSchemaGenerator(compiler.SchemaGenerator):
             default = self.get_column_default_string(column)
             if default is not None:
                 colspec += " DEFAULT " + default
-        
+
         return colspec
 
 class MSSQLSchemaDropper(compiler.SchemaDropper):
index 588ea341dc0876fb34c270b969fd68a8c23133f3..17defbb70c8be9655616d554cb475a4d838470e1 100644 (file)
@@ -1956,7 +1956,8 @@ class MySQLSchemaGenerator(compiler.SchemaGenerator):
         """Builds column DDL."""
 
         colspec = [self.preparer.format_column(column),
-                   column.type.dialect_impl(self.dialect, _for_ddl=True).get_col_spec()]
+                   column.type.dialect_impl(self.dialect,
+                                            _for_ddl=column).get_col_spec()]
 
         default = self.get_column_default_string(column)
         if default is not None:
index 7ac8f89518b3cdced178e5bb0ea6971d7c7d7d57..394aba178cb15e54bd9e9bfb5c6a188720e863ba 100644 (file)
@@ -44,11 +44,11 @@ class OracleDate(sqltypes.Date):
             else:
                 return value.date()
         return process
-        
+
 class OracleDateTime(sqltypes.DateTime):
     def get_col_spec(self):
         return "DATE"
-        
+
     def result_processor(self, dialect):
         def process(value):
             if value is None or isinstance(value,datetime.datetime):
@@ -58,7 +58,7 @@ class OracleDateTime(sqltypes.DateTime):
                 return datetime.datetime(value.year,value.month,
                     value.day,value.hour, value.minute, value.second)
         return process
-        
+
 # Note:
 # Oracle DATE == DATETIME
 # Oracle does not allow milliseconds in DATE
@@ -135,7 +135,7 @@ class OracleBinary(sqltypes.Binary):
             else:
                 return value
         return process
-        
+
 class OracleBoolean(sqltypes.Boolean):
     def get_col_spec(self):
         return "SMALLINT"
@@ -146,7 +146,7 @@ class OracleBoolean(sqltypes.Boolean):
                 return None
             return value and True or False
         return process
-        
+
     def bind_processor(self, dialect):
         def process(value):
             if value is True:
@@ -158,7 +158,7 @@ class OracleBoolean(sqltypes.Boolean):
             else:
                 return value and True or False
         return process
-        
+
 colspecs = {
     sqltypes.Integer : OracleInteger,
     sqltypes.Smallinteger : OracleSmallInteger,
@@ -230,7 +230,7 @@ class OracleExecutionContext(default.DefaultExecutionContext):
                 type_code = column[1]
                 if type_code in self.dialect.ORACLE_BINARY_TYPES:
                     return base.BufferedColumnResultProxy(self)
-        
+
         return base.ResultProxy(self)
 
 class OracleDialect(default.DefaultDialect):
@@ -258,9 +258,9 @@ class OracleDialect(default.DefaultDialect):
             # etc. leads to a little too much magic, reflection doesn't know if it should
             # expect encoded strings or unicodes, etc.
             self.dbapi_type_map = {
-                self.dbapi.CLOB: OracleText(), 
-                self.dbapi.BLOB: OracleBinary(), 
-                self.dbapi.BINARY: OracleRaw(), 
+                self.dbapi.CLOB: OracleText(),
+                self.dbapi.BLOB: OracleBinary(),
+                self.dbapi.BINARY: OracleRaw(),
             }
             self.ORACLE_BINARY_TYPES = [getattr(self.dbapi, k) for k in ["BFILE", "CLOB", "NCLOB", "BLOB"] if hasattr(self.dbapi, k)]
 
@@ -268,7 +268,7 @@ class OracleDialect(default.DefaultDialect):
         import cx_Oracle
         return cx_Oracle
     dbapi = classmethod(dbapi)
-    
+
     def create_connect_args(self, url):
         dialect_opts = dict(url.query)
         for opt in ('use_ansi', 'auto_setinputsizes', 'auto_convert_lobs',
@@ -340,10 +340,10 @@ class OracleDialect(default.DefaultDialect):
 
     def do_begin_twophase(self, connection, xid):
         connection.connection.begin(*xid)
-        
+
     def do_prepare_twophase(self, connection, xid):
         connection.connection.prepare()
-        
+
     def do_rollback_twophase(self, connection, xid, is_prepared=True, recover=False):
         self.do_rollback(connection.connection)
 
@@ -352,7 +352,7 @@ class OracleDialect(default.DefaultDialect):
 
     def do_recover_twophase(self, connection):
         pass
-        
+
     def create_execution_context(self, *args, **kwargs):
         return OracleExecutionContext(self, *args, **kwargs)
 
@@ -433,7 +433,7 @@ class OracleDialect(default.DefaultDialect):
             return name.lower().decode(self.encoding)
         else:
             return name.decode(self.encoding)
-    
+
     def _denormalize_name(self, name):
         if name is None:
             return None
@@ -441,7 +441,7 @@ class OracleDialect(default.DefaultDialect):
             return name.upper().encode(self.encoding)
         else:
             return name.encode(self.encoding)
-    
+
     def get_default_schema_name(self,connection):
         try:
             return self._default_schema_name
@@ -469,7 +469,7 @@ class OracleDialect(default.DefaultDialect):
 
         c = connection.execute ("select COLUMN_NAME, DATA_TYPE, DATA_LENGTH, DATA_PRECISION, DATA_SCALE, NULLABLE, DATA_DEFAULT from ALL_TAB_COLUMNS%(dblink)s where TABLE_NAME = :table_name and OWNER = :owner" % {'dblink':dblink}, {'table_name':actual_name, 'owner':owner})
 
-                
+
         while True:
             row = c.fetchone()
             if row is None:
@@ -570,7 +570,7 @@ class _OuterJoinColumn(sql.ClauseElement):
     __visit_name__ = 'outer_join_column'
     def __init__(self, column):
         self.column = column
-        
+
 class OracleCompiler(compiler.DefaultCompiler):
     """Oracle compiler modifies the lexical structure of Select
     statements to work under non-ANSI configured Oracle databases, if
@@ -587,7 +587,7 @@ class OracleCompiler(compiler.DefaultCompiler):
     def __init__(self, *args, **kwargs):
         super(OracleCompiler, self).__init__(*args, **kwargs)
         self.__wheres = {}
-        
+
     def default_from(self):
         """Called when a ``SELECT`` statement has no froms, and no ``FROM`` clause is to be appended.
 
@@ -612,7 +612,7 @@ class OracleCompiler(compiler.DefaultCompiler):
                         binary.left = _OuterJoinColumn(binary.left)
                     elif binary.right.table is join.right:
                         binary.right = _OuterJoinColumn(binary.right)
-        
+
         if join.isouter:
             if where is not None:
                 self.__wheres[join.left] = self.__wheres[parentjoin] = (sql.and_(VisitOn().traverse(join.onclause, clone=True), where), parentjoin)
@@ -623,24 +623,24 @@ class OracleCompiler(compiler.DefaultCompiler):
                 self.__wheres[join.left] = self.__wheres[parentjoin] = (sql.and_(join.onclause, where), parentjoin)
             else:
                 self.__wheres[join.left] = self.__wheres[join] = (join.onclause, join)
-            
+
         return self.process(join.left, asfrom=True) + ", " + self.process(join.right, asfrom=True)
-    
+
     def get_whereclause(self, f):
         if f in self.__wheres:
             return self.__wheres[f][0]
         else:
             return None
-            
+
     def visit_outer_join_column(self, vc):
         return self.process(vc.column) + "(+)"
-        
+
     def visit_sequence(self, seq):
         return self.dialect.identifier_preparer.format_sequence(seq) + ".nextval"
-        
+
     def visit_alias(self, alias, asfrom=False, **kwargs):
         """Oracle doesn't like ``FROM table AS alias``.  Is the AS standard SQL??"""
-        
+
         if asfrom:
             return self.process(alias.original, asfrom=asfrom, **kwargs) + " " + self.preparer.format_alias(alias, self._anonymize(alias.name))
         else:
@@ -661,11 +661,11 @@ class OracleCompiler(compiler.DefaultCompiler):
             if not orderby:
                 orderby = list(select.oid_column.proxies)[0]
                 orderby = self.process(orderby)
-                
+
             oldselect = select
             select = select.column(sql.literal_column("ROW_NUMBER() OVER (ORDER BY %s)" % orderby).label("ora_rn")).order_by(None)
             select._oracle_visit = True
-                
+
             limitselect = sql.select([c for c in select.c if c.key!='ora_rn'])
             if select._offset is not None:
                 limitselect.append_whereclause("ora_rn>%d" % select._offset)
@@ -690,7 +690,7 @@ class OracleCompiler(compiler.DefaultCompiler):
 class OracleSchemaGenerator(compiler.SchemaGenerator):
     def get_column_specification(self, column, **kwargs):
         colspec = self.preparer.format_column(column)
-        colspec += " " + column.type.dialect_impl(self.dialect, _for_ddl=True).get_col_spec()
+        colspec += " " + column.type.dialect_impl(self.dialect, _for_ddl=column).get_col_spec()
         default = self.get_column_default_string(column)
         if default is not None:
             colspec += " DEFAULT " + default
@@ -719,7 +719,7 @@ class OracleIdentifierPreparer(compiler.IdentifierPreparer):
         name = re.sub(r'^_+', '', savepoint.ident)
         return super(OracleIdentifierPreparer, self).format_savepoint(savepoint, name)
 
-    
+
 dialect = OracleDialect
 dialect.statement_compiler = OracleCompiler
 dialect.schemagenerator = OracleSchemaGenerator
index 6d29430b98d98eacbb745de358f2ed5c440b25be..62372698026b737fe306a4ce4ed5a4b4738ff396 100644 (file)
@@ -14,8 +14,8 @@ option to the Index constructor::
 PostgreSQL 8.2+ supports returning a result set from inserts and updates.
 To use this pass the column/expression list to the postgres_returning
 parameter when creating the queries::
-    
-  raises = tbl.update(empl.c.sales > 100, values=dict(salary=empl.c.salary * 1.1), 
+
+  raises = tbl.update(empl.c.sales > 100, values=dict(salary=empl.c.salary * 1.1),
     postgres_returning=[empl.c.id, empl.c.salary]).execute().fetchall()
 """
 
@@ -31,7 +31,7 @@ from sqlalchemy import types as sqltypes
 class PGInet(sqltypes.TypeEngine):
     def get_col_spec(self):
         return "INET"
-    
+
 class PGMacAddr(sqltypes.TypeEngine):
     def get_col_spec(self):
         return "MACADDR"
@@ -56,7 +56,7 @@ class PGNumeric(sqltypes.Numeric):
                 else:
                     return value
             return process
-        
+
 class PGFloat(sqltypes.Float):
     def get_col_spec(self):
         if not self.precision:
@@ -118,13 +118,13 @@ class PGArray(sqltypes.Concatenable, sqltypes.TypeEngine):
         if isinstance(item_type, type):
             item_type = item_type()
         self.item_type = item_type
-        
+
     def dialect_impl(self, dialect, **kwargs):
         impl = self.__class__.__new__(self.__class__)
         impl.__dict__.update(self.__dict__)
         impl.item_type = self.item_type.dialect_impl(dialect)
         return impl
-        
+
     def bind_processor(self, dialect):
         item_proc = self.item_type.bind_processor(dialect)
         def process(value):
@@ -140,7 +140,7 @@ class PGArray(sqltypes.Concatenable, sqltypes.TypeEngine):
                         return item
             return [convert_item(item) for item in value]
         return process
-        
+
     def result_processor(self, dialect):
         item_proc = self.item_type.result_processor(dialect)
         def process(value):
@@ -242,15 +242,15 @@ class PGExecutionContext(default.DefaultExecutionContext):
         m = SELECT_RE.match(statement)
         return m and (not m.group(1) or (RETURNING_RE.search(statement)
            and RETURNING_QUOTED_RE.match(statement)))
-    
+
     def returns_rows_compiled(self, compiled):
         return isinstance(compiled.statement, expression.Selectable) or \
             (
                 (compiled.isupdate or compiled.isinsert) and "postgres_returning" in compiled.statement.kwargs
             )
-        
+
     def create_cursor(self):
-        # executing a default or Sequence standalone creates an execution context without a statement.  
+        # executing a default or Sequence standalone creates an execution context without a statement.
         # so slightly hacky "if no statement assume we're server side" logic
         # TODO: dont use regexp if Compiled is used ?
         self.__is_server_side = \
@@ -272,7 +272,7 @@ class PGExecutionContext(default.DefaultExecutionContext):
             return base.BufferedRowResultProxy(self)
         else:
             return base.ResultProxy(self)
-    
+
     def post_exec(self):
         if self.compiled.isinsert and self.last_inserted_ids is None:
             if not self.dialect.use_oids:
@@ -285,7 +285,7 @@ class PGExecutionContext(default.DefaultExecutionContext):
                     row = self.connection.execute(s).fetchone()
                 self._last_inserted_ids = [v for v in row]
         super(PGExecutionContext, self).post_exec()
-        
+
 class PGDialect(default.DefaultDialect):
     supports_alter = True
     supports_unicode_statements = False
@@ -300,12 +300,12 @@ class PGDialect(default.DefaultDialect):
         self.use_oids = use_oids
         self.server_side_cursors = server_side_cursors
         self.paramstyle = 'pyformat'
-        
+
     def dbapi(cls):
         import psycopg2 as psycopg
         return psycopg
     dbapi = classmethod(dbapi)
-    
+
     def create_connect_args(self, url):
         opts = url.translate_connect_args(username='user')
         if 'port' in opts:
@@ -390,10 +390,10 @@ class PGDialect(default.DefaultDialect):
             return "losed the connection unexpectedly" in str(e)
         else:
             return False
-        
+
     def table_names(self, connection, schema):
         s = """
-        SELECT relname 
+        SELECT relname
         FROM pg_class c
         WHERE relkind = 'r'
           AND '%(schema)s' = (select nspname from pg_namespace n where n.oid = c.relnamespace)
@@ -417,7 +417,7 @@ class PGDialect(default.DefaultDialect):
         else:
             schema_where_clause = "pg_catalog.pg_table_is_visible(c.oid)"
             schemaname = None
-            
+
         SQL_COLS = """
             SELECT a.attname,
               pg_catalog.format_type(a.atttypid, a.atttypmod),
@@ -447,11 +447,11 @@ class PGDialect(default.DefaultDialect):
             raise exceptions.NoSuchTableError(table.name)
 
         domains = self._load_domains(connection)
-        
+
         for name, format_type, default, notnull, attnum, table_oid in rows:
             if include_columns and name not in include_columns:
                 continue
-            
+
             ## strip (30) from character varying(30)
             attype = re.search('([^\([]+)', format_type).group(1)
             nullable = not notnull
@@ -563,7 +563,7 @@ class PGDialect(default.DefaultDialect):
             if referred_schema:
                 referred_schema = preparer._unquote_identifier(referred_schema)
             elif table.schema is not None and table.schema == self.get_default_schema_name(connection):
-                # no schema (i.e. its the default schema), and the table we're 
+                # no schema (i.e. its the default schema), and the table we're
                 # reflecting has the default schema explicit, then use that.
                 # i.e. try to use the user's conventions
                 referred_schema = table.schema
@@ -582,7 +582,7 @@ class PGDialect(default.DefaultDialect):
                     refspec.append(".".join([referred_table, column]))
 
             table.append_constraint(schema.ForeignKeyConstraint(constrained_columns, refspec, conname))
-                
+
     def _load_domains(self, connection):
         ## Load data types for domains:
         SQL_DOMAINS = """
@@ -606,7 +606,7 @@ class PGDialect(default.DefaultDialect):
             ## strip (30) from character varying(30)
             attype = re.search('([^\(]+)', domain['attype']).group(1)
             if domain['visible']:
-                # 'visible' just means whether or not the domain is in a 
+                # 'visible' just means whether or not the domain is in a
                 # schema that's on the search path -- or not overriden by
                 # a schema with higher presedence. If it's not visible,
                 # it will be prefixed with the schema-name when it's used.
@@ -617,9 +617,9 @@ class PGDialect(default.DefaultDialect):
             domains[name] = {'attype':attype, 'nullable': domain['nullable'], 'default': domain['default']}
 
         return domains
-        
-        
-        
+
+
+
 class PGCompiler(compiler.DefaultCompiler):
     operators = compiler.DefaultCompiler.operators.copy()
     operators.update(
@@ -633,7 +633,7 @@ class PGCompiler(compiler.DefaultCompiler):
             return None
         else:
             return "nextval('%s')" % self.preparer.format_sequence(seq)
-        
+
     def limit_clause(self, select):
         text = ""
         if select._limit is not None:
@@ -699,7 +699,7 @@ class PGSchemaGenerator(compiler.SchemaGenerator):
             else:
                 colspec += " SERIAL"
         else:
-            colspec += " " + column.type.dialect_impl(self.dialect, _for_ddl=True).get_col_spec()
+            colspec += " " + column.type.dialect_impl(self.dialect, _for_ddl=column).get_col_spec()
             default = self.get_column_default_string(column)
             if default is not None:
                 colspec += " DEFAULT " + default
@@ -712,7 +712,7 @@ class PGSchemaGenerator(compiler.SchemaGenerator):
         if not sequence.optional and (not self.checkfirst or not self.dialect.has_sequence(self.connection, sequence.name)):
             self.append("CREATE SEQUENCE %s" % self.preparer.format_sequence(sequence))
             self.execute()
-    
+
     def visit_index(self, index):
         preparer = self.preparer
         self.append("CREATE ")
index 92645e524f1aee5f399d2dba402bae5979db32bd..36e05a067a15df3db081f6d9b3457e76b8fd9544 100644 (file)
@@ -16,7 +16,7 @@ from sqlalchemy.sql import compiler
 
 
 SELECT_REGEXP = re.compile(r'\s*(?:SELECT|PRAGMA)', re.I | re.UNICODE)
-    
+
 class SLNumeric(sqltypes.Numeric):
     def bind_processor(self, dialect):
         type_ = self.asdecimal and str or float
@@ -57,7 +57,7 @@ class DateTimeMixin(object):
             else:
                 return None
         return process
-        
+
     def _cvt(self, value, dialect):
         if value is None:
             return None
@@ -71,7 +71,7 @@ class DateTimeMixin(object):
 class SLDateTime(DateTimeMixin,sqltypes.DateTime):
     __format__ = "%Y-%m-%d %H:%M:%S"
     __microsecond__ = True
-    
+
     def get_col_spec(self):
         return "TIMESTAMP"
 
@@ -80,7 +80,7 @@ class SLDateTime(DateTimeMixin,sqltypes.DateTime):
             tup = self._cvt(value, dialect)
             return tup and datetime.datetime(*tup)
         return process
-        
+
 class SLDate(DateTimeMixin, sqltypes.Date):
     __format__ = "%Y-%m-%d"
     __microsecond__ = False
@@ -93,7 +93,7 @@ class SLDate(DateTimeMixin, sqltypes.Date):
             tup = self._cvt(value, dialect)
             return tup and datetime.date(*tup[0:3])
         return process
-        
+
 class SLTime(DateTimeMixin, sqltypes.Time):
     __format__ = "%H:%M:%S"
     __microsecond__ = True
@@ -106,7 +106,7 @@ class SLTime(DateTimeMixin, sqltypes.Time):
             tup = self._cvt(value, dialect)
             return tup and datetime.time(*tup[3:7])
         return process
-        
+
 class SLText(sqltypes.Text):
     def get_col_spec(self):
         return "TEXT"
@@ -133,14 +133,14 @@ class SLBoolean(sqltypes.Boolean):
                 return None
             return value and 1 or 0
         return process
-    
+
     def result_processor(self, dialect):
         def process(value):
             if value is None:
                 return None
             return value and True or False
         return process
-        
+
 colspecs = {
     sqltypes.Integer : SLInteger,
     sqltypes.Smallinteger : SLSmallInteger,
@@ -171,7 +171,7 @@ ischema_names = {
     'DATETIME' : SLDateTime,
     'DATE' : SLDate,
     'BLOB' : SLBinary,
-    'BOOL': SLBoolean, 
+    'BOOL': SLBoolean,
     'BOOLEAN': SLBoolean,
 }
 
@@ -190,11 +190,11 @@ class SQLiteExecutionContext(default.DefaultExecutionContext):
 
     def returns_rows_text(self, statement):
         return SELECT_REGEXP.match(statement)
-        
+
 class SQLiteDialect(default.DefaultDialect):
     supports_alter = False
     supports_unicode_statements = True
-    
+
     def __init__(self, **kwargs):
         default.DefaultDialect.__init__(self, default_paramstyle='qmark', **kwargs)
         def vers(num):
@@ -204,7 +204,7 @@ class SQLiteDialect(default.DefaultDialect):
             if sqlite_ver < (2,1,'3'):
                 warnings.warn(RuntimeWarning("The installed version of pysqlite2 (%s) is out-dated, and will cause errors in some cases.  Version 2.1.3 or greater is recommended." % '.'.join([str(subver) for subver in sqlite_ver])))
         self.supports_cast = (self.dbapi is None or vers(self.dbapi.sqlite_version) >= vers("3.2.3"))
-        
+
     def dbapi(cls):
         try:
             from pysqlite2 import dbapi2 as sqlite
@@ -239,7 +239,7 @@ class SQLiteDialect(default.DefaultDialect):
 
     def oid_column_name(self, column):
         return "oid"
-    
+
     def is_disconnect(self, e):
         return isinstance(e, self.dbapi.ProgrammingError) and "Cannot operate on a closed database." in str(e)
 
@@ -283,7 +283,7 @@ class SQLiteDialect(default.DefaultDialect):
             except KeyError:
                 warnings.warn(RuntimeWarning("Did not recognize type '%s' of column '%s'" % (coltype, name)))
                 coltype = sqltypes.NullType
-                
+
             if args is not None:
                 args = re.findall(r'(\d+)', args)
                 coltype = coltype(*[int(a) for a in args])
@@ -386,7 +386,7 @@ class SQLiteCompiler(compiler.DefaultCompiler):
 class SQLiteSchemaGenerator(compiler.SchemaGenerator):
 
     def get_column_specification(self, column, **kwargs):
-        colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect, _for_ddl=True).get_col_spec()
+        colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect, _for_ddl=column).get_col_spec()
         default = self.get_column_default_string(column)
         if default is not None:
             colspec += " DEFAULT " + default
@@ -438,4 +438,3 @@ dialect.statement_compiler = SQLiteCompiler
 dialect.schemagenerator = SQLiteSchemaGenerator
 dialect.schemadropper = SQLiteSchemaDropper
 dialect.preparer = SQLiteIdentifierPreparer
-
index f461125aa3fd12f2e6c031b64b47245dd11fa233..f7c3d8a0f00a7cc6de942b4f745814728a5aa5b1 100644 (file)
@@ -9,7 +9,7 @@
 Sybase database backend.
 
 Known issues / TODO:
-  
+
  * Uses the mx.ODBC driver from egenix (version 2.1.0)
  * The current version of sqlalchemy.databases.sybase only supports
    mx.ODBC.Windows (other platforms such as mx.ODBC.unixODBC still need
@@ -38,11 +38,11 @@ __all__ = [
     'SybaseNumeric', 'SybaseFloat', 'SybaseInteger', 'SybaseBigInteger',
     'SybaseTinyInteger', 'SybaseSmallInteger',
     'SybaseDateTime_mxodbc', 'SybaseDateTime_pyodbc',
-    'SybaseDate_mxodbc', 'SybaseDate_pyodbc', 
-    'SybaseTime_mxodbc', 'SybaseTime_pyodbc', 
+    'SybaseDate_mxodbc', 'SybaseDate_pyodbc',
+    'SybaseTime_mxodbc', 'SybaseTime_pyodbc',
     'SybaseText', 'SybaseString', 'SybaseChar', 'SybaseBinary',
-    'SybaseBoolean', 'SybaseTimeStamp', 'SybaseMoney', 'SybaseSmallMoney', 
-    'SybaseUniqueIdentifier', 
+    'SybaseBoolean', 'SybaseTimeStamp', 'SybaseMoney', 'SybaseSmallMoney',
+    'SybaseUniqueIdentifier',
     ]
 
 
@@ -162,7 +162,7 @@ class SybaseTypeError(sqltypes.TypeEngine):
         def process(value):
             raise exceptions.NotSupportedError("Data type not supported", [value])
         return process
-        
+
     def get_col_spec(self):
         raise exceptions.NotSupportedError("Data type not supported")
 
@@ -180,7 +180,7 @@ class SybaseFloat(sqltypes.FLOAT, SybaseNumeric):
     def __init__(self, precision = 10, asdecimal = False, length = 2, **kwargs):
         super(sqltypes.FLOAT, self).__init__(precision, asdecimal, **kwargs)
         self.length = length
-    
+
     def get_col_spec(self):
         # if asdecimal is True, handle same way as SybaseNumeric
         if self.asdecimal:
@@ -198,7 +198,7 @@ class SybaseFloat(sqltypes.FLOAT, SybaseNumeric):
         if self.asdecimal:
             return SybaseNumeric.result_processor(self, dialect)
         return process
-        
+
 class SybaseInteger(sqltypes.Integer):
     def get_col_spec(self):
         return "INTEGER"
@@ -221,7 +221,7 @@ class SybaseDateTime_mxodbc(sqltypes.DateTime):
 
     def get_col_spec(self):
         return "DATETIME"
-        
+
 class SybaseDateTime_pyodbc(sqltypes.DateTime):
     def __init__(self, *a, **kw):
         super(SybaseDateTime_pyodbc, self).__init__(False)
@@ -242,7 +242,7 @@ class SybaseDateTime_pyodbc(sqltypes.DateTime):
             if value is None:
                 return None
             return value
-        return process    
+        return process
 
 class SybaseDate_mxodbc(sqltypes.Date):
     def __init__(self, *a, **kw):
@@ -261,10 +261,10 @@ class SybaseDate_pyodbc(sqltypes.Date):
 class SybaseTime_mxodbc(sqltypes.Time):
     def __init__(self, *a, **kw):
         super(SybaseTime_mxodbc, self).__init__(False)
-    
+
     def get_col_spec(self):
         return "DATETIME"
-            
+
     def result_processor(self, dialect):
         def process(value):
             if value is None:
@@ -276,10 +276,10 @@ class SybaseTime_mxodbc(sqltypes.Time):
 class SybaseTime_pyodbc(sqltypes.Time):
     def __init__(self, *a, **kw):
         super(SybaseTime_pyodbc, self).__init__(False)
-        
+
     def get_col_spec(self):
         return "DATETIME"
-    
+
     def result_processor(self, dialect):
         def process(value):
             if value is None:
@@ -287,7 +287,7 @@ class SybaseTime_pyodbc(sqltypes.Time):
             # Convert the datetime.datetime back to datetime.time
             return datetime.time(value.hour, value.minute, value.second, value.microsecond)
         return process
-    
+
     def bind_processor(self, dialect):
         def process(value):
             if value is None:
@@ -297,7 +297,7 @@ class SybaseTime_pyodbc(sqltypes.Time):
 
 class SybaseText(sqltypes.Text):
     def get_col_spec(self):
-        return "TEXT"            
+        return "TEXT"
 
 class SybaseString(sqltypes.String):
     def get_col_spec(self):
@@ -321,7 +321,7 @@ class SybaseBoolean(sqltypes.Boolean):
                 return None
             return value and True or False
         return process
-    
+
     def bind_processor(self, dialect):
         def process(value):
             if value is True:
@@ -333,23 +333,23 @@ class SybaseBoolean(sqltypes.Boolean):
             else:
                 return value and True or False
         return process
-        
+
 class SybaseTimeStamp(sqltypes.TIMESTAMP):
     def get_col_spec(self):
         return "TIMESTAMP"
-        
+
 class SybaseMoney(sqltypes.TypeEngine):
     def get_col_spec(self):
         return "MONEY"
-        
+
 class SybaseSmallMoney(SybaseMoney):
     def get_col_spec(self):
         return "SMALLMONEY"
-        
+
 class SybaseUniqueIdentifier(sqltypes.TypeEngine):
     def get_col_spec(self):
         return "UNIQUEIDENTIFIER"
-        
+
 def descriptor():
     return {'name':'sybase',
     'description':'SybaseSQL',
@@ -364,18 +364,18 @@ class SybaseSQLExecutionContext(default.DefaultExecutionContext):
     pass
 
 class SybaseSQLExecutionContext_mxodbc(SybaseSQLExecutionContext):
-    
+
     def __init__(self, dialect, connection, compiled=None, statement=None, parameters=None):
         super(SybaseSQLExecutionContext_mxodbc, self).__init__(dialect, connection, compiled, statement, parameters)
-    
+
     def pre_exec(self):
         super(SybaseSQLExecutionContext_mxodbc, self).pre_exec()
-        
+
     def post_exec(self):
         if self.compiled.isinsert:
             table = self.compiled.statement.table
             # get the inserted values of the primary key
-            
+
             # get any sequence IDs first (using @@identity)
             self.cursor.execute("SELECT @@identity AS lastrowid")
             row = self.cursor.fetchone()
@@ -392,15 +392,15 @@ class SybaseSQLExecutionContext_mxodbc(SybaseSQLExecutionContext):
 class SybaseSQLExecutionContext_pyodbc(SybaseSQLExecutionContext):
     def __init__(self, dialect, connection, compiled=None, statement=None, parameters=None):
         super(SybaseSQLExecutionContext_pyodbc, self).__init__(dialect, connection, compiled, statement, parameters)
-    
+
     def pre_exec(self):
         super(SybaseSQLExecutionContext_pyodbc, self).pre_exec()
-        
+
     def post_exec(self):
         if self.compiled.isinsert:
             table = self.compiled.statement.table
             # get the inserted values of the primary key
-            
+
             # get any sequence IDs first (using @@identity)
             self.cursor.execute("SELECT @@identity AS lastrowid")
             row = self.cursor.fetchone()
@@ -474,13 +474,13 @@ class SybaseSQLDialect(default.DefaultDialect):
             return dialect(*args, **kwargs)
         else:
             return object.__new__(cls, *args, **kwargs)
-                
+
     def __init__(self, **params):
         super(SybaseSQLDialect, self).__init__(**params)
         self.text_as_varchar = False
         # FIXME: what is the default schema for sybase connections (DBA?) ?
         self.set_default_schema_name("dba")
-        
+
     def dbapi(cls, module_name=None):
         if module_name:
             try:
@@ -497,7 +497,7 @@ class SybaseSQLDialect(default.DefaultDialect):
             else:
                 raise ImportError('No DBAPI module detected for SybaseSQL - please install mxodbc')
     dbapi = classmethod(dbapi)
-    
+
     def create_execution_context(self, *args, **kwargs):
         return SybaseSQLExecutionContext(self, *args, **kwargs)
 
@@ -531,7 +531,7 @@ class SybaseSQLDialect(default.DefaultDialect):
 
     def table_names(self, connection, schema):
         """Ignore the schema and the charset for now."""
-        s = sql.select([tables.c.table_name], 
+        s = sql.select([tables.c.table_name],
                        sql.not_(tables.c.table_name.like("SYS%")) and
                        tables.c.creator >= 100
                        )
@@ -541,7 +541,7 @@ class SybaseSQLDialect(default.DefaultDialect):
     def has_table(self, connection, tablename, schema=None):
         # FIXME: ignore schemas for sybase
         s = sql.select([tables.c.table_name], tables.c.table_name == tablename)
-        
+
         c = connection.execute(s)
         row = c.fetchone()
         print "has_table: " + tablename + ": " + str(bool(row is not None))
@@ -554,7 +554,7 @@ class SybaseSQLDialect(default.DefaultDialect):
         else:
             current_schema = self.get_default_schema_name(connection)
 
-        s = sql.select([columns, domains], tables.c.table_name==table.name, from_obj=[columns.join(tables).join(domains)], order_by=[columns.c.column_id])    
+        s = sql.select([columns, domains], tables.c.table_name==table.name, from_obj=[columns.join(tables).join(domains)], order_by=[columns.c.column_id])
 
         c = connection.execute(s)
         found_table = False
@@ -566,7 +566,7 @@ class SybaseSQLDialect(default.DefaultDialect):
             found_table = True
             (name, type, nullable, charlen, numericprec, numericscale, default, primary_key, max_identity, table_id, column_id) = (
                 row[columns.c.column_name],
-                row[domains.c.domain_name], 
+                row[domains.c.domain_name],
                 row[columns.c.nulls] == 'Y',
                 row[columns.c.width],
                 row[domains.c.precision],
@@ -630,17 +630,17 @@ class SybaseSQLDialect(default.DefaultDialect):
         for primary_table in foreignKeys.keys():
             #table.append_constraint(schema.ForeignKeyConstraint(['%s.%s'%(foreign_table, foreign_column)], ['%s.%s'%(primary_table,primary_column)]))
             table.append_constraint(schema.ForeignKeyConstraint(foreignKeys[primary_table][0], foreignKeys[primary_table][1]))
-               
+
         if not found_table:
             raise exceptions.NoSuchTableError(table.name)
 
 
-class SybaseSQLDialect_mxodbc(SybaseSQLDialect):    
+class SybaseSQLDialect_mxodbc(SybaseSQLDialect):
     def __init__(self, **params):
         super(SybaseSQLDialect_mxodbc, self).__init__(**params)
 
         self.dbapi_type_map = {'getdate' : SybaseDate_mxodbc()}
-        
+
     def import_dbapi(cls):
         #import mx.ODBC.Windows as module
         import mxODBC as module
@@ -653,10 +653,10 @@ class SybaseSQLDialect_mxodbc(SybaseSQLDialect):
     colspecs[sqltypes.DateTime] = SybaseDateTime_mxodbc
 
     ischema_names = SybaseSQLDialect.ischema_names.copy()
-    ischema_names['time'] = SybaseTime_mxodbc    
-    ischema_names['date'] = SybaseDate_mxodbc    
-    ischema_names['datetime'] = SybaseDateTime_mxodbc    
-    ischema_names['smalldatetime'] = SybaseDateTime_mxodbc    
+    ischema_names['time'] = SybaseTime_mxodbc
+    ischema_names['date'] = SybaseDate_mxodbc
+    ischema_names['datetime'] = SybaseDateTime_mxodbc
+    ischema_names['smalldatetime'] = SybaseDateTime_mxodbc
 
     def is_disconnect(self, e):
         # FIXME: optimize
@@ -744,7 +744,7 @@ class SybaseSQLCompiler(compiler.DefaultCompiler):
     operators.update({
         sql_operators.mod: lambda x, y: "MOD(%s, %s)" % (x, y),
     })
-    
+
     def bindparam_string(self, name):
         res = super(SybaseSQLCompiler, self).bindparam_string(name)
         if name.lower().startswith('literal'):
@@ -767,7 +767,7 @@ class SybaseSQLCompiler(compiler.DefaultCompiler):
             s += "START AT %s " % (select._offset+1,)
         return s
 
-    def limit_clause(self, select):    
+    def limit_clause(self, select):
         # Limit in sybase is after the select keyword
         return ""
 
@@ -816,7 +816,7 @@ class SybaseSQLSchemaGenerator(compiler.SchemaGenerator):
     def get_column_specification(self, column, **kwargs):
 
         colspec = self.preparer.format_column(column)
-        
+
         if (not getattr(column.table, 'has_sequence', False)) and column.primary_key and \
                 column.autoincrement and isinstance(column.type, sqltypes.Integer):
             if column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional):
@@ -827,8 +827,8 @@ class SybaseSQLSchemaGenerator(compiler.SchemaGenerator):
             #colspec += " numeric(30,0) IDENTITY"
             colspec += " Integer IDENTITY"
         else:
-            colspec += " " + column.type.dialect_impl(self.dialect, _for_ddl=True).get_col_spec()
-            
+            colspec += " " + column.type.dialect_impl(self.dialect, _for_ddl=column).get_col_spec()
+
         if not column.nullable:
             colspec += " NOT NULL"
 
index 2de54804ab95f72e5d965fb8b6f6e30dbcf0341b..d2f1e9ad22673151be3961b15eb5300ea3ef58ca 100644 (file)
@@ -26,7 +26,8 @@ import datetime as dt
 import warnings
 
 from sqlalchemy import exceptions
-from sqlalchemy.util import pickle, Decimal as _python_Decimal, warn_deprecated
+from sqlalchemy.util import pickle, Decimal as _python_Decimal
+import sqlalchemy.util as util
 NoneType = type(None)
 
 class _UserTypeAdapter(type):
@@ -393,7 +394,11 @@ class String(Concatenable, TypeEngine):
     def dialect_impl(self, dialect, **kwargs):
         _for_ddl = kwargs.pop('_for_ddl', False)
         if _for_ddl and self.length is None:
-            warn_deprecated("Using String type with no length for CREATE TABLE is deprecated; use the Text type explicitly")
+            label = util.to_ascii(_for_ddl is True and
+                                  '' or (' for column "%s"' % str(_for_ddl)))
+            util.warn_deprecated(
+                "Using String type with no length for CREATE TABLE "
+                "is deprecated; use the Text type explicitly" + label)
         return TypeEngine.dialect_impl(self, dialect, **kwargs)
 
     def get_search_list(self):
index 01cbf5865e4b69aa62ea145193c7d698319ed9b6..5c391ac3dd73cc93640abacf4db1ce2b04d2c388 100644 (file)
@@ -100,6 +100,16 @@ def to_set(x):
     else:
         return x
 
+def to_ascii(x):
+    """Convert Unicode or a string with unknown encoding into ASCII."""
+
+    if isinstance(x, str):
+        return x.encode('string_escape')
+    elif isinstance(x, unicode):
+        return x.encode('unicode_escape')
+    else:
+        raise TypeError
+
 def flatten_iterator(x):
     """Given an iterator of which further sub-elements may also be
     iterators, flatten the sub-elements into a single iterator.
index 98a8ca0e138d55cb9d094eae9a7207bf23c6f46f..c11f8fcd18399af8522025c9219938cc864bcc0f 100644 (file)
@@ -1,6 +1,6 @@
 import testbase
 import pickleable
-import datetime, os
+import datetime, os, re
 from sqlalchemy import *
 from sqlalchemy import types, exceptions
 from sqlalchemy.sql import operators
@@ -688,6 +688,7 @@ class StringTest(AssertMixin):
             assert False
         except SADeprecationWarning, e:
             assert "Using String type with no length" in str(e)
+            assert re.search(r'\bone\b', str(e))
 
         bar = Table('bar', metadata, Column('one', String(40)))