]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
adjustments to improve readability (indentation to complex conditional expressions...
authorBrad Allen <bradallen137@gmail.com>
Thu, 18 Mar 2010 18:16:01 +0000 (12:16 -0600)
committerBrad Allen <bradallen137@gmail.com>
Thu, 18 Mar 2010 18:16:01 +0000 (12:16 -0600)
lib/sqlalchemy/dialects/mssql/base.py

index e6e9a0e41b82609811696c4c5c9f85df46ab0003..6a35cbc87671caad6e47c7ce31987ffddfa4245c 100644 (file)
@@ -881,18 +881,20 @@ class MSSQLCompiler(compiler.SQLCompiler):
     def visit_select(self, select, **kwargs):
         """Look for ``LIMIT`` and OFFSET in a select statement, and if
         so tries to wrap it in a subquery with ``row_number()`` criterion.
-
         """
         if not getattr(select, '_mssql_visit', None) and select._offset:
             # to use ROW_NUMBER(), an ORDER BY is required.
             orderby = self.process(select._order_by_clause)
             if not orderby:
-                raise exc.InvalidRequestError('MSSQL requires an order_by when using an offset.')
+                raise exc.InvalidRequestError('MSSQL requires an order_by when '
+                                              'using an offset.')
 
             _offset = select._offset
             _limit = select._limit
             select._mssql_visit = True
-            select = select.column(sql.literal_column("ROW_NUMBER() OVER (ORDER BY %s)" % orderby).label("mssql_rn")).order_by(None).alias()
+            select = select.column(sql.literal_column("ROW_NUMBER() OVER (ORDER BY %s)" 
+                                                      % orderby).label("mssql_rn")
+                                   ).order_by(None).alias()
 
             limitselect = sql.select([c for c in select.c if c.key!='mssql_rn'])
             limitselect.append_whereclause("mssql_rn>%d" % _offset)
@@ -932,7 +934,8 @@ class MSSQLCompiler(compiler.SQLCompiler):
         return 'DATEPART("%s", %s)' % (field, self.process(extract.expr, **kw))
 
     def visit_rollback_to_savepoint(self, savepoint_stmt):
-        return "ROLLBACK TRANSACTION %s" % self.preparer.format_savepoint(savepoint_stmt)
+        return ("ROLLBACK TRANSACTION %s" 
+                % self.preparer.format_savepoint(savepoint_stmt))
 
     def visit_column(self, column, result_map=None, **kwargs):
         if column.table is not None and \
@@ -943,27 +946,51 @@ class MSSQLCompiler(compiler.SQLCompiler):
                 converted = expression._corresponding_column_or_error(t, column)
 
                 if result_map is not None:
-                    result_map[column.name.lower()] = (column.name, (column, ), column.type)
+                    result_map[column.name.lower()] = (column.name, (column, ), 
+                                                       column.type)
 
-                return super(MSSQLCompiler, self).visit_column(converted, result_map=None, **kwargs)
+                return super(MSSQLCompiler, self).visit_column(converted, 
+                                                               result_map=None, 
+                                                               **kwargs)
 
-        return super(MSSQLCompiler, self).visit_column(column, result_map=result_map, **kwargs)
+        return super(MSSQLCompiler, self).visit_column(column, 
+                                                       result_map=result_map, 
+                                                       **kwargs)
 
     def visit_binary(self, binary, **kwargs):
         """Move bind parameters to the right-hand side of an operator, where
         possible.
 
         """
-        if isinstance(binary.left, expression._BindParamClause) and binary.operator == operator.eq \
-            and not isinstance(binary.right, expression._BindParamClause):
-            return self.process(expression._BinaryExpression(binary.right, binary.left, binary.operator), **kwargs)
+        if (
+            isinstance(binary.left, expression._BindParamClause) 
+            and binary.operator == operator.eq
+            and not isinstance(binary.right, expression._BindParamClause)
+            ):
+            return self.process(expression._BinaryExpression(binary.right, 
+                                                             binary.left, 
+                                                             binary.operator), 
+                                **kwargs)
         else:
-            if (binary.operator is operator.eq or binary.operator is operator.ne) and (
-                (isinstance(binary.left, expression._FromGrouping) and isinstance(binary.left.element, expression._ScalarSelect)) or \
-                (isinstance(binary.right, expression._FromGrouping) and isinstance(binary.right.element, expression._ScalarSelect)) or \
-                 isinstance(binary.left, expression._ScalarSelect) or isinstance(binary.right, expression._ScalarSelect)):
+            if (
+
+                (binary.operator is operator.eq or binary.operator is operator.ne) 
+                and (
+                    (isinstance(binary.left, expression._FromGrouping) 
+                     and isinstance(binary.left.element, 
+                                    expression._ScalarSelect)) 
+                    or (isinstance(binary.right, expression._FromGrouping) 
+                        and isinstance(binary.right.element, 
+                                       expression._ScalarSelect)) 
+                    or isinstance(binary.left, expression._ScalarSelect) 
+                    or isinstance(binary.right, expression._ScalarSelect)
+                    )
+
+               ):
                 op = binary.operator == operator.eq and "IN" or "NOT IN"
-                return self.process(expression._BinaryExpression(binary.left, binary.right, op), **kwargs)
+                return self.process(expression._BinaryExpression(binary.left,
+                                                                 binary.right, op),
+                                    **kwargs)
             return super(MSSQLCompiler, self).visit_binary(binary, **kwargs)
 
     def returning_clause(self, stmt, returning_cols):
@@ -1044,7 +1071,8 @@ class MSSQLStrictCompiler(MSSQLCompiler):
 
 class MSDDLCompiler(compiler.DDLCompiler):
     def get_column_specification(self, column, **kwargs):
-        colspec = self.preparer.format_column(column) + " " + self.dialect.type_compiler.process(column.type)
+        colspec = (self.preparer.format_column(column) + " " 
+                   + self.dialect.type_compiler.process(column.type))
 
         if column.nullable is not None:
             if not column.nullable or column.primary_key:
@@ -1053,7 +1081,8 @@ class MSDDLCompiler(compiler.DDLCompiler):
                 colspec += " NULL"
         
         if column.table is None:
-            raise exc.InvalidRequestError("mssql requires Table-bound columns in order to generate DDL")
+            raise exc.InvalidRequestError("mssql requires Table-bound columns " 
+                                          "in order to generate DDL")
             
         seq_col = column.table._autoincrement_column
 
@@ -1075,7 +1104,8 @@ class MSDDLCompiler(compiler.DDLCompiler):
     def visit_drop_index(self, drop):
         return "\nDROP INDEX %s.%s" % (
             self.preparer.quote_identifier(drop.element.table.name),
-            self.preparer.quote(self._validate_identifier(drop.element.name, False), drop.element.quote)
+            self.preparer.quote(self._validate_identifier(drop.element.name, False),
+                                drop.element.quote)
             )
 
 
@@ -1083,7 +1113,8 @@ class MSIdentifierPreparer(compiler.IdentifierPreparer):
     reserved_words = RESERVED_WORDS
 
     def __init__(self, dialect):
-        super(MSIdentifierPreparer, self).__init__(dialect, initial_quote='[', final_quote=']')
+        super(MSIdentifierPreparer, self).__init__(dialect, initial_quote='[', 
+                                                   final_quote=']')
 
     def _escape_identifier(self, value):
         return value
@@ -1129,7 +1160,8 @@ class MSDialect(default.DefaultDialect):
         super(MSDialect, self).__init__(**opts)
     
     def do_savepoint(self, connection, name):
-        util.warn("Savepoint support in mssql is experimental and may lead to data loss.")
+        util.warn("Savepoint support in mssql is experimental and "
+                  "may lead to data loss.")
         connection.execute("IF @@TRANCOUNT = 0 BEGIN TRANSACTION")
         connection.execute("SAVE TRANSACTION %s" % name)
 
@@ -1161,7 +1193,8 @@ class MSDialect(default.DefaultDialect):
         return self.schema_name
 
     def table_names(self, connection, schema):
-        s = select([ischema.tables.c.table_name], ischema.tables.c.table_schema==schema)
+        s = select([ischema.tables.c.table_name], 
+                   ischema.tables.c.table_schema==schema)
         return [row[0] for row in connection.execute(s)]
 
 
@@ -1279,7 +1312,8 @@ class MSDialect(default.DefaultDialect):
             coltype = self.ischema_names.get(type, None)
 
             kwargs = {}
-            if coltype in (MSString, MSChar, MSNVarchar, MSNChar, MSText, MSNText, MSBinary, MSVarBinary, sqltypes.LargeBinary):
+            if coltype in (MSString, MSChar, MSNVarchar, MSNChar, MSText, 
+                           MSNText, MSBinary, MSVarBinary, sqltypes.LargeBinary):
                 kwargs['length'] = charlen
                 if collation:
                     kwargs['collation'] = collation
@@ -1308,7 +1342,9 @@ class MSDialect(default.DefaultDialect):
         for col in cols:
             colmap[col['name']] = col
         # We also run an sp_columns to check for identity columns:
-        cursor = connection.execute("sp_columns @table_name = '%s', @table_owner = '%s'" % (tablename, current_schema))
+        cursor = connection.execute("sp_columns @table_name = '%s', "
+                                    "@table_owner = '%s'" 
+                                    % (tablename, current_schema))
         ic = None
         while True:
             row = cursor.fetchone()