]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- endless isinstance(x, str)s....
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 28 Apr 2013 18:08:28 +0000 (14:08 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 28 Apr 2013 18:08:28 +0000 (14:08 -0400)
lib/sqlalchemy/schema.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/expression.py
lib/sqlalchemy/sql/operators.py
lib/sqlalchemy/testing/assertions.py
lib/sqlalchemy/testing/fixtures.py
lib/sqlalchemy/testing/plugin/noseplugin.py
lib/sqlalchemy/testing/warnings.py
lib/sqlalchemy/util/__init__.py
lib/sqlalchemy/util/compat.py
lib/sqlalchemy/util/langhelpers.py

index b53eb88872aa0e2f161cc6a1c17efa87077f8dfc..a8705beb4367f9076593ab8b1aa7c74ea1fe6550 100644 (file)
@@ -679,11 +679,11 @@ class Table(SchemaItem, expression.TableClause):
             # skip indexes that would be generated
             # by the 'index' flag on Column
             if len(index.columns) == 1 and \
-                list(index.columns)[0].index:
+                    list(index.columns)[0].index:
                 continue
             Index(index.name,
                   unique=index.unique,
-                  *[table.c[col] for col in list(index.columns.keys())],
+                  *[table.c[col] for col in index.columns.keys()],
                   **index.kwargs)
         table.dispatch._update(self.dispatch)
         return table
@@ -898,7 +898,7 @@ class Column(SchemaItem, expression.ColumnClause):
         type_ = kwargs.pop('type_', None)
         args = list(args)
         if args:
-            if isinstance(args[0], str):
+            if isinstance(args[0], util.string_types):
                 if name is not None:
                     raise exc.ArgumentError(
                         "May not pass name positionally and as a keyword.")
@@ -944,12 +944,7 @@ class Column(SchemaItem, expression.ColumnClause):
                 args.append(self.default)
             else:
                 if getattr(self.type, '_warn_on_bytestring', False):
-# start Py3K
-                    if isinstance(self.default, bytes):
-# end Py3K
-# start Py2K
-#                    if isinstance(self.default, str):
-# end Py2K
+                    if isinstance(self.default, util.binary_type):
                         util.warn("Unicode column received non-unicode "
                                   "default value.")
                 args.append(ColumnDefault(self.default))
@@ -984,7 +979,7 @@ class Column(SchemaItem, expression.ColumnClause):
 
         if kwargs:
             raise exc.ArgumentError(
-                "Unknown arguments passed to Column: " + repr(list(kwargs.keys())))
+                "Unknown arguments passed to Column: " + repr(list(kwargs)))
 
     def __str__(self):
         if self.name is None:
@@ -1070,7 +1065,7 @@ class Column(SchemaItem, expression.ColumnClause):
         self.table = table
 
         if self.index:
-            if isinstance(self.index, str):
+            if isinstance(self.index, util.string_types):
                 raise exc.ArgumentError(
                     "The 'index' keyword argument on Column is boolean only. "
                     "To create indexes with a specific name, create an "
@@ -1078,7 +1073,7 @@ class Column(SchemaItem, expression.ColumnClause):
             Index(expression._truncated_label('ix_%s' % self._label),
                                     self, unique=self.unique)
         elif self.unique:
-            if isinstance(self.unique, str):
+            if isinstance(self.unique, util.string_types):
                 raise exc.ArgumentError(
                     "The 'unique' keyword argument on Column is boolean "
                     "only. To create unique constraints or indexes with a "
@@ -1338,7 +1333,7 @@ class ForeignKey(SchemaItem):
         if schema:
             return schema + "." + self.column.table.name + \
                                     "." + self.column.key
-        elif isinstance(self._colspec, str):
+        elif isinstance(self._colspec, util.string_types):
             return self._colspec
         elif hasattr(self._colspec, '__clause_element__'):
             _column = self._colspec.__clause_element__()
@@ -1383,7 +1378,7 @@ class ForeignKey(SchemaItem):
         """
         # ForeignKey inits its remote column as late as possible, so tables
         # can be defined without dependencies
-        if isinstance(self._colspec, str):
+        if isinstance(self._colspec, util.string_types):
             # locate the parent table this foreign key is attached to.  we
             # use the "original" column which our parent column represents
             # (its a list of columns/other ColumnElements if the parent
@@ -1650,8 +1645,7 @@ class ColumnDefault(DefaultGenerator):
         defaulted = argspec[3] is not None and len(argspec[3]) or 0
         positionals = len(argspec[0]) - defaulted
 
-# start Py3K
-# end Py3K
+        # Py3K compat - no unbound methods
         if inspect.ismethod(inspectable) or inspect.isclass(fn):
             positionals -= 1
 
@@ -1913,7 +1907,7 @@ class DefaultClause(FetchedValue):
     has_argument = True
 
     def __init__(self, arg, for_update=False, _reflected=False):
-        util.assert_arg_type(arg, (str,
+        util.assert_arg_type(arg, (util.string_types[0],
                                    expression.ClauseElement,
                                    expression.TextClause), 'arg')
         super(DefaultClause, self).__init__(for_update)
@@ -2023,7 +2017,7 @@ class ColumnCollectionMixin(object):
 
     def _set_parent(self, table):
         for col in self._pending_colargs:
-            if isinstance(col, str):
+            if isinstance(col, util.string_types):
                 col = table.c[col]
             self.columns.add(col)
 
@@ -2060,7 +2054,7 @@ class ColumnCollectionConstraint(ColumnCollectionMixin, Constraint):
 
     def copy(self, **kw):
         c = self.__class__(name=self.name, deferrable=self.deferrable,
-                              initially=self.initially, *list(self.columns.keys()))
+                              initially=self.initially, *self.columns.keys())
         c.dispatch._update(self.dispatch)
         return c
 
@@ -2241,7 +2235,7 @@ class ForeignKeyConstraint(Constraint):
             self._set_parent_with_dispatch(table)
         elif columns and \
             isinstance(columns[0], Column) and \
-            columns[0].table is not None:
+                columns[0].table is not None:
             self._set_parent_with_dispatch(columns[0].table)
 
     @property
@@ -2250,7 +2244,7 @@ class ForeignKeyConstraint(Constraint):
 
     @property
     def columns(self):
-        return list(self._elements.keys())
+        return list(self._elements)
 
     @property
     def elements(self):
@@ -2262,7 +2256,7 @@ class ForeignKeyConstraint(Constraint):
         for col, fk in self._elements.items():
             # string-specified column names now get
             # resolved to Column objects
-            if isinstance(col, str):
+            if isinstance(col, util.string_types):
                 try:
                     col = table.c[col]
                 except KeyError:
@@ -2272,7 +2266,7 @@ class ForeignKeyConstraint(Constraint):
                         "named '%s' is present." % (table.description, col))
 
             if not hasattr(fk, 'parent') or \
-                fk.parent is not col:
+                    fk.parent is not col:
                 fk._set_parent_with_dispatch(col)
 
         if self.use_alter:
@@ -2287,8 +2281,8 @@ class ForeignKeyConstraint(Constraint):
 
     def copy(self, schema=None, **kw):
         fkc = ForeignKeyConstraint(
-                    [x.parent.key for x in list(self._elements.values())],
-                    [x._get_colspec(schema=schema) for x in list(self._elements.values())],
+                    [x.parent.key for x in self._elements.values()],
+                    [x._get_colspec(schema=schema) for x in self._elements.values()],
                     name=self.name,
                     onupdate=self.onupdate,
                     ondelete=self.ondelete,
@@ -2563,7 +2557,7 @@ class MetaData(SchemaItem):
         return 'MetaData(bind=%r)' % self.bind
 
     def __contains__(self, table_or_key):
-        if not isinstance(table_or_key, str):
+        if not isinstance(table_or_key, util.string_types):
             table_or_key = table_or_key.key
         return table_or_key in self.tables
 
@@ -2578,7 +2572,7 @@ class MetaData(SchemaItem):
         dict.pop(self.tables, key, None)
         if self._schemas:
             self._schemas = set([t.schema
-                                for t in list(self.tables.values())
+                                for t in self.tables.values()
                                 if t.schema is not None])
 
     def __getstate__(self):
@@ -2623,7 +2617,7 @@ class MetaData(SchemaItem):
     def _bind_to(self, bind):
         """Bind this MetaData to an Engine, Connection, string or URL."""
 
-        if isinstance(bind, (str, url.URL)):
+        if isinstance(bind, util.string_types + (url.URL, )):
             from sqlalchemy import create_engine
             self._bind = create_engine(bind)
         else:
@@ -2656,7 +2650,7 @@ class MetaData(SchemaItem):
             :meth:`.Inspector.sorted_tables`
 
         """
-        return sqlutil.sort_tables(iter(self.tables.values()))
+        return sqlutil.sort_tables(self.tables.values())
 
     def reflect(self, bind=None, schema=None, views=False, only=None):
         """Load all available table definitions from the database.
@@ -2717,7 +2711,7 @@ class MetaData(SchemaItem):
                     bind.dialect.get_view_names(conn, schema)
                 )
 
-            current = set(self.tables.keys())
+            current = set(self.tables)
 
             if only is None:
                 load = [name for name in available if name not in current]
@@ -2839,7 +2833,7 @@ class ThreadLocalMetaData(MetaData):
     def _bind_to(self, bind):
         """Bind to a Connectable in the caller's thread."""
 
-        if isinstance(bind, (str, url.URL)):
+        if isinstance(bind, util.string_types + (url.URL, )):
             try:
                 self.context._engine = self.__engines[bind]
             except KeyError:
@@ -3069,7 +3063,7 @@ class DDLElement(expression.Executable, _DDLCompiles):
             not self._should_execute_deprecated(None, target, bind, **kw):
             return False
 
-        if isinstance(self.dialect, str):
+        if isinstance(self.dialect, util.string_types):
             if self.dialect != bind.engine.name:
                 return False
         elif isinstance(self.dialect, (tuple, list, set)):
@@ -3084,7 +3078,7 @@ class DDLElement(expression.Executable, _DDLCompiles):
     def _should_execute_deprecated(self, event, target, bind, **kw):
         if self.on is None:
             return True
-        elif isinstance(self.on, str):
+        elif isinstance(self.on, util.string_types):
             return self.on == bind.engine.name
         elif isinstance(self.on, (tuple, list, set)):
             return bind.engine.name in self.on
@@ -3099,7 +3093,7 @@ class DDLElement(expression.Executable, _DDLCompiles):
 
     def _check_ddl_on(self, on):
         if (on is not None and
-            (not isinstance(on, (str, tuple, list, set)) and
+            (not isinstance(on, util.string_types + (tuple, list, set)) and
                     not util.callable(on))):
             raise exc.ArgumentError(
                 "Expected the name of a database dialect, a tuple "
@@ -3224,7 +3218,7 @@ class DDL(DDLElement):
 
         """
 
-        if not isinstance(statement, str):
+        if not isinstance(statement, util.string_types):
             raise exc.ArgumentError(
                 "Expected a string or unicode SQL statement, got '%r'" %
                 statement)
@@ -3256,7 +3250,7 @@ def _to_schema_column(element):
 def _to_schema_column_or_string(element):
     if hasattr(element, '__clause_element__'):
         element = element.__clause_element__()
-    if not isinstance(element, (str, expression.ColumnElement)):
+    if not isinstance(element, util.string_types + (expression.ColumnElement, )):
         msg = "Element %r is not a string name or column element"
         raise exc.ArgumentError(msg % element)
     return element
index b3f74ceefc0cf31082dac4a67815763ec5f1c073..d51dd625a448bc74651d56e8fe96f225910d95e0 100644 (file)
@@ -83,9 +83,7 @@ OPERATORS = {
     operators.add: ' + ',
     operators.mul: ' * ',
     operators.sub: ' - ',
-# start Py2K
-#    operators.div: ' / ',
-# end Py2K
+    operators.div: ' / ',
     operators.mod: ' % ',
     operators.truediv: ' / ',
     operators.neg: '-',
@@ -826,12 +824,12 @@ class SQLCompiler(engine.Compiled):
         of the DBAPI.
 
         """
-        if isinstance(value, str):
+        if isinstance(value, util.string_types):
             value = value.replace("'", "''")
             return "'%s'" % value
         elif value is None:
             return "NULL"
-        elif isinstance(value, (float, int)):
+        elif isinstance(value, (float, ) + util.int_types):
             return repr(value)
         elif isinstance(value, decimal.Decimal):
             return str(value)
@@ -1214,7 +1212,7 @@ class SQLCompiler(engine.Compiled):
             self.positiontup = self.cte_positional + self.positiontup
         cte_text = self.get_cte_preamble(self.ctes_recursive) + " "
         cte_text += ", \n".join(
-            [txt for txt in list(self.ctes.values())]
+            [txt for txt in self.ctes.values()]
         )
         cte_text += "\n "
         return cte_text
@@ -1325,7 +1323,7 @@ class SQLCompiler(engine.Compiled):
             dialect_hints = dict([
                 (table, hint_text)
                 for (table, dialect), hint_text in
-                list(insert_stmt._hints.items())
+                insert_stmt._hints.items()
                 if dialect in ('*', self.dialect.name)
             ])
             if insert_stmt.table in dialect_hints:
@@ -1422,7 +1420,7 @@ class SQLCompiler(engine.Compiled):
             dialect_hints = dict([
                 (table, hint_text)
                 for (table, dialect), hint_text in
-                list(update_stmt._hints.items())
+                update_stmt._hints.items()
                 if dialect in ('*', self.dialect.name)
             ])
             if update_stmt.table in dialect_hints:
@@ -1559,7 +1557,7 @@ class SQLCompiler(engine.Compiled):
         if extra_tables and stmt_parameters:
             normalized_params = dict(
                 (sql._clause_element_as_expr(c), param)
-                for c, param in list(stmt_parameters.items())
+                for c, param in stmt_parameters.items()
             )
             assert self.isupdate
             affected_tables = set()
@@ -1752,7 +1750,7 @@ class SQLCompiler(engine.Compiled):
             dialect_hints = dict([
                 (table, hint_text)
                 for (table, dialect), hint_text in
-                list(delete_stmt._hints.items())
+                delete_stmt._hints.items()
                 if dialect in ('*', self.dialect.name)
             ])
             if delete_stmt.table in dialect_hints:
@@ -1870,11 +1868,11 @@ class DDLCompiler(engine.Compiled):
                     first_pk = True
             except exc.CompileError as ce:
                 util.raise_from_cause(
-                    exc.CompileError("(in table '%s', column '%s'): %s" % (
+                    exc.CompileError(util.u("(in table '%s', column '%s'): %s" % (
                                                 table.description,
                                                 column.name,
                                                 ce.args[0]
-                                            )))
+                                            ))))
 
         const = self.create_table_constraints(table)
         if const:
index 1ad6364d2f2032a891fd622b19dbd8c23291da8f..aff5512d313e7b614b4c41d573b4c152afe39c8d 100644 (file)
@@ -26,7 +26,7 @@ to stay the same in future releases.
 
 """
 
-
+from __future__ import unicode_literals
 import itertools
 import re
 from operator import attrgetter
@@ -1375,7 +1375,7 @@ func = _FunctionGenerator()
 modifier = _FunctionGenerator(group=False)
 
 
-class _truncated_label(str):
+class _truncated_label(util.text_type):
     """A unicode subclass used to identify symbolic "
     "names that may require truncation."""
 
@@ -1395,13 +1395,13 @@ class _anonymous_label(_truncated_label):
 
     def __add__(self, other):
         return _anonymous_label(
-                    str(self) +
-                    str(other))
+                    util.text_type(self) +
+                    util.text_type(other))
 
     def __radd__(self, other):
         return _anonymous_label(
-                    str(other) +
-                    str(self))
+                    util.text_type(other) +
+                    util.text_type(self))
 
     def apply_map(self, map_):
         return self % map_
@@ -1422,7 +1422,7 @@ def _as_truncated(value):
 
 
 def _string_or_unprintable(element):
-    if isinstance(element, str):
+    if isinstance(element, util.string_types):
         return element
     else:
         try:
@@ -1486,7 +1486,7 @@ def _labeled(element):
 
 
 def _column_as_key(element):
-    if isinstance(element, str):
+    if isinstance(element, util.string_types):
         return element
     if hasattr(element, '__clause_element__'):
         element = element.__clause_element__()
@@ -1508,8 +1508,8 @@ def _literal_as_text(element):
         return element
     elif hasattr(element, '__clause_element__'):
         return element.__clause_element__()
-    elif isinstance(element, str):
-        return TextClause(str(element))
+    elif isinstance(element, util.string_types):
+        return TextClause(util.text_type(element))
     elif isinstance(element, (util.NoneType, bool)):
         return _const_expr(element)
     else:
@@ -1583,8 +1583,8 @@ def _interpret_as_column_or_from(element):
 def _interpret_as_from(element):
     insp = inspection.inspect(element, raiseerr=False)
     if insp is None:
-        if isinstance(element, str):
-            return TextClause(str(element))
+        if isinstance(element, util.string_types):
+            return TextClause(util.text_type(element))
     elif hasattr(insp, "selectable"):
         return insp.selectable
     raise exc.ArgumentError("FROM expression expected")
@@ -1914,12 +1914,10 @@ class ClauseElement(Visitable):
         return dialect.statement_compiler(dialect, self, **kw)
 
     def __str__(self):
-# start Py3K
-        return str(self.compile())
-# end Py3K
-# start Py2K
-#        return unicode(self.compile()).encode('ascii', 'backslashreplace')
-# end Py2K
+        if util.py3k:
+            return str(self.compile())
+        else:
+            return unicode(self.compile()).encode('ascii', 'backslashreplace')
 
     def __and__(self, other):
         return and_(self, other)
@@ -1933,6 +1931,8 @@ class ClauseElement(Visitable):
     def __bool__(self):
         raise TypeError("Boolean value of this clause is not defined")
 
+    __nonzero__ = __bool__
+
     def _negate(self):
         if hasattr(self, 'negation_clause'):
             return self.negation_clause
@@ -2508,7 +2508,7 @@ class ColumnCollection(util.OrderedProperties):
     def update(self, value):
         self._data.update(value)
         self._all_cols.clear()
-        self._all_cols.update(list(self._data.values()))
+        self._all_cols.update(self._data.values())
 
     def extend(self, iter):
         self.update((c.key, c) for c in iter)
@@ -2524,13 +2524,13 @@ class ColumnCollection(util.OrderedProperties):
         return and_(*l)
 
     def __contains__(self, other):
-        if not isinstance(other, str):
+        if not isinstance(other, util.string_types):
             raise exc.ArgumentError("__contains__ requires a string argument")
         return util.OrderedProperties.__contains__(self, other)
 
     def __setstate__(self, state):
         self.__dict__['_data'] = state['_data']
-        self.__dict__['_all_cols'] = util.column_set(list(self._data.values()))
+        self.__dict__['_all_cols'] = util.column_set(self._data.values())
 
     def contains_column(self, col):
         # this has to be done via set() membership
@@ -3185,13 +3185,13 @@ class TextClause(Executable, ClauseElement):
     _hide_froms = []
 
     def __init__(
-        self,
-        text='',
-        bind=None,
-        bindparams=None,
-        typemap=None,
-        autocommit=None,
-        ):
+                    self,
+                    text='',
+                    bind=None,
+                    bindparams=None,
+                    typemap=None,
+                    autocommit=None):
+
         self._bind = bind
         self.bindparams = {}
         self.typemap = typemap
@@ -3201,9 +3201,9 @@ class TextClause(Executable, ClauseElement):
                                  'e)')
             self._execution_options = \
                 self._execution_options.union(
-                  {'autocommit': autocommit})
+                    {'autocommit': autocommit})
         if typemap is not None:
-            for key in list(typemap.keys()):
+            for key in typemap:
                 typemap[key] = sqltypes.to_instance(typemap[key])
 
         def repl(m):
@@ -3237,7 +3237,7 @@ class TextClause(Executable, ClauseElement):
 
     def _copy_internals(self, clone=_clone, **kw):
         self.bindparams = dict((b.key, clone(b, **kw))
-                               for b in list(self.bindparams.values()))
+                               for b in self.bindparams.values())
 
     def get_children(self, **kwargs):
         return list(self.bindparams.values())
@@ -3751,7 +3751,7 @@ class BinaryExpression(ColumnElement):
                     negate=None, modifiers=None):
         # allow compatibility with libraries that
         # refer to BinaryExpression directly and pass strings
-        if isinstance(operator, str):
+        if isinstance(operator, util.string_types):
             operator = operators.custom_op(operator)
         self._orig = (left, right)
         self.left = _literal_as_text(left).self_group(against=operator)
@@ -3770,6 +3770,7 @@ class BinaryExpression(ColumnElement):
             return self.operator(hash(self._orig[0]), hash(self._orig[1]))
         else:
             raise TypeError("Boolean value of this clause is not defined")
+    __nonzero__ = __bool__
 
     @property
     def is_comparison(self):
@@ -4058,12 +4059,10 @@ class Alias(FromClause):
 
     @property
     def description(self):
-# start Py3K
-        return self.name
-# end Py3K
-# start Py2K
-#        return self.name.encode('ascii', 'backslashreplace')
-# end Py2K
+        if util.py3k:
+            return self.name
+        else:
+            return self.name.encode('ascii', 'backslashreplace')
 
     def as_scalar(self):
         try:
@@ -4473,12 +4472,10 @@ class ColumnClause(Immutable, ColumnElement):
 
     @util.memoized_property
     def description(self):
-# start Py3K
-        return self.name
-# end Py3K
-# start Py2K
-#        return self.name.encode('ascii', 'backslashreplace')
-# end Py2K
+        if util.py3k:
+            return self.name
+        else:
+            return self.name.encode('ascii', 'backslashreplace')
 
     @_memoized_property
     def _key_label(self):
@@ -4605,12 +4602,10 @@ class TableClause(Immutable, FromClause):
 
     @util.memoized_property
     def description(self):
-# start Py3K
-        return self.name
-# end Py3K
-# start Py2K
-#        return self.name.encode('ascii', 'backslashreplace')
-# end Py2K
+        if util.py3k:
+            return self.name
+        else:
+            return self.name.encode('ascii', 'backslashreplace')
 
     def append_column(self, c):
         self._columns[c.key] = c
index cf1c484d0a3500ece86f3df51a606f34aecdba1c..4afb3db48b0cb4048a382e6184d64a24e772c6af 100644 (file)
@@ -9,16 +9,19 @@
 
 """Defines operators used in SQL expressions."""
 
+from .. import util
+
+
 from operator import (
     and_, or_, inv, add, mul, sub, mod, truediv, lt, le, ne, gt, ge, eq, neg,
     getitem, lshift, rshift
     )
 
-# start Py2K
-#from operator import (div,)
-# end Py2K
+if util.py2k:
+    from operator import div
+else:
+    div = truediv
 
-from ..util import symbol
 
 
 class Operators(object):
@@ -781,17 +784,15 @@ parenthesize (a op b).
 
 """
 
-_smallest = symbol('_smallest', canonical=-100)
-_largest = symbol('_largest', canonical=100)
+_smallest = util.symbol('_smallest', canonical=-100)
+_largest = util.symbol('_largest', canonical=100)
 
 _PRECEDENCE = {
     from_: 15,
     getitem: 15,
     mul: 8,
     truediv: 8,
-# start Py2K
-#    div: 8,
-# end Py2K
+    div: 8,
     mod: 8,
     neg: 8,
     add: 7,
index e01948f9ccf993011c18468bcb23d0c26c9e6243..c041539619dcbdfdbb372d59485379612b5e8b04 100644 (file)
@@ -1,4 +1,4 @@
-
+from __future__ import absolute_import
 
 from . import util as testutil
 from sqlalchemy import pool, orm, util
@@ -63,7 +63,7 @@ def emits_warning_on(db, *warnings):
 
     @decorator
     def decorate(fn, *args, **kw):
-        if isinstance(db, str):
+        if isinstance(db, util.string_types):
             if not spec(config.db):
                 return fn(*args, **kw)
             else:
@@ -172,8 +172,8 @@ def assert_raises_message(except_cls, msg, callable_, *args, **kwargs):
         callable_(*args, **kwargs)
         assert False, "Callable did not raise an exception"
     except except_cls as e:
-        assert re.search(msg, str(e), re.UNICODE), "%r !~ %s" % (msg, e)
-        print(str(e).encode('utf-8'))
+        assert re.search(msg, util.text_type(e), re.UNICODE), "%r !~ %s" % (msg, e)
+        print(util.text_type(e).encode('utf-8'))
 
 
 class AssertsCompiledSQL(object):
@@ -190,12 +190,12 @@ class AssertsCompiledSQL(object):
                 dialect = default.DefaultDialect()
             elif dialect is None:
                 dialect = config.db.dialect
-            elif isinstance(dialect, str):
+            elif isinstance(dialect, util.string_types):
                 dialect = create_engine("%s://" % dialect).dialect
 
         kw = {}
         if params is not None:
-            kw['column_keys'] = list(params.keys())
+            kw['column_keys'] = list(params)
 
         if isinstance(clause, orm.Query):
             context = clause._compile_context()
@@ -205,13 +205,13 @@ class AssertsCompiledSQL(object):
         c = clause.compile(dialect=dialect, **kw)
 
         param_str = repr(getattr(c, 'params', {}))
-# start Py3K
-        param_str = param_str.encode('utf-8').decode('ascii', 'ignore')
-# end Py3K
 
-        print("\nSQL String:\n" + str(c) + param_str)
+        if util.py3k:
+            param_str = param_str.encode('utf-8').decode('ascii', 'ignore')
+
+        print("\nSQL String:\n" + util.text_type(c) + param_str)
 
-        cc = re.sub(r'[\n\t]', '', str(c))
+        cc = re.sub(r'[\n\t]', '', util.text_type(c))
 
         eq_(cc, result, "%r != %r on dialect %r" % (cc, result, dialect))
 
@@ -301,7 +301,7 @@ class AssertsExecutionResults(object):
         found = util.IdentitySet(result)
         expected = set([immutabledict(e) for e in expected])
 
-        for wrong in itertools.filterfalse(lambda o: type(o) == cls, found):
+        for wrong in util.itertools_filterfalse(lambda o: type(o) == cls, found):
             fail('Unexpected type "%s", expected "%s"' % (
                 type(wrong).__name__, cls.__name__))
 
index 08b2361f22a54dec7d2552a62b0975b9ba59d7ae..daa779ae3181f876b74f7d2d886f38ad25c2e9f9 100644 (file)
@@ -1,6 +1,7 @@
 from . import config
 from . import assertions, schema
 from .util import adict
+from .. import util
 from .engines import drop_all_tables
 from .entities import BasicEntity, ComparableEntity
 import sys
@@ -126,8 +127,9 @@ class TablesTest(TestBase):
                 try:
                     table.delete().execute().close()
                 except sa.exc.DBAPIError as ex:
-                    print("Error emptying table %s: %r" % (
-                        table, ex), file=sys.stderr)
+                    util.print_(
+                        ("Error emptying table %s: %r" % (table, ex)),
+                        file=sys.stderr)
 
     def setup(self):
         self._setup_each_tables()
@@ -190,7 +192,7 @@ class TablesTest(TestBase):
         for table, data in cls.fixtures().items():
             if len(data) < 2:
                 continue
-            if isinstance(table, str):
+            if isinstance(table, util.string_types):
                 table = cls.tables[table]
             headers[table] = data[0]
             rows[table] = data[1:]
@@ -199,7 +201,7 @@ class TablesTest(TestBase):
                 continue
             cls.bind.execute(
                 table.insert(),
-                [dict(list(zip(headers[table], column_values)))
+                [dict(zip(headers[table], column_values))
                  for column_values in rows[table]])
 
 
@@ -284,7 +286,7 @@ class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults):
                 cls_registry[classname] = cls
                 return type.__init__(cls, classname, bases, dict_)
 
-        class _Base(object, metaclass=FindFixture):
+        class _Base(util.with_metaclass(FindFixture, object)):
             pass
 
         class Basic(BasicEntity, _Base):
index 7ad61c7b9cd3cb0eb6f27ffc2129f57333727736..b3cd3a4e3eda953be64b8a1382d610d06c804866 100644 (file)
@@ -10,13 +10,19 @@ normally as "from sqlalchemy.testing.plugin import noseplugin".
 
 """
 
+from __future__ import absolute_import
 
 import os
-import configparser
+import sys
+py3k = sys.version_info >= (3, 0)
+
+if py3k:
+    import configparser
+else:
+    import ConfigParser as configparser
 
 from nose.plugins import Plugin
 from nose import SkipTest
-import time
 import sys
 import re
 
index 9546945ebf10aff7ef42d6d8927dc5f64979677a..6193acd886e9e574e0d80482a11ef1f43ad165b5 100644 (file)
@@ -1,4 +1,4 @@
-
+from __future__ import absolute_import
 
 import warnings
 from .. import exc as sa_exc
@@ -10,7 +10,7 @@ def testing_warn(msg, stacklevel=3):
 
     filename = "sqlalchemy.testing.warnings"
     lineno = 1
-    if isinstance(msg, str):
+    if isinstance(msg, util.string_types):
         warnings.warn_explicit(msg, sa_exc.SAWarning, filename, lineno)
     else:
         warnings.warn_explicit(msg, filename, lineno)
index 9e562402db3e85bb080fff5645597d296c01ff85..25dcce335f1d57b58c9d9d9f6571e2e06c52f64b 100644 (file)
@@ -8,7 +8,7 @@ from .compat import callable, cmp, reduce,  \
     threading, py3k, py2k, jython, pypy, cpython, win32, \
     pickle, dottedgetter, parse_qsl, namedtuple, next, WeakSet, reraise, \
     raise_from_cause, text_type, string_types, int_types, binary_type, \
-    quote_plus, with_metaclass
+    quote_plus, with_metaclass, print_, itertools_filterfalse, u, b
 
 from ._collections import KeyedTuple, ImmutableContainer, immutabledict, \
     Properties, OrderedProperties, ImmutableProperties, OrderedDict, \
index bdeb69d1e15424bc2e6251a2e4cc21b21ac83eca..bc7a0fe2136d529da1450aa91bfbb5fbe35261fe 100644 (file)
@@ -33,6 +33,8 @@ else:
         import pickle
 
 if py3k:
+    import builtins
+
     from inspect import getfullargspec as inspect_getfullargspec
     from urllib.parse import quote_plus, unquote_plus, parse_qsl
     string_types = str,
@@ -41,6 +43,9 @@ if py3k:
     int_types = int,
     iterbytes = iter
 
+    def u(s):
+        return s
+
     def b(s):
         return s.encode("latin-1")
 
@@ -55,6 +60,13 @@ if py3k:
 
     from functools import reduce
 
+    print_ = getattr(builtins, "print")
+
+    import_ = getattr(builtins, '__import__')
+
+    import itertools
+    itertools_filterfalse = itertools.filterfalse
+    itertools_imap = map
 else:
     from inspect import getargspec as inspect_getfullargspec
     from urllib import quote_plus, unquote_plus
@@ -66,13 +78,35 @@ else:
     def iterbytes(buf):
         return (ord(byte) for byte in buf)
 
+    def u(s):
+        return unicode(s, "unicode_escape")
+
     def b(s):
         return s
 
+    def import_(*args):
+        if len(args) == 4:
+            args = args[0:3] + ([str(arg) for arg in args[3]],)
+        return __import__(*args)
+
     callable = callable
     cmp = cmp
     reduce = reduce
 
+    def print_(*args, **kwargs):
+        fp = kwargs.pop("file", sys.stdout)
+        if fp is None:
+            return
+        for arg in enumerate(args):
+            if not isinstance(arg, basestring):
+                arg = str(arg)
+            fp.write(arg)
+
+    import itertools
+    itertools_filterfalse = itertools.ifilterfalse
+    itertools_imap = itertools.imap
+
+
 
 try:
     from weakref import WeakSet
@@ -131,6 +165,12 @@ else:
         exc_type, exc_value, exc_tb = exc_info
         reraise(type(exception), exception, tb=exc_tb)
 
+if py3k:
+    exec_ = getattr(builtins, 'exec')
+else:
+    def exec_(func_text, globals_, lcl):
+        exec('exec func_text in globals_, lcl')
+
 
 def with_metaclass(meta, *bases):
     """Create a base class with a metaclass."""
index f65803dc63a643d0f973a2d086f4430c2db8db4d..8a6af3758ed8746a34b4a80361663007c9fedbc2 100644 (file)
@@ -15,18 +15,14 @@ import re
 import sys
 import types
 import warnings
-from .compat import threading, \
-    callable, inspect_getfullargspec, py3k
 from functools import update_wrapper
 from .. import exc
 import hashlib
 from . import compat
-import collections
 
 def md5_hex(x):
-# start Py3K
-    x = x.encode('utf-8')
-# end Py3K
+    if compat.py3k:
+        x = x.encode('utf-8')
     m = hashlib.md5()
     m.update(x)
     return m.hexdigest()
@@ -79,7 +75,7 @@ def _unique_symbols(used, *bases):
     used = set(used)
     for base in bases:
         pool = itertools.chain((base,),
-                               map(lambda i: base + str(i),
+                               compat.itertools_imap(lambda i: base + str(i),
                                               range(1000)))
         for sym in pool:
             if sym not in used:
@@ -96,7 +92,7 @@ def decorator(target):
     def decorate(fn):
         if not inspect.isfunction(fn):
             raise Exception("not a decoratable function")
-        spec = inspect_getfullargspec(fn)
+        spec = compat.inspect_getfullargspec(fn)
         names = tuple(spec[0]) + spec[1:3] + (fn.__name__,)
         targ_name, fn_name = _unique_symbols(names, 'target', 'fn')
 
@@ -145,7 +141,7 @@ class PluginLoader(object):
 
     def register(self, name, modulepath, objname):
         def load():
-            mod = __import__(modulepath)
+            mod = compat.import_(modulepath)
             for token in modulepath.split(".")[1:]:
                 mod = getattr(mod, token)
             return getattr(mod, objname)
@@ -252,8 +248,8 @@ def format_argspec_plus(fn, grouped=True):
        'apply_pos': '(self, a, b, c, **d)'}
 
     """
-    if isinstance(fn, collections.Callable):
-        spec = inspect_getfullargspec(fn)
+    if compat.callable(fn):
+        spec = compat.inspect_getfullargspec(fn)
     else:
         # we accept an existing argspec...
         spec = fn
@@ -265,7 +261,7 @@ def format_argspec_plus(fn, grouped=True):
     else:
         self_arg = None
 
-    if py3k:
+    if compat.py3k:
         apply_pos = inspect.formatargspec(spec[0], spec[1],
             spec[2], None, spec[4])
         num_defaults = 0
@@ -420,34 +416,33 @@ def class_hierarchy(cls):
     will not be descended.
 
     """
-# start Py2K
-#    if isinstance(cls, types.ClassType):
-#        return list()
-# end Py2K
+    if compat.py2k:
+        if isinstance(cls, types.ClassType):
+            return list()
+
     hier = set([cls])
     process = list(cls.__mro__)
     while process:
         c = process.pop()
-# start Py2K
-#        if isinstance(c, types.ClassType):
-#            continue
-#        for b in (_ for _ in c.__bases__
-#                  if _ not in hier and not isinstance(_, types.ClassType)):
-# end Py2K
-# start Py3K
-        for b in (_ for _ in c.__bases__
-                  if _ not in hier):
-# end Py3K
+        if compat.py2k:
+            if isinstance(c, types.ClassType):
+                continue
+            bases = (_ for _ in c.__bases__
+                  if _ not in hier and not isinstance(_, types.ClassType))
+        else:
+            bases = (_ for _ in c.__bases__ if _ not in hier)
+
+        for b in bases:
             process.append(b)
             hier.add(b)
-# start Py3K
-        if c.__module__ == 'builtins' or not hasattr(c, '__subclasses__'):
-            continue
-# end Py3K
-# start Py2K
-#        if c.__module__ == '__builtin__' or not hasattr(c, '__subclasses__'):
-#            continue
-# end Py2K
+
+        if compat.py3k:
+            if c.__module__ == 'builtins' or not hasattr(c, '__subclasses__'):
+                continue
+        else:
+            if c.__module__ == '__builtin__' or not hasattr(c, '__subclasses__'):
+                continue
+
         for s in [_ for _ in c.__subclasses__() if _ not in hier]:
             process.append(s)
             hier.add(s)
@@ -504,7 +499,7 @@ def monkeypatch_proxied_specials(into_cls, from_cls, skip=None, only=None,
               "return %(name)s.%(method)s%(d_args)s" % locals())
 
         env = from_instance is not None and {name: from_instance} or {}
-        exec(py, env)
+        compat.exec_(py, env, {})
         try:
             env[method].__defaults__ = fn.__defaults__
         except AttributeError:
@@ -593,7 +588,7 @@ def as_interface(obj, cls=None, methods=None, required=None):
     for method, impl in dictlike_iteritems(obj):
         if method not in interface:
             raise TypeError("%r: unknown in this interface" % method)
-        if not isinstance(impl, collections.Callable):
+        if not compat.callable(impl):
             raise TypeError("%r=%r is not callable" % (method, impl))
         setattr(AnonymousInterface, method, staticmethod(impl))
         found.add(method)
@@ -734,11 +729,11 @@ class importlater(object):
     def _resolve(self):
         importlater._unresolved.discard(self)
         if self._il_addtl:
-            self._initial_import = __import__(
+            self._initial_import = compat.import_(
                                 self._il_path, globals(), locals(),
                                 [self._il_addtl])
         else:
-            self._initial_import = __import__(self._il_path)
+            self._initial_import = compat.import_(self._il_path)
 
     def __getattr__(self, key):
         if key == 'module':
@@ -757,7 +752,7 @@ class importlater(object):
 
 # from paste.deploy.converters
 def asbool(obj):
-    if isinstance(obj, str):
+    if isinstance(obj, compat.string_types):
         obj = obj.strip().lower()
         if obj in ['true', 'yes', 'on', 'y', 't', '1']:
             return True
@@ -817,7 +812,7 @@ def constructor_copy(obj, cls, **kw):
 def counter():
     """Return a threadsafe counter function."""
 
-    lock = threading.Lock()
+    lock = compat.threading.Lock()
     counter = itertools.count(1)
 
     # avoid the 2to3 "next" transformation...
@@ -880,16 +875,14 @@ def assert_arg_type(arg, argtype, name):
 def dictlike_iteritems(dictlike):
     """Return a (key, value) iterator for almost any dict-like object."""
 
-# start Py3K
-    if hasattr(dictlike, 'items'):
-        return list(dictlike.items())
-# end Py3K
-# start Py2K
-#    if hasattr(dictlike, 'iteritems'):
-#        return dictlike.iteritems()
-#    elif hasattr(dictlike, 'items'):
-#        return iter(dictlike.items())
-# end Py2K
+    if compat.py3k:
+        if hasattr(dictlike, 'items'):
+            return list(dictlike.items())
+    else:
+        if hasattr(dictlike, 'iteritems'):
+            return dictlike.iteritems()
+        elif hasattr(dictlike, 'items'):
+            return iter(dictlike.items())
 
     getter = getattr(dictlike, '__getitem__', getattr(dictlike, 'get', None))
     if getter is None:
@@ -902,7 +895,7 @@ def dictlike_iteritems(dictlike):
                 yield key, getter(key)
         return iterator()
     elif hasattr(dictlike, 'keys'):
-        return iter((key, getter(key)) for key in list(dictlike.keys()))
+        return iter((key, getter(key)) for key in dictlike.keys())
     else:
         raise TypeError(
             "Object '%r' is not dict-like" % dictlike)
@@ -942,7 +935,7 @@ class hybridmethod(object):
 class _symbol(int):
     def __new__(self, name, doc=None, canonical=None):
         """Construct a new named symbol."""
-        assert isinstance(name, str)
+        assert isinstance(name, compat.string_types)
         if canonical is None:
             canonical = hash(name)
         v = int.__new__(_symbol, canonical)
@@ -985,7 +978,7 @@ class symbol(object):
 
     """
     symbols = {}
-    _lock = threading.Lock()
+    _lock = compat.threading.Lock()
 
     def __new__(cls, name, doc=None, canonical=None):
         cls._lock.acquire()
@@ -1039,7 +1032,7 @@ def warn(msg, stacklevel=3):
        be controlled.
 
     """
-    if isinstance(msg, str):
+    if isinstance(msg, compat.string_types):
         warnings.warn(msg, exc.SAWarning, stacklevel=stacklevel)
     else:
         warnings.warn(msg, stacklevel=stacklevel)