From: Mike Bayer Date: Thu, 18 Dec 2008 17:57:15 +0000 (+0000) Subject: merged -r5299:5438 of py3k warnings branch. this fixes some sqlite py2.6 testing... X-Git-Tag: rel_0_5_0~91 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=be5d3263436b81fb179c8189f1064d477d5fb3e6;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git merged -r5299:5438 of py3k warnings branch. this fixes some sqlite py2.6 testing issues, and also addresses a significant chunk of py3k deprecations. It's mainly expicit __hash__ methods. Additionally, most usage of sets/dicts to store columns uses util-based placeholder names. --- diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py index 0fc7c8fbd3..9c6c48e0f6 100644 --- a/lib/sqlalchemy/databases/mysql.py +++ b/lib/sqlalchemy/databases/mysql.py @@ -1034,7 +1034,7 @@ class _BinaryType(sqltypes.Binary): if value is None: return None else: - return buffer(value) + return util.buffer(value) return process class MSVarBinary(_BinaryType): @@ -1081,7 +1081,7 @@ class MSBinary(_BinaryType): if value is None: return None else: - return buffer(value) + return util.buffer(value) return process class MSBlob(_BinaryType): @@ -1108,7 +1108,7 @@ class MSBlob(_BinaryType): if value is None: return None else: - return buffer(value) + return util.buffer(value) return process def __repr__(self): diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index c7101d10ec..dbcd5b76be 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -1267,7 +1267,6 @@ class Engine(Connectable): return self.pool.unique_connection() - def _proxy_connection_cls(cls, proxy): class ProxyConnection(cls): def execute(self, object, *multiparams, **params): @@ -1319,6 +1318,8 @@ class RowProxy(object): for i in xrange(len(self.__row)): yield self.__parent._get_col(self.__row, i) + __hash__ = None + def __eq__(self, other): return ((other is self) or (other == tuple(self.__parent._get_col(self.__row, key) @@ -1347,18 +1348,23 @@ class RowProxy(object): def items(self): """Return a list of tuples, each tuple containing a key/value pair.""" - return [(key, getattr(self, key)) for key in self.keys()] + return [(key, getattr(self, key)) for key in self.iterkeys()] def keys(self): """Return the list of keys as strings represented by this RowProxy.""" return self.__parent.keys - + + def iterkeys(self): + return iter(self.__parent.keys) + def values(self): """Return the values represented by this RowProxy as a list.""" return list(self) - + + def itervalues(self): + return iter(self) class BufferedColumnRow(RowProxy): def __init__(self, parent, row): @@ -1425,7 +1431,7 @@ class ResultProxy(object): return self._rowcount = None - self._props = util.PopulateDict(None) + self._props = util.populate_column_dict(None) self._props.creator = self.__key_fallback() self.keys = [] @@ -1848,7 +1854,7 @@ class DefaultRunner(schema.SchemaVisitor): def visit_column_onupdate(self, onupdate): if isinstance(onupdate.arg, expression.ClauseElement): return self.exec_default_sql(onupdate) - elif callable(onupdate.arg): + elif util.callable(onupdate.arg): return onupdate.arg(self.context) else: return onupdate.arg @@ -1856,7 +1862,7 @@ class DefaultRunner(schema.SchemaVisitor): def visit_column_default(self, default): if isinstance(default.arg, expression.ClauseElement): return self.exec_default_sql(default) - elif callable(default.arg): + elif util.callable(default.arg): return default.arg(self.context) else: return default.arg diff --git a/lib/sqlalchemy/engine/url.py b/lib/sqlalchemy/engine/url.py index 044d701ac6..5c8e68ce45 100644 --- a/lib/sqlalchemy/engine/url.py +++ b/lib/sqlalchemy/engine/url.py @@ -71,6 +71,9 @@ class URL(object): s += '?' + "&".join("%s=%s" % (k, self.query[k]) for k in keys) return s + def __hash__(self): + return hash(str(self)) + def __eq__(self, other): return \ isinstance(other, URL) and \ diff --git a/lib/sqlalchemy/ext/associationproxy.py b/lib/sqlalchemy/ext/associationproxy.py index 33eb5d240e..315142d8e0 100644 --- a/lib/sqlalchemy/ext/associationproxy.py +++ b/lib/sqlalchemy/ext/associationproxy.py @@ -487,7 +487,7 @@ class _AssociationList(object): raise TypeError("%s objects are unhashable" % type(self).__name__) for func_name, func in locals().items(): - if (callable(func) and func.func_name == func_name and + if (util.callable(func) and func.func_name == func_name and not func.__doc__ and hasattr(list, func_name)): func.__doc__ = getattr(list, func_name).__doc__ del func_name, func @@ -663,7 +663,7 @@ class _AssociationDict(object): raise TypeError("%s objects are unhashable" % type(self).__name__) for func_name, func in locals().items(): - if (callable(func) and func.func_name == func_name and + if (util.callable(func) and func.func_name == func_name and not func.__doc__ and hasattr(dict, func_name)): func.__doc__ = getattr(dict, func_name).__doc__ del func_name, func @@ -890,7 +890,7 @@ class _AssociationSet(object): raise TypeError("%s objects are unhashable" % type(self).__name__) for func_name, func in locals().items(): - if (callable(func) and func.func_name == func_name and + if (util.callable(func) and func.func_name == func_name and not func.__doc__ and hasattr(set, func_name)): func.__doc__ = getattr(set, func_name).__doc__ del func_name, func diff --git a/lib/sqlalchemy/ext/orderinglist.py b/lib/sqlalchemy/ext/orderinglist.py index e59b577e3d..a5d60bf82e 100644 --- a/lib/sqlalchemy/ext/orderinglist.py +++ b/lib/sqlalchemy/ext/orderinglist.py @@ -65,7 +65,7 @@ ORM-compatible constructor for `OrderingList` instances. """ from sqlalchemy.orm.collections import collection - +from sqlalchemy import util __all__ = [ 'ordering_list' ] @@ -272,7 +272,7 @@ class OrderingList(list): self._reorder() for func_name, func in locals().items(): - if (callable(func) and func.func_name == func_name and + if (util.callable(func) and func.func_name == func_name and not func.__doc__ and hasattr(list, func_name)): func.__doc__ = getattr(list, func_name).__doc__ del func_name, func diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index 79be76c3ad..f113a4eb95 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -175,7 +175,7 @@ def proxied_attribute_factory(descriptor): @property def comparator(self): - if callable(self._comparator): + if util.callable(self._comparator): self._comparator = self._comparator() return self._comparator @@ -838,7 +838,7 @@ class InstanceState(object): @property def sort_key(self): - return self.key and self.key[1] or self.insert_order + return self.key and self.key[1] or (self.insert_order, ) def check_modified(self): if self.modified: @@ -958,7 +958,7 @@ class InstanceState(object): """a set of keys which have no uncommitted changes""" return set( - key for key in self.manager.keys() + key for key in self.manager.iterkeys() if (key not in self.committed_state or (key in self.manager.mutable_attributes and not self.manager[key].impl.check_mutable_modified(self)))) @@ -972,7 +972,7 @@ class InstanceState(object): """ return set( - key for key in self.manager.keys() + key for key in self.manager.iterkeys() if key not in self.committed_state and key not in self.dict) def expire_attributes(self, attribute_names): diff --git a/lib/sqlalchemy/orm/collections.py b/lib/sqlalchemy/orm/collections.py index 2105a4fe6a..3c1c16b7df 100644 --- a/lib/sqlalchemy/orm/collections.py +++ b/lib/sqlalchemy/orm/collections.py @@ -105,15 +105,14 @@ import weakref import sqlalchemy.exceptions as sa_exc from sqlalchemy.sql import expression -from sqlalchemy import schema -import sqlalchemy.util as sautil +from sqlalchemy import schema, util __all__ = ['collection', 'collection_adapter', 'mapped_collection', 'column_mapped_collection', 'attribute_mapped_collection'] -__instrumentation_mutex = sautil.threading.Lock() +__instrumentation_mutex = util.threading.Lock() def column_mapped_collection(mapping_spec): @@ -131,7 +130,7 @@ def column_mapped_collection(mapping_spec): from sqlalchemy.orm.util import _state_mapper from sqlalchemy.orm.attributes import instance_state - cols = [expression._no_literals(q) for q in sautil.to_list(mapping_spec)] + cols = [expression._no_literals(q) for q in util.to_list(mapping_spec)] if len(cols) == 1: def keyfunc(value): state = instance_state(value) @@ -511,8 +510,8 @@ class CollectionAdapter(object): if converter is not None: return converter(obj) - setting_type = sautil.duck_type_collection(obj) - receiving_type = sautil.duck_type_collection(self._data()) + setting_type = util.duck_type_collection(obj) + receiving_type = util.duck_type_collection(self._data()) if obj is None or setting_type != receiving_type: given = obj is None and 'None' or obj.__class__.__name__ @@ -637,7 +636,7 @@ def bulk_replace(values, existing_adapter, new_adapter): if not isinstance(values, list): values = list(values) - idset = sautil.IdentitySet + idset = util.IdentitySet constants = idset(existing_adapter or ()).intersection(values or ()) additions = idset(values or ()).difference(constants) removals = idset(existing_adapter or ()).difference(constants) @@ -739,7 +738,7 @@ def _instrument_class(cls): "Can not instrument a built-in type. Use a " "subclass, even a trivial one.") - collection_type = sautil.duck_type_collection(cls) + collection_type = util.duck_type_collection(cls) if collection_type in __interfaces: roles = __interfaces[collection_type].copy() decorators = roles.pop('_decorators', {}) @@ -753,7 +752,7 @@ def _instrument_class(cls): for name in dir(cls): method = getattr(cls, name, None) - if not callable(method): + if not util.callable(method): continue # note role declarations @@ -825,7 +824,7 @@ def _instrument_membership_mutator(method, before, argument, after): """Route method args and/or return value through the collection adapter.""" # This isn't smart enough to handle @adds(1) for 'def fn(self, (a, b))' if before: - fn_args = list(sautil.flatten_iterator(inspect.getargspec(method)[0])) + fn_args = list(util.flatten_iterator(inspect.getargspec(method)[0])) if type(argument) is int: pos_arg = argument named_arg = len(fn_args) > argument and fn_args[argument] or None @@ -1040,7 +1039,7 @@ def _dict_decorators(): setattr(fn, '_sa_instrumented', True) fn.__doc__ = getattr(getattr(dict, fn.__name__), '__doc__') - Unspecified = sautil.symbol('Unspecified') + Unspecified = util.symbol('Unspecified') def __setitem__(fn): def __setitem__(self, key, value, _sa_initiator=None): @@ -1138,7 +1137,7 @@ def _set_binops_check_strict(self, obj): def _set_binops_check_loose(self, obj): """Allow anything set-like to participate in set binops.""" return (isinstance(obj, _set_binop_bases + (self.__class__,)) or - sautil.duck_type_collection(obj) == set) + util.duck_type_collection(obj) == set) def _set_decorators(): @@ -1148,7 +1147,7 @@ def _set_decorators(): setattr(fn, '_sa_instrumented', True) fn.__doc__ = getattr(getattr(set, fn.__name__), '__doc__') - Unspecified = sautil.symbol('Unspecified') + Unspecified = util.symbol('Unspecified') def add(fn): def add(self, value, _sa_initiator=None): @@ -1405,7 +1404,7 @@ class MappedCollection(dict): have assigned for that value. """ - for incoming_key, value in sautil.dictlike_iteritems(dictlike): + for incoming_key, value in util.dictlike_iteritems(dictlike): new_key = self.keyfunc(value) if incoming_key != new_key: raise TypeError( diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index cc3517f75c..ca6dec6895 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -409,8 +409,8 @@ class Mapper(object): self._pks_by_table = {} self._cols_by_table = {} - all_cols = set(chain(*[col.proxy_set for col in self._columntoproperty])) - pk_cols = set(c for c in all_cols if c.primary_key) + all_cols = util.column_set(chain(*[col.proxy_set for col in self._columntoproperty])) + pk_cols = util.column_set(c for c in all_cols if c.primary_key) # identify primary key columns which are also mapped by this mapper. tables = set(self.tables + [self.mapped_table]) @@ -418,8 +418,8 @@ class Mapper(object): for t in tables: if t.primary_key and pk_cols.issuperset(t.primary_key): # ordering is important since it determines the ordering of mapper.primary_key (and therefore query.get()) - self._pks_by_table[t] = util.OrderedSet(t.primary_key).intersection(pk_cols) - self._cols_by_table[t] = util.OrderedSet(t.c).intersection(all_cols) + self._pks_by_table[t] = util.ordered_column_set(t.primary_key).intersection(pk_cols) + self._cols_by_table[t] = util.ordered_column_set(t.c).intersection(all_cols) # determine cols that aren't expressed within our tables; mark these # as "read only" properties which are refreshed upon INSERT/UPDATE @@ -470,7 +470,7 @@ class Mapper(object): # table columns mapped to lists of MapperProperty objects # using a list allows a single column to be defined as # populating multiple object attributes - self._columntoproperty = {} + self._columntoproperty = util.column_dict() # load custom properties if self._init_properties: @@ -891,7 +891,7 @@ class Mapper(object): """ params = [(primary_key, sql.bindparam(None, type_=primary_key.type)) for primary_key in self.primary_key] - return sql.and_(*[k==v for (k, v) in params]), dict(params) + return sql.and_(*[k==v for (k, v) in params]), util.column_dict(params) @util.memoized_property def _equivalent_columns(self): @@ -915,17 +915,17 @@ class Mapper(object): """ - result = {} + result = util.column_dict() def visit_binary(binary): if binary.operator == operators.eq: if binary.left in result: result[binary.left].add(binary.right) else: - result[binary.left] = set((binary.right,)) + result[binary.left] = util.column_set((binary.right,)) if binary.right in result: result[binary.right].add(binary.left) else: - result[binary.right] = set((binary.left,)) + result[binary.right] = util.column_set((binary.left,)) for mapper in self.base_mapper.polymorphic_iterator(): if mapper.inherit_condition: visitors.traverse(mapper.inherit_condition, {}, {'binary':visit_binary}) @@ -1232,7 +1232,7 @@ class Mapper(object): for t in mapper.tables: table_to_mapper[t] = mapper - for table in sqlutil.sort_tables(table_to_mapper.keys()): + for table in sqlutil.sort_tables(table_to_mapper.iterkeys()): insert = [] update = [] @@ -1282,7 +1282,7 @@ class Mapper(object): if col is mapper.version_id_col: params[col._label] = mapper._get_state_attr_by_column(state, col) params[col.key] = params[col._label] + 1 - for prop in mapper._columntoproperty.values(): + for prop in mapper._columntoproperty.itervalues(): history = attributes.get_history(state, prop.key, passive=True) if history.added: hasdata = True @@ -1432,7 +1432,7 @@ class Mapper(object): for t in mapper.tables: table_to_mapper[t] = mapper - for table in reversed(sqlutil.sort_tables(table_to_mapper.keys())): + for table in reversed(sqlutil.sort_tables(table_to_mapper.iterkeys())): delete = {} for state, mapper, connection in tups: if table not in mapper._pks_by_table: @@ -1666,7 +1666,7 @@ class Mapper(object): """Produce a collection of attribute level row processor callables.""" new_populators, existing_populators = [], [] - for prop in self._props.values(): + for prop in self._props.itervalues(): newpop, existingpop = prop.create_row_processor(context, path, self, row, adapter) if newpop: new_populators.append((prop.key, newpop)) diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index ad42117e1b..084a539d18 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -112,7 +112,7 @@ class CompositeProperty(ColumnProperty): util.warn_deprecated("The 'comparator' argument to CompositeProperty is deprecated. Use comparator_factory.") kwargs['comparator_factory'] = kwargs['comparator'] super(CompositeProperty, self).__init__(*columns, **kwargs) - self._col_position_map = dict((c, i) for i, c in enumerate(columns)) + self._col_position_map = util.column_dict((c, i) for i, c in enumerate(columns)) self.composite_class = class_ self.strategy_class = strategies.CompositeColumnLoader @@ -159,7 +159,9 @@ class CompositeProperty(ColumnProperty): return expression.ClauseList(*[self.adapter(x) for x in self.prop.columns]) else: return expression.ClauseList(*self.prop.columns) - + + __hash__ = None + def __eq__(self, other): if other is None: values = [None] * len(self.prop.columns) @@ -363,6 +365,8 @@ class RelationProperty(StrategizedProperty): raise NotImplementedError("in_() not yet supported for relations. For a " "simple many-to-one, use in_() against the set of foreign key values.") + __hash__ = None + def __eq__(self, other): if other is None: if self.prop.direction in [ONETOMANY, MANYTOMANY]: @@ -583,7 +587,7 @@ class RelationProperty(StrategizedProperty): self.mapper = mapper.class_mapper(self.argument, compile=False) elif isinstance(self.argument, mapper.Mapper): self.mapper = self.argument - elif callable(self.argument): + elif util.callable(self.argument): # accept a callable to suit various deferred-configurational schemes self.mapper = mapper.class_mapper(self.argument(), compile=False) else: @@ -592,7 +596,7 @@ class RelationProperty(StrategizedProperty): # accept callables for other attributes which may require deferred initialization for attr in ('order_by', 'primaryjoin', 'secondaryjoin', 'secondary', '_foreign_keys', 'remote_side'): - if callable(getattr(self, attr)): + if util.callable(getattr(self, attr)): setattr(self, attr, getattr(self, attr)()) # in the case that InstrumentedAttributes were used to construct @@ -607,8 +611,8 @@ class RelationProperty(StrategizedProperty): if self.order_by: self.order_by = [expression._literal_as_column(x) for x in util.to_list(self.order_by)] - self._foreign_keys = set(expression._literal_as_column(x) for x in util.to_set(self._foreign_keys)) - self.remote_side = set(expression._literal_as_column(x) for x in util.to_set(self.remote_side)) + self._foreign_keys = util.column_set(expression._literal_as_column(x) for x in util.to_column_set(self._foreign_keys)) + self.remote_side = util.column_set(expression._literal_as_column(x) for x in util.to_column_set(self.remote_side)) if not self.parent.concrete: for inheriting in self.parent.iterate_to_root(): @@ -727,7 +731,7 @@ class RelationProperty(StrategizedProperty): else: self.secondary_synchronize_pairs = None - self._foreign_keys = set(r for l, r in self.synchronize_pairs) + self._foreign_keys = util.column_set(r for l, r in self.synchronize_pairs) if self.secondary_synchronize_pairs: self._foreign_keys.update(r for l, r in self.secondary_synchronize_pairs) @@ -814,7 +818,7 @@ class RelationProperty(StrategizedProperty): "Specify remote_side argument to indicate which column lazy " "join condition should bind." % (r, self.mapper)) - self.local_side, self.remote_side = [util.OrderedSet(x) for x in zip(*list(self.local_remote_pairs))] + self.local_side, self.remote_side = [util.ordered_column_set(x) for x in zip(*list(self.local_remote_pairs))] def _post_init(self): diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index da33eac41e..5a0c3faff9 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -753,7 +753,7 @@ class Query(object): """ aliased, from_joinpoint = kwargs.pop('aliased', False), kwargs.pop('from_joinpoint', False) if kwargs: - raise TypeError("unknown arguments: %s" % ','.join(kwargs.keys())) + raise TypeError("unknown arguments: %s" % ','.join(kwargs.iterkeys())) return self.__join(props, outerjoin=False, create_aliases=aliased, from_joinpoint=from_joinpoint) @util.accepts_a_list_as_starargs(list_deprecation='pending') @@ -766,7 +766,7 @@ class Query(object): """ aliased, from_joinpoint = kwargs.pop('aliased', False), kwargs.pop('from_joinpoint', False) if kwargs: - raise TypeError("unknown arguments: %s" % ','.join(kwargs.keys())) + raise TypeError("unknown arguments: %s" % ','.join(kwargs.iterkeys())) return self.__join(props, outerjoin=True, create_aliases=aliased, from_joinpoint=from_joinpoint) @_generative(__no_statement_condition, __no_limit_offset) diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 58aa71c6aa..a159e4bfaa 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -451,9 +451,9 @@ class LazyLoader(AbstractRelationLoader): return (new_execute, None) def _create_lazy_clause(cls, prop, reverse_direction=False): - binds = {} - lookup = {} - equated_columns = {} + binds = util.column_dict() + lookup = util.column_dict() + equated_columns = util.column_dict() if reverse_direction and not prop.secondaryjoin: for l, r in prop.local_remote_pairs: diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py index 778bf09499..4efab88ae6 100644 --- a/lib/sqlalchemy/orm/unitofwork.py +++ b/lib/sqlalchemy/orm/unitofwork.py @@ -269,7 +269,7 @@ class UOWTransaction(object): def elements(self): """Iterate UOWTaskElements.""" - for task in self.tasks.values(): + for task in self.tasks.itervalues(): for elem in task.elements: yield elem @@ -288,7 +288,7 @@ class UOWTransaction(object): def _sort_dependencies(self): nodes = topological.sort_with_cycles(self.dependencies, - [t.mapper for t in self.tasks.values() if t.base_task is t] + [t.mapper for t in self.tasks.itervalues() if t.base_task is t] ) ret = [] @@ -565,7 +565,7 @@ class UOWTask(object): # as part of the topological sort itself, which would # eliminate the need for this step (but may make the original # topological sort more expensive) - head = topological.sort_as_tree(tuples, object_to_original_task.keys()) + head = topological.sort_as_tree(tuples, object_to_original_task.iterkeys()) if head is not None: original_to_tasks = {} stack = [(head, t)] @@ -585,7 +585,7 @@ class UOWTask(object): task.append(state, originating_task._objects[state].listonly, isdelete=originating_task._objects[state].isdelete) if state in dependencies: - task.cyclical_dependencies.update(dependencies[state].values()) + task.cyclical_dependencies.update(dependencies[state].itervalues()) stack += [(n, task) for n in children] diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 541adf4e41..411c827c67 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -4,8 +4,6 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -import new - import sqlalchemy.exceptions as sa_exc from sqlalchemy import sql, util from sqlalchemy.sql import expression, util as sql_util, operators @@ -329,7 +327,7 @@ class AliasedClass(object): if hasattr(attr, 'func_code'): is_method = getattr(self.__target, key, None) if is_method and is_method.im_self is not None: - return new.instancemethod(attr.im_func, self, self) + return util.types.MethodType(attr.im_func, self, self) else: return None elif hasattr(attr, '__get__'): @@ -570,7 +568,8 @@ def _is_mapped_class(cls): from sqlalchemy.orm import mapperlib as mapper if isinstance(cls, (AliasedClass, mapper.Mapper)): return True - + if isinstance(cls, expression.ClauseElement): + return False manager = attributes.manager_of_class(cls) return manager and _INSTRUMENTOR in manager.info diff --git a/lib/sqlalchemy/pool.py b/lib/sqlalchemy/pool.py index 1d99215dc3..6aa8b0395d 100644 --- a/lib/sqlalchemy/pool.py +++ b/lib/sqlalchemy/pool.py @@ -814,14 +814,14 @@ class AssertionPool(Pool): return "AssertionPool" def create_connection(self): - raise "Invalid" + raise AssertionError("Invalid") def do_return_conn(self, conn): assert conn is self._conn and self.connection is None self.connection = conn def do_return_invalid(self, conn): - raise "Invalid" + raise AssertionError("Invalid") def do_get(self): assert self.connection is not None diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index dc523b36c6..5fa84063fd 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -574,7 +574,7 @@ class Column(SchemaItem, expression.ColumnClause): coltype = args[0] # adjust for partials - if callable(coltype): + if util.callable(coltype): coltype = args[0]() if (isinstance(coltype, types.AbstractType) or @@ -963,7 +963,7 @@ class ColumnDefault(DefaultGenerator): if isinstance(arg, FetchedValue): raise exc.ArgumentError( "ColumnDefault may not be a server-side default type.") - if callable(arg): + if util.callable(arg): arg = self._maybe_wrap_callable(arg) self.arg = arg @@ -1320,6 +1320,8 @@ class PrimaryKeyConstraint(Constraint): def copy(self, **kw): return PrimaryKeyConstraint(name=self.name, *[c.key for c in self]) + __hash__ = Constraint.__hash__ + def __eq__(self, other): return self.columns == other @@ -1663,7 +1665,7 @@ class MetaData(SchemaItem): if only is None: load = [name for name in available if name not in current] - elif callable(only): + elif util.callable(only): load = [name for name in available if name not in current and only(name, self)] else: @@ -1940,7 +1942,7 @@ class DDL(object): "Expected a string or unicode SQL statement, got '%r'" % statement) if (on is not None and - (not isinstance(on, basestring) and not callable(on))): + (not isinstance(on, basestring) and not util.callable(on))): raise exc.ArgumentError( "Expected the name of a database dialect or a callable for " "'on' criteria, got type '%s'." % type(on).__name__) diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 747978e762..921d932d2f 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -162,7 +162,7 @@ class DefaultCompiler(engine.Compiled): # a dictionary of _BindParamClause instances to "compiled" names that are # actually present in the generated SQL - self.bind_names = {} + self.bind_names = util.column_dict() # stack which keeps track of nested SELECT statements self.stack = [] @@ -205,6 +205,7 @@ class DefaultCompiler(engine.Compiled): """return a dictionary of bind parameter keys and values""" if params: + params = util.column_dict(params) pd = {} for bindparam, name in self.bind_names.iteritems(): for paramname in (bindparam, bindparam.key, bindparam.shortname, name): @@ -212,7 +213,7 @@ class DefaultCompiler(engine.Compiled): pd[name] = params[paramname] break else: - if callable(bindparam.value): + if util.callable(bindparam.value): pd[name] = bindparam.value() else: pd[name] = bindparam.value @@ -220,7 +221,7 @@ class DefaultCompiler(engine.Compiled): else: pd = {} for bindparam in self.bind_names: - if callable(bindparam.value): + if util.callable(bindparam.value): pd[self.bind_names[bindparam]] = bindparam.value() else: pd[self.bind_names[bindparam]] = bindparam.value @@ -317,7 +318,7 @@ class DefaultCompiler(engine.Compiled): sep = clauselist.operator if sep is None: sep = " " - elif sep == operators.comma_op: + elif sep is operators.comma_op: sep = ', ' else: sep = " " + self.operator_string(clauselist.operator) + " " @@ -336,7 +337,7 @@ class DefaultCompiler(engine.Compiled): name = self.function_string(func) - if callable(name): + if util.callable(name): return name(*[self.process(x) for x in func.clauses]) else: return ".".join(func.packagenames + [name]) % {'expr':self.function_argspec(func)} @@ -377,7 +378,7 @@ class DefaultCompiler(engine.Compiled): def visit_binary(self, binary, **kwargs): op = self.operator_string(binary.operator) - if callable(op): + if util.callable(op): return op(self.process(binary.left), self.process(binary.right), **binary.modifiers) else: return self.process(binary.left) + " " + op + " " + self.process(binary.right) diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 0f7f62e74c..a4ff72b1af 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -997,7 +997,7 @@ class ClauseElement(Visitable): of transformative operations. """ - s = set() + s = util.column_set() f = self while f is not None: s.add(f) @@ -1258,6 +1258,8 @@ class ColumnOperators(Operators): def __le__(self, other): return self.operate(operators.le, other) + __hash__ = Operators.__hash__ + def __eq__(self, other): return self.operate(operators.eq, other) @@ -1580,12 +1582,12 @@ class ColumnElement(ClauseElement, _CompareMixin): @util.memoized_property def base_columns(self): - return set(c for c in self.proxy_set + return util.column_set(c for c in self.proxy_set if not hasattr(c, 'proxies')) @util.memoized_property def proxy_set(self): - s = set([self]) + s = util.column_set([self]) if hasattr(self, 'proxies'): for c in self.proxies: s.update(c.proxy_set) @@ -1694,6 +1696,8 @@ class ColumnCollection(util.OrderedProperties): for c in iter: self.add(c) + __hash__ = None + def __eq__(self, other): l = [] for c in other: @@ -1711,9 +1715,9 @@ class ColumnCollection(util.OrderedProperties): # have to use a Set here, because it will compare the identity # of the column, not just using "==" for comparison which will always return a # "True" value (i.e. a BinaryClause...) - return col in set(self) + return col in util.column_set(self) -class ColumnSet(util.OrderedSet): +class ColumnSet(util.ordered_column_set): def contains_column(self, col): return col in self @@ -1733,7 +1737,7 @@ class ColumnSet(util.OrderedSet): return and_(*l) def __hash__(self): - return hash(tuple(self._list)) + return hash(tuple(x for x in self)) class Selectable(ClauseElement): """mark a class as being selectable""" @@ -1985,7 +1989,7 @@ class _BindParamClause(ColumnElement): d = self.__dict__.copy() v = self.value - if callable(v): + if util.callable(v): v = v() d['value'] = v return d @@ -2369,7 +2373,7 @@ class _BinaryExpression(ColumnElement): def self_group(self, against=None): # use small/large defaults for comparison so that unknown # operators are always parenthesized - if self.operator != against and operators.is_precedent(self.operator, against): + if self.operator is not against and operators.is_precedent(self.operator, against): return _Grouping(self) else: return self diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index 9b9b9ec094..d0ca0b01fd 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -7,6 +7,7 @@ from itertools import chain def sort_tables(tables): """sort a collection of Table objects in order of their foreign-key dependency.""" + tables = list(tables) tuples = [] def visit_foreign_key(fkey): if fkey.use_alter: @@ -60,7 +61,7 @@ def find_tables(clause, check_columns=False, include_aliases=False, include_join def find_columns(clause): """locate Column objects within the given expression.""" - cols = set() + cols = util.column_set() def visit_column(col): cols.add(col) visitors.traverse(clause, {}, {'column':visit_column}) @@ -182,7 +183,7 @@ class Annotated(object): # to this object's __dict__. clone.__dict__.update(self.__dict__) return Annotated(clone, self._annotations) - + def __hash__(self): return hash(self.__element) @@ -279,9 +280,9 @@ def reduce_columns(columns, *clauses, **kw): """ ignore_nonexistent_tables = kw.pop('ignore_nonexistent_tables', False) - columns = util.OrderedSet(columns) + columns = util.column_set(columns) - omit = set() + omit = util.column_set() for col in columns: for fk in col.foreign_keys: for c in columns: @@ -301,7 +302,7 @@ def reduce_columns(columns, *clauses, **kw): if clauses: def visit_binary(binary): if binary.operator == operators.eq: - cols = set(chain(*[c.proxy_set for c in columns.difference(omit)])) + cols = util.column_set(chain(*[c.proxy_set for c in columns.difference(omit)])) if binary.left in cols and binary.right in cols: for c in columns: if c.shares_lineage(binary.right): @@ -444,7 +445,7 @@ class ClauseAdapter(visitors.ReplacingCloningVisitor): self.selectable = selectable self.include = include self.exclude = exclude - self.equivalents = equivalents or {} + self.equivalents = util.column_dict(equivalents or {}) def _corresponding_column(self, col, require_embedded, _seen=util.EMPTY_SET): newcol = self.selectable.corresponding_column(col, require_embedded=require_embedded) @@ -484,7 +485,7 @@ class ColumnAdapter(ClauseAdapter): ClauseAdapter.__init__(self, selectable, equivalents, include, exclude) if chain_to: self.chain(chain_to) - self.columns = util.PopulateDict(self._locate_col) + self.columns = util.populate_column_dict(self._locate_col) def wrap(self, adapter): ac = self.__class__.__new__(self.__class__) @@ -492,7 +493,7 @@ class ColumnAdapter(ClauseAdapter): ac._locate_col = ac._wrap(ac._locate_col, adapter._locate_col) ac.adapt_clause = ac._wrap(ac.adapt_clause, adapter.adapt_clause) ac.adapt_list = ac._wrap(ac.adapt_list, adapter.adapt_list) - ac.columns = util.PopulateDict(ac._locate_col) + ac.columns = util.populate_column_dict(ac._locate_col) return ac adapt_clause = ClauseAdapter.traverse diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index 17b9c59d56..a5bd497aed 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -207,7 +207,7 @@ def traverse_depthfirst(obj, opts, visitors): def cloned_traverse(obj, opts, visitors): """clone the given expression structure, allowing modifications by visitors.""" - cloned = {} + cloned = util.column_dict() def clone(element): if element not in cloned: @@ -234,8 +234,8 @@ def cloned_traverse(obj, opts, visitors): def replacement_traverse(obj, opts, replace): """clone the given expression structure, allowing element replacement by a given replacement function.""" - cloned = {} - stop_on = set(opts.get('stop_on', [])) + cloned = util.column_dict() + stop_on = util.column_set(opts.get('stop_on', [])) def clone(element): newelem = replace(element) diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py index 2604f4e8fe..7eda27f4fb 100644 --- a/lib/sqlalchemy/types.py +++ b/lib/sqlalchemy/types.py @@ -382,7 +382,7 @@ class Concatenable(object): def adapt_operator(self, op): """Converts an add operator to concat.""" from sqlalchemy.sql import operators - if op == operators.add: + if op is operators.add: return operators.concat_op else: return op diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py index 8b68fb1086..1356fa324b 100644 --- a/lib/sqlalchemy/util.py +++ b/lib/sqlalchemy/util.py @@ -4,7 +4,7 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -import inspect, itertools, new, operator, sys, warnings, weakref +import inspect, itertools, operator, sys, warnings, weakref import __builtin__ types = __import__('types') @@ -18,8 +18,13 @@ except ImportError: import dummy_threading as threading from dummy_threading import local as ThreadLocal -if sys.version_info < (2, 6): +py3k = getattr(sys, 'py3kwarning', False) or sys.version_info >= (3, 0) + +if py3k: + set_types = set +elif sys.version_info < (2, 6): import sets + set_types = set, sets.Set else: # 2.6 deprecates sets.Set, but we still need to be able to detect them # in user code and as return values from DB-APIs @@ -32,15 +37,24 @@ else: import sets warnings.filters.remove(ignore) -set_types = set, sets.Set + set_types = set, sets.Set EMPTY_SET = frozenset() -try: - import cPickle as pickle -except ImportError: +if py3k: import pickle +else: + try: + import cPickle as pickle + except ImportError: + import pickle +if py3k: + def buffer(x): + return x # no-op until we figure out what MySQLdb is going to use +else: + buffer = __builtin__.buffer + if sys.version_info >= (2, 5): class PopulateDict(dict): """A dict which populates missing values via a creation function. @@ -70,6 +84,17 @@ else: self[key] = value = self.creator(key) return value +if py3k: + def callable(fn): + return hasattr(fn, '__call__') +else: + callable = __builtin__.callable + +if py3k: + from functools import reduce +else: + reduce = __builtin__.reduce + try: from collections import defaultdict except ImportError: @@ -125,6 +150,14 @@ def to_set(x): else: return x +def to_column_set(x): + if x is None: + return column_set() + if not isinstance(x, column_set): + return column_set(to_list(x)) + else: + return x + try: from functools import update_wrapper @@ -823,10 +856,11 @@ class IdentitySet(object): This strategy has edge cases for builtin types- it's possible to have two 'foo' strings in one of these sets, for example. Use sparingly. + """ _working_set = set - + def __init__(self, iterable=None): self._members = dict() if iterable: @@ -918,7 +952,7 @@ class IdentitySet(object): result = type(self)() # testlib.pragma exempt:__hash__ result._members.update( - self._working_set(self._members.iteritems()).union(_iter_id(iterable))) + self._working_set(self._member_id_tuples()).union(_iter_id(iterable))) return result def __or__(self, other): @@ -939,7 +973,7 @@ class IdentitySet(object): result = type(self)() # testlib.pragma exempt:__hash__ result._members.update( - self._working_set(self._members.iteritems()).difference(_iter_id(iterable))) + self._working_set(self._member_id_tuples()).difference(_iter_id(iterable))) return result def __sub__(self, other): @@ -960,7 +994,7 @@ class IdentitySet(object): result = type(self)() # testlib.pragma exempt:__hash__ result._members.update( - self._working_set(self._members.iteritems()).intersection(_iter_id(iterable))) + self._working_set(self._member_id_tuples()).intersection(_iter_id(iterable))) return result def __and__(self, other): @@ -981,9 +1015,12 @@ class IdentitySet(object): result = type(self)() # testlib.pragma exempt:__hash__ result._members.update( - self._working_set(self._members.iteritems()).symmetric_difference(_iter_id(iterable))) + self._working_set(self._member_id_tuples()).symmetric_difference(_iter_id(iterable))) return result - + + def _member_id_tuples(self): + return ((id(v), v) for v in self._members.itervalues()) + def __xor__(self, other): if not isinstance(other, IdentitySet): return NotImplemented @@ -1016,11 +1053,6 @@ class IdentitySet(object): return '%s(%r)' % (type(self).__name__, self._members.values()) -def _iter_id(iterable): - """Generator: ((id(o), o) for o in iterable).""" - for item in iterable: - yield id(item), item - class OrderedIdentitySet(IdentitySet): class _working_set(OrderedSet): # a testing pragma: exempt the OIDS working set from the test suite's @@ -1028,7 +1060,7 @@ class OrderedIdentitySet(IdentitySet): # but it's safe here: IDS operates on (id, instance) tuples in the # working set. __sa_hash_exempt__ = True - + def __init__(self, iterable=None): IdentitySet.__init__(self) self._members = OrderedDict() @@ -1036,6 +1068,19 @@ class OrderedIdentitySet(IdentitySet): for o in iterable: self.add(o) +def _iter_id(iterable): + """Generator: ((id(o), o) for o in iterable).""" + + for item in iterable: + yield id(item), item + +# define collections that are capable of storing +# ColumnElement objects as hashable keys/elements. +column_set = set +column_dict = dict +ordered_column_set = OrderedSet +populate_column_dict = PopulateDict + def unique_list(seq, compare_with=set): seen = compare_with() return [x for x in seq if x not in seen and not seen.add(x)] @@ -1296,7 +1341,7 @@ def function_named(fn, name): try: fn.__name__ = name except TypeError: - fn = new.function(fn.func_code, fn.func_globals, name, + fn = types.FunctionType(fn.func_code, fn.func_globals, name, fn.func_defaults, fn.func_closure) return fn diff --git a/test/base/utils.py b/test/base/utils.py index c3d026f045..0add6dd66c 100644 --- a/test/base/utils.py +++ b/test/base/utils.py @@ -116,6 +116,7 @@ class HashOverride(object): class EqOverride(object): def __init__(self, value=None): self.value = value + __hash__ = object.__hash__ def __eq__(self, other): if isinstance(other, EqOverride): return self.value == other.value @@ -260,6 +261,32 @@ class IdentitySetTest(unittest.TestCase): self.assertRaises(TypeError, lambda: s1 - os2) self.assertRaises(TypeError, lambda: s1 - [3, 4, 5]) +class OrderedIdentitySetTest(unittest.TestCase): + + def assert_eq(self, identityset, expected_iterable): + expected = [id(o) for o in expected_iterable] + found = [id(o) for o in identityset] + eq_(found, expected) + + def test_add(self): + elem = object + s = util.OrderedIdentitySet() + s.add(elem()) + s.add(elem()) + + def test_intersection(self): + elem = object + eq_ = self.assert_eq + + a, b, c, d, e, f, g = elem(), elem(), elem(), elem(), elem(), elem(), elem() + + s1 = util.OrderedIdentitySet([a, b, c]) + s2 = util.OrderedIdentitySet([d, e, f]) + s3 = util.OrderedIdentitySet([a, d, f, g]) + eq_(s1.intersection(s2), []) + eq_(s1.intersection(s3), [a]) + eq_(s1.union(s2).intersection(s3), [a, d, f]) + class DictlikeIteritemsTest(unittest.TestCase): baseline = set([('a', 1), ('b', 2), ('c', 3)]) diff --git a/test/ext/serializer.py b/test/ext/serializer.py index 0a900a9e9f..21765ff9c6 100644 --- a/test/ext/serializer.py +++ b/test/ext/serializer.py @@ -3,7 +3,7 @@ import testenv; testenv.configure_for_tests() from sqlalchemy.ext import serializer from sqlalchemy import exc from testlib import sa, testing -from testlib.sa import MetaData, Table, Column, Integer, String, ForeignKey, select, desc, func +from testlib.sa import MetaData, Table, Column, Integer, String, ForeignKey, select, desc, func, util from testlib.sa.orm import relation, sessionmaker, scoped_session, class_mapper, mapper, eagerload, compile_mappers, aliased from testlib.testing import eq_ from orm._base import ComparableEntity, MappedTest @@ -88,7 +88,9 @@ class SerializeTest(testing.ORMTest): re_expr.execute().fetchall(), [(7, u'jack'), (8, u'ed'), (8, u'ed'), (8, u'ed'), (9, u'fred')] ) - + + # fails due to pure Python pickle bug: http://bugs.python.org/issue998998 + @testing.fails_if(lambda: util.py3k) def test_query(self): q = Session.query(User).filter(User.name=='ed').options(eagerload(User.addresses)) eq_(q.all(), [User(name='ed', addresses=[Address(id=2), Address(id=3), Address(id=4)])]) diff --git a/test/orm/assorted_eager.py b/test/orm/assorted_eager.py index 1fc05380ab..c6444a333c 100644 --- a/test/orm/assorted_eager.py +++ b/test/orm/assorted_eager.py @@ -329,7 +329,7 @@ class EagerTest3(_base.MappedTest): arb_result = arb_data.execute().fetchall() # order the result list descending based on 'max' - arb_result.sort(lambda a, b: cmp(b['max'], a['max'])) + arb_result.sort(key = lambda a: a['max'], reverse=True) # extract just the "data_id" from it arb_result = [row['data_id'] for row in arb_result] diff --git a/test/orm/collection.py b/test/orm/collection.py index c37a20b681..f121450219 100644 --- a/test/orm/collection.py +++ b/test/orm/collection.py @@ -148,7 +148,7 @@ class CollectionsTest(_base.ORMTest): control[0] = e assert_eq() - if reduce(and_, [hasattr(direct, a) for a in + if util.reduce(and_, [hasattr(direct, a) for a in ('__delitem__', 'insert', '__len__')], True): values = [creator(), creator(), creator(), creator()] direct[slice(0,1)] = values @@ -365,6 +365,7 @@ class CollectionsTest(_base.ORMTest): assert False def __iter__(self): return iter(self.data) + __hash__ = object.__hash__ def __eq__(self, other): return self.data == other def __repr__(self): @@ -392,6 +393,7 @@ class CollectionsTest(_base.ORMTest): assert False def __iter__(self): return iter(self.data) + __hash__ = object.__hash__ def __eq__(self, other): return self.data == other def __repr__(self): @@ -735,6 +737,7 @@ class CollectionsTest(_base.ORMTest): self.data.update(other) def __iter__(self): return iter(self.data) + __hash__ = object.__hash__ def __eq__(self, other): return self.data == other @@ -760,6 +763,7 @@ class CollectionsTest(_base.ORMTest): self.data.update(other) def __iter__(self): return iter(self.data) + __hash__ = object.__hash__ def __eq__(self, other): return self.data == other @@ -1037,6 +1041,7 @@ class CollectionsTest(_base.ORMTest): @collection.iterator def itervalues(self): return self.data.itervalues() + __hash__ = object.__hash__ def __eq__(self, other): return self.data == other def __repr__(self): @@ -1076,6 +1081,7 @@ class CollectionsTest(_base.ORMTest): @collection.iterator def itervalues(self): return self.data.itervalues() + __hash__ = object.__hash__ def __eq__(self, other): return self.data == other def __repr__(self): @@ -1153,6 +1159,7 @@ class CollectionsTest(_base.ORMTest): @collection.iterator def __iter__(self): return iter(self.data) + __hash__ = object.__hash__ def __eq__(self, other): return self.data == other @@ -1183,6 +1190,7 @@ class CollectionsTest(_base.ORMTest): @collection.iterator def __iter__(self): return iter(self.data) + __hash__ = object.__hash__ def __eq__(self, other): return self.data == other diff --git a/test/orm/utils.py b/test/orm/utils.py index 0a449fbf70..52c055110c 100644 --- a/test/orm/utils.py +++ b/test/orm/utils.py @@ -143,7 +143,7 @@ class AliasedClassTest(TestBase): def test_hybrid_descriptors(self): from sqlalchemy import Column # override testlib's override - import new + import types class MethodDescriptor(object): def __init__(self, func): @@ -153,7 +153,7 @@ class AliasedClassTest(TestBase): args = (self.func, owner, owner.__class__) else: args = (self.func, instance, owner) - return new.instancemethod(*args) + return types.MethodType(*args) class PropertyDescriptor(object): def __init__(self, fget, fset, fdel): diff --git a/test/pickleable.py b/test/pickleable.py index 3ffc1e59be..ffb22f3a24 100644 --- a/test/pickleable.py +++ b/test/pickleable.py @@ -7,6 +7,7 @@ class Foo(object): self.data = 'im data' self.stuff = 'im stuff' self.moredata = moredata + __hash__ = object.__hash__ def __eq__(self, other): return other.data == self.data and other.stuff == self.stuff and other.moredata==self.moredata @@ -15,6 +16,7 @@ class Bar(object): def __init__(self, x, y): self.x = x self.y = y + __hash__ = object.__hash__ def __eq__(self, other): return other.__class__ is self.__class__ and other.x==self.x and other.y==self.y def __str__(self): diff --git a/test/sql/selectable.py b/test/sql/selectable.py index 3f9464283d..eb8bc861f5 100755 --- a/test/sql/selectable.py +++ b/test/sql/selectable.py @@ -8,6 +8,7 @@ from testlib import * from sqlalchemy.sql import util as sql_util, visitors from sqlalchemy import exc from sqlalchemy.sql import table, column +from sqlalchemy import util metadata = MetaData() table1 = Table('table1', metadata, @@ -288,13 +289,13 @@ class PrimaryKeyTest(TestBase, AssertsExecutionResults): ) self.assertEquals( - set(employee.join(engineer, employee.c.id==engineer.c.id).primary_key), - set([employee.c.id]) + util.column_set(employee.join(engineer, employee.c.id==engineer.c.id).primary_key), + util.column_set([employee.c.id]) ) self.assertEquals( - set(employee.join(engineer, engineer.c.id==employee.c.id).primary_key), - set([employee.c.id]) + util.column_set(employee.join(engineer, engineer.c.id==employee.c.id).primary_key), + util.column_set([employee.c.id]) ) @@ -313,8 +314,8 @@ class ReduceTest(TestBase, AssertsExecutionResults): self.assertEquals( - set(sql_util.reduce_columns([t1.c.t1id, t1.c.t1data, t2.c.t2id, t2.c.t2data, t3.c.t3id, t3.c.t3data])), - set([t1.c.t1id, t1.c.t1data, t2.c.t2data, t3.c.t3data]) + util.column_set(sql_util.reduce_columns([t1.c.t1id, t1.c.t1data, t2.c.t2id, t2.c.t2data, t3.c.t3id, t3.c.t3data])), + util.column_set([t1.c.t1id, t1.c.t1data, t2.c.t2data, t3.c.t3data]) ) def test_reduce_selectable(self): @@ -332,8 +333,8 @@ class ReduceTest(TestBase, AssertsExecutionResults): s = select([engineers, managers]).where(engineers.c.engineer_name==managers.c.manager_name) - self.assertEquals(set(sql_util.reduce_columns(list(s.c), s)), - set([s.c.engineer_id, s.c.engineer_name, s.c.manager_id]) + self.assertEquals(util.column_set(sql_util.reduce_columns(list(s.c), s)), + util.column_set([s.c.engineer_id, s.c.engineer_name, s.c.manager_id]) ) def test_reduce_aliased_join(self): @@ -358,8 +359,8 @@ class ReduceTest(TestBase, AssertsExecutionResults): pjoin = people.outerjoin(engineers).outerjoin(managers).select(use_labels=True).alias('pjoin') self.assertEquals( - set(sql_util.reduce_columns([pjoin.c.people_person_id, pjoin.c.engineers_person_id, pjoin.c.managers_person_id])), - set([pjoin.c.people_person_id]) + util.column_set(sql_util.reduce_columns([pjoin.c.people_person_id, pjoin.c.engineers_person_id, pjoin.c.managers_person_id])), + util.column_set([pjoin.c.people_person_id]) ) def test_reduce_aliased_union(self): @@ -382,8 +383,8 @@ class ReduceTest(TestBase, AssertsExecutionResults): }, None, 'item_join') self.assertEquals( - set(sql_util.reduce_columns([item_join.c.id, item_join.c.dummy, item_join.c.child_name])), - set([item_join.c.id, item_join.c.dummy, item_join.c.child_name]) + util.column_set(sql_util.reduce_columns([item_join.c.id, item_join.c.dummy, item_join.c.child_name])), + util.column_set([item_join.c.id, item_join.c.dummy, item_join.c.child_name]) ) def test_reduce_aliased_union_2(self): @@ -407,8 +408,8 @@ class ReduceTest(TestBase, AssertsExecutionResults): }, None, 'page_join') self.assertEquals( - set(sql_util.reduce_columns([pjoin.c.id, pjoin.c.page_id, pjoin.c.magazine_page_id])), - set([pjoin.c.id]) + util.column_set(sql_util.reduce_columns([pjoin.c.id, pjoin.c.page_id, pjoin.c.magazine_page_id])), + util.column_set([pjoin.c.id]) ) diff --git a/test/sql/testtypes.py b/test/sql/testtypes.py index 02eebf6312..7f8240d4c3 100644 --- a/test/sql/testtypes.py +++ b/test/sql/testtypes.py @@ -292,6 +292,9 @@ class UnicodeTest(TestBase, AssertsExecutionResults): assert unicode_table.c.unicode_varchar.type.length == 250 rawdata = 'Alors vous imaginez ma surprise, au lever du jour, quand une dr\xc3\xb4le de petit voix m\xe2\x80\x99a r\xc3\xa9veill\xc3\xa9. Elle disait: \xc2\xab S\xe2\x80\x99il vous pla\xc3\xaet\xe2\x80\xa6 dessine-moi un mouton! \xc2\xbb\n' unicodedata = rawdata.decode('utf-8') + if testing.against('sqlite'): + rawdata = "something" + unicode_table.insert().execute(unicode_varchar=unicodedata, unicode_text=unicodedata, plain_varchar=rawdata) @@ -301,7 +304,8 @@ class UnicodeTest(TestBase, AssertsExecutionResults): if isinstance(x['plain_varchar'], unicode): # SQLLite and MSSQL return non-unicode data as unicode self.assert_(testing.against('sqlite', 'mssql')) - self.assert_(x['plain_varchar'] == unicodedata) + if not testing.against('sqlite'): + self.assert_(x['plain_varchar'] == unicodedata) print "it's %s!" % testing.db.name else: self.assert_(not isinstance(x['plain_varchar'], unicode) and x['plain_varchar'] == rawdata) @@ -311,6 +315,8 @@ class UnicodeTest(TestBase, AssertsExecutionResults): rawdata = 'Alors vous imaginez ma surprise, au lever du jour, quand une dr\xc3\xb4le de petit voix m\xe2\x80\x99a r\xc3\xa9veill\xc3\xa9. Elle disait: \xc2\xab S\xe2\x80\x99il vous pla\xc3\xaet\xe2\x80\xa6 dessine-moi un mouton! \xc2\xbb\n' unicodedata = rawdata.decode('utf-8') + if testing.against('sqlite'): + rawdata = "something" unicode_table.insert().execute(unicode_varchar=unicodedata, unicode_text=unicodedata, plain_varchar=rawdata) @@ -358,6 +364,8 @@ class UnicodeTest(TestBase, AssertsExecutionResults): testing.db.engine.dialect.assert_unicode = False rawdata = 'Alors vous imaginez ma surprise, au lever du jour, quand une dr\xc3\xb4le de petit voix m\xe2\x80\x99a r\xc3\xa9veill\xc3\xa9. Elle disait: \xc2\xab S\xe2\x80\x99il vous pla\xc3\xaet\xe2\x80\xa6 dessine-moi un mouton! \xc2\xbb\n' unicodedata = rawdata.decode('utf-8') + if testing.against('sqlite'): + rawdata = "something" unicode_table.insert().execute(unicode_varchar=unicodedata, unicode_text=unicodedata, plain_varchar=rawdata) @@ -368,7 +376,8 @@ class UnicodeTest(TestBase, AssertsExecutionResults): print 3, repr(x['plain_varchar']) self.assert_(isinstance(x['unicode_varchar'], unicode) and x['unicode_varchar'] == unicodedata) self.assert_(isinstance(x['unicode_text'], unicode) and x['unicode_text'] == unicodedata) - self.assert_(isinstance(x['plain_varchar'], unicode) and x['plain_varchar'] == unicodedata) + if not testing.against('sqlite'): + self.assert_(isinstance(x['plain_varchar'], unicode) and x['plain_varchar'] == unicodedata) finally: testing.db.engine.dialect.convert_unicode = prev_unicode testing.db.engine.dialect.convert_unicode = prev_assert diff --git a/test/testlib/compat.py b/test/testlib/compat.py index 0b157e64a0..374f144f64 100644 --- a/test/testlib/compat.py +++ b/test/testlib/compat.py @@ -1,4 +1,4 @@ -import new +import types __all__ = '_function_named', @@ -7,7 +7,7 @@ def _function_named(fn, newname): try: fn.__name__ = newname except: - fn = new.function(fn.func_code, fn.func_globals, newname, + fn = types.FunctionType(fn.func_code, fn.func_globals, newname, fn.func_defaults, fn.func_closure) return fn diff --git a/test/testlib/engines.py b/test/testlib/engines.py index 000b188ce2..2a16d3f494 100644 --- a/test/testlib/engines.py +++ b/test/testlib/engines.py @@ -2,6 +2,7 @@ import sys, types, weakref from collections import deque from testlib import config from testlib.compat import _function_named +from sqlalchemy.util import callable class ConnectionKiller(object): def __init__(self): diff --git a/test/testlib/fixtures.py b/test/testlib/fixtures.py index 854cd7c5ae..8629cca7f2 100644 --- a/test/testlib/fixtures.py +++ b/test/testlib/fixtures.py @@ -25,6 +25,8 @@ class Base(object): def __ne__(self, other): return not self.__eq__(other) + __hash__ = object.__hash__ + def __eq__(self, other): """'passively' compare this object to another. diff --git a/test/testlib/testing.py b/test/testlib/testing.py index 650fa86a75..b26bcd93c3 100644 --- a/test/testlib/testing.py +++ b/test/testlib/testing.py @@ -11,6 +11,8 @@ import unittest import warnings from cStringIO import StringIO +from sqlalchemy.util import callable + import testlib.config as config from testlib.compat import _function_named diff --git a/test/zblog/user.py b/test/zblog/user.py index 973413d922..4fe3b8f8d9 100644 --- a/test/zblog/user.py +++ b/test/zblog/user.py @@ -1,7 +1,7 @@ """user.py - handles user login and validation""" import random, string -from sha import sha +from hashlib import sha1 as sha administrator = 'admin' user = 'user'