From 041a329e69f6aa60bdd2f3fb87b5172481806c4a Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Tue, 16 Nov 2010 15:53:14 -0500 Subject: [PATCH] - adapt initial patch from [ticket:1917] to current tip - raise TypeError for immutability --- lib/sqlalchemy/orm/mapper.py | 4 +- lib/sqlalchemy/orm/query.py | 8 +- lib/sqlalchemy/schema.py | 19 +++-- lib/sqlalchemy/sql/expression.py | 123 ++++++++++++++++++++----------- lib/sqlalchemy/sql/operators.py | 2 +- lib/sqlalchemy/util.py | 92 ++++++++++++----------- test/engine/test_metadata.py | 27 ++++++- test/engine/test_reflection.py | 1 + test/sql/test_compiler.py | 2 +- 9 files changed, 172 insertions(+), 106 deletions(-) diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index e9da4f5337..8abb26fb6f 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -162,7 +162,7 @@ class Mapper(object): else: self.with_polymorphic = None - if isinstance(self.local_table, expression._SelectBaseMixin): + if isinstance(self.local_table, expression._SelectBase): raise sa_exc.InvalidRequestError( "When mapping against a select() construct, map against " "an alias() of the construct instead." @@ -172,7 +172,7 @@ class Mapper(object): if self.with_polymorphic and \ isinstance(self.with_polymorphic[1], - expression._SelectBaseMixin): + expression._SelectBase): self.with_polymorphic = (self.with_polymorphic[0], self.with_polymorphic[1].alias()) diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 2bccb8f73f..2f482537d7 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -162,7 +162,7 @@ class Query(object): fa = [] for from_obj in obj: - if isinstance(from_obj, expression._SelectBaseMixin): + if isinstance(from_obj, expression._SelectBase): from_obj = from_obj.alias() fa.append(from_obj) @@ -1597,7 +1597,7 @@ class Query(object): if not isinstance(statement, (expression._TextClause, - expression._SelectBaseMixin)): + expression._SelectBase)): raise sa_exc.ArgumentError( "from_statement accepts text(), select(), " "and union() objects only.") @@ -2468,7 +2468,7 @@ class Query(object): for hint in self._with_hints: statement = statement.with_hint(*hint) - + if self._execution_options: statement = statement.execution_options( **self._execution_options) @@ -2803,7 +2803,7 @@ class QueryContext(object): def __init__(self, query): if query._statement is not None: - if isinstance(query._statement, expression._SelectBaseMixin) and \ + if isinstance(query._statement, expression._SelectBase) and \ not query._statement.use_labels: self.statement = query._statement.apply_labels() else: diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index a332cec361..e7a5d6e464 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -227,7 +227,7 @@ class Table(SchemaItem, expression.TableClause): self.constraints = set() self._columns = expression.ColumnCollection() self._set_primary_key(PrimaryKeyConstraint()) - self._foreign_keys = util.OrderedSet() + self.foreign_keys = util.OrderedSet() self._extra_dependencies = set() self.ddl_listeners = util.defaultdict(list) self.kwargs = {} @@ -283,7 +283,7 @@ class Table(SchemaItem, expression.TableClause): if include_columns: for c in self.c: if c.name not in include_columns: - self.c.remove(c) + self._columns.remove(c) for key in ('quote', 'quote_schema'): if key in kwargs: @@ -307,10 +307,13 @@ class Table(SchemaItem, expression.TableClause): "Invalid argument(s) for Table: %r" % kwargs.keys()) self.kwargs.update(kwargs) + def _init_collections(self): + pass + def _set_primary_key(self, pk): - if getattr(self, '_primary_key', None) in self.constraints: - self.constraints.remove(self._primary_key) - self._primary_key = pk + if self.primary_key in self.constraints: + self.constraints.remove(self.primary_key) + self.primary_key = pk self.constraints.add(pk) for c in pk.columns: @@ -330,10 +333,6 @@ class Table(SchemaItem, expression.TableClause): def key(self): return _get_table_key(self.name, self.schema) - @property - def primary_key(self): - return self._primary_key - def __repr__(self): return "Table(%s)" % ', '.join( [repr(self.name)] + [repr(self.metadata)] + @@ -937,7 +936,7 @@ class Column(SchemaItem, expression.ColumnClause): nullable = self.nullable, quote=self.quote, _proxies=[self], *fk) c.table = selectable - selectable.columns.add(c) + selectable._columns.add(c) if self.primary_key: selectable.primary_key.add(c) for fn in c._table_events: diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index c3dc339a50..0f93643dc7 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -1658,7 +1658,7 @@ class _CompareMixin(ColumnOperators): if isinstance(seq_or_selectable, _ScalarSelect): return self.__compare(op, seq_or_selectable, negate=negate_op) - elif isinstance(seq_or_selectable, _SelectBaseMixin): + elif isinstance(seq_or_selectable, _SelectBase): # TODO: if we ever want to support (x, y, z) IN (select x, # y, z from table), we would need a multi-column version of @@ -1830,7 +1830,7 @@ class _CompareMixin(ColumnOperators): return other.__clause_element__() elif not isinstance(other, ClauseElement): return self._bind_param(operator, other) - elif isinstance(other, (_SelectBaseMixin, Alias)): + elif isinstance(other, (_SelectBase, Alias)): return other.as_scalar() else: return other @@ -1905,7 +1905,7 @@ class ColumnElement(ClauseElement, _CompareMixin): co = ColumnClause(name, selectable, type_=getattr(self, 'type', None)) co.proxies = [self] - selectable.columns[key] = co + selectable._columns[key] = co return co def compare(self, other, use_proxies=False, equivalents=None, **kw): @@ -2044,6 +2044,16 @@ class ColumnCollection(util.OrderedProperties): # always return a "True" value (i.e. a BinaryClause...) return col in util.column_set(self) + + def as_immutable(self): + return ImmutableColumnCollection(self._data) + +class ImmutableColumnCollection(util.ImmutableProperties, ColumnCollection): + def __init__(self, data): + util.ImmutableProperties.__init__(self, data) + + extend = remove = util.ImmutableProperties._immutable + class ColumnSet(util.ordered_column_set): def contains_column(self, col): @@ -2239,44 +2249,50 @@ class FromClause(Selectable): def _reset_exported(self): """delete memoized collections when a FromClause is cloned.""" - for attr in '_columns', '_primary_key', '_foreign_keys', \ - 'locate_all_froms': - self.__dict__.pop(attr, None) + for name in 'primary_key', '_columns', 'columns', \ + 'foreign_keys', 'locate_all_froms': + self.__dict__.pop(name, None) @util.memoized_property - def _columns(self): + def columns(self): """Return the collection of Column objects contained by this FromClause.""" - - self._export_columns() - return self._columns - + + if '_columns' not in self.__dict__: + self._init_collections() + self._populate_column_collection() + return self._columns.as_immutable() + @util.memoized_property - def _primary_key(self): + def primary_key(self): """Return the collection of Column objects which comprise the primary key of this FromClause.""" - - self._export_columns() - return self._primary_key - + + self._init_collections() + self._populate_column_collection() + return self.primary_key + @util.memoized_property - def _foreign_keys(self): + def foreign_keys(self): """Return the collection of ForeignKey objects which this FromClause references.""" + + self._init_collections() + self._populate_column_collection() + return self.foreign_keys - self._export_columns() - return self._foreign_keys - columns = property(attrgetter('_columns'), doc=_columns.__doc__) - primary_key = property(attrgetter('_primary_key'), - doc=_primary_key.__doc__) - foreign_keys = property(attrgetter('_foreign_keys'), - doc=_foreign_keys.__doc__) - - # synonyms for 'columns' - - c = _select_iterable = property(attrgetter('columns'), - doc=_columns.__doc__) - + c = property(attrgetter('columns')) + _select_iterable = property(attrgetter('columns')) + + def _init_collections(self): + assert '_columns' not in self.__dict__ + assert 'primary_key' not in self.__dict__ + assert 'foreign_keys' not in self.__dict__ + + self._columns = ColumnCollection() + self.primary_key = ColumnSet() + self.foreign_keys = set() + def _export_columns(self): """Initialize column collections.""" @@ -3009,7 +3025,7 @@ class _Exists(_UnaryExpression): _from_objects = [] def __init__(self, *args, **kwargs): - if args and isinstance(args[0], (_SelectBaseMixin, _ScalarSelect)): + if args and isinstance(args[0], (_SelectBase, _ScalarSelect)): s = args[0] else: if not args: @@ -3088,10 +3104,10 @@ class Join(FromClause): columns = [c for c in self.left.columns] + \ [c for c in self.right.columns] - self._primary_key.extend(sqlutil.reduce_columns( + self.primary_key.extend(sqlutil.reduce_columns( (c for c in columns if c.primary_key), self.onclause)) self._columns.update((col._label, col) for col in columns) - self._foreign_keys.update(itertools.chain( + self.foreign_keys.update(itertools.chain( *[col.foreign_keys for col in columns])) def _copy_internals(self, clone=_clone): @@ -3281,11 +3297,25 @@ class _FromGrouping(FromClause): def __init__(self, element): self.element = element - + + def _init_collections(self): + pass + @property def columns(self): return self.element.columns + @property + def primary_key(self): + return self.element.primary_key + + @property + def foreign_keys(self): + # this could be + # self.element.foreign_keys + # see SelectableTest.test_join_condition + return set() + @property def _hide_froms(self): return self.element._hide_froms @@ -3476,7 +3506,7 @@ class ColumnClause(_Immutable, ColumnElement): ) c.proxies = [self] if attach: - selectable.columns[c.name] = c + selectable._columns[c.name] = c return c class TableClause(_Immutable, FromClause): @@ -3496,11 +3526,14 @@ class TableClause(_Immutable, FromClause): super(TableClause, self).__init__() self.name = self.fullname = name self._columns = ColumnCollection() - self._primary_key = ColumnSet() - self._foreign_keys = set() + self.primary_key = ColumnSet() + self.foreign_keys = set() for c in columns: self.append_column(c) - + + def _init_collections(self): + pass + def _export_columns(self): raise NotImplementedError() @@ -3556,7 +3589,7 @@ class TableClause(_Immutable, FromClause): def _from_objects(self): return [self] -class _SelectBaseMixin(Executable): +class _SelectBase(Executable, FromClause): """Base class for :class:`Select` and ``CompoundSelects``.""" def __init__(self, @@ -3583,7 +3616,7 @@ class _SelectBaseMixin(Executable): self._order_by_clause = ClauseList(*util.to_list(order_by) or []) self._group_by_clause = ClauseList(*util.to_list(group_by) or []) - + def as_scalar(self): """return a 'scalar' representation of this selectable, which can be used as a column expression. @@ -3729,7 +3762,7 @@ class _ScalarSelect(_Grouping): def _make_proxy(self, selectable, name): return list(self.inner_columns)[0]._make_proxy(selectable, name) -class CompoundSelect(_SelectBaseMixin, FromClause): +class CompoundSelect(_SelectBase): """Forms the basis of ``UNION``, ``UNION ALL``, and other SELECT-based set operations.""" @@ -3764,7 +3797,7 @@ class CompoundSelect(_SelectBaseMixin, FromClause): self.selects.append(s.self_group(self)) - _SelectBaseMixin.__init__(self, **kwargs) + _SelectBase.__init__(self, **kwargs) def _scalar_type(self): return self.selects[0]._scalar_type() @@ -3830,7 +3863,7 @@ class CompoundSelect(_SelectBaseMixin, FromClause): self._bind = bind bind = property(bind, _set_bind) -class Select(_SelectBaseMixin, FromClause): +class Select(_SelectBase): """Represents a ``SELECT`` statement. Select statements support appendable clauses, as well as the @@ -3859,7 +3892,7 @@ class Select(_SelectBaseMixin, FromClause): argument descriptions. Additional generative and mutator methods are available on the - :class:`_SelectBaseMixin` superclass. + :class:`_SelectBase` superclass. """ self._should_correlate = correlate @@ -3907,7 +3940,7 @@ class Select(_SelectBaseMixin, FromClause): if prefixes: self._prefixes = tuple([_literal_as_text(p) for p in prefixes]) - _SelectBaseMixin.__init__(self, **kwargs) + _SelectBase.__init__(self, **kwargs) def _get_display_froms(self, existing_froms=None): """Return the full list of 'from' clauses to be displayed. diff --git a/lib/sqlalchemy/sql/operators.py b/lib/sqlalchemy/sql/operators.py index 6f70b1778d..67830f7cf5 100644 --- a/lib/sqlalchemy/sql/operators.py +++ b/lib/sqlalchemy/sql/operators.py @@ -83,7 +83,7 @@ def desc_op(a): def asc_op(a): return a.asc() -_commutative = set([eq, ne, add, mul]) +_commutative = set([eq, ne, add, mul, and_]) def is_commutative(op): return op in _commutative diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py index 8665cd0d4a..cfeb38f54c 100644 --- a/lib/sqlalchemy/util.py +++ b/lib/sqlalchemy/util.py @@ -146,36 +146,6 @@ except ImportError: return 'defaultdict(%s, %s)' % (self.default_factory, dict.__repr__(self)) -class frozendict(dict): - @property - def _blocked_attribute(obj): - raise AttributeError, "A frozendict cannot be modified." - - __delitem__ = __setitem__ = clear = _blocked_attribute - pop = popitem = setdefault = update = _blocked_attribute - - def __new__(cls, *args): - new = dict.__new__(cls) - dict.__init__(new, *args) - return new - - def __init__(self, *args): - pass - - def __reduce__(self): - return frozendict, (dict(self), ) - - def union(self, d): - if not self: - return frozendict(d) - else: - d2 = self.copy() - d2.update(d) - return frozendict(d2) - - def __repr__(self): - return "frozendict(%s)" % dict.__repr__(self) - # find or create a dict implementation that supports __missing__ class _probe(dict): @@ -759,20 +729,44 @@ class NamedTuple(tuple): def keys(self): return [l for l in self._labels if l is not None] +class ImmutableContainer(object): + def _immutable(self, *arg, **kw): + raise TypeError("%s object is immutable" % self.__class__.__name__) -class OrderedProperties(object): - """An object that maintains the order in which attributes are set upon it. + __delitem__ = __setitem__ = __setattr__ = _immutable - Also provides an iterator and a very basic getitem/setitem - interface to those attributes. +class frozendict(ImmutableContainer, dict): + + clear = pop = popitem = setdefault = \ + update = ImmutableContainer._immutable - (Not really a dict, since it iterates over values, not keys. Not really - a list, either, since each value must have a key associated; hence there is - no append or extend.) - """ + def __new__(cls, *args): + new = dict.__new__(cls) + dict.__init__(new, *args) + return new - def __init__(self): - self.__dict__['_data'] = OrderedDict() + def __init__(self, *args): + pass + + def __reduce__(self): + return frozendict, (dict(self), ) + + def union(self, d): + if not self: + return frozendict(d) + else: + d2 = self.copy() + d2.update(d) + return frozendict(d2) + + def __repr__(self): + return "frozendict(%s)" % dict.__repr__(self) + +class Properties(object): + """Provide a __getattr__/__setattr__ interface over a dict.""" + + def __init__(self, data): + self.__dict__['_data'] = data def __len__(self): return len(self._data) @@ -809,7 +803,12 @@ class OrderedProperties(object): def __contains__(self, key): return key in self._data - + + def as_immutable(self): + """Return an immutable proxy for this :class:`.Properties`.""" + + return ImmutableProperties(self._data) + def update(self, value): self._data.update(value) @@ -828,6 +827,17 @@ class OrderedProperties(object): def clear(self): self._data.clear() +class OrderedProperties(Properties): + """Provide a __getattr__/__setattr__ interface with an OrderedDict + as backing store.""" + def __init__(self): + Properties.__init__(self, OrderedDict()) + + +class ImmutableProperties(ImmutableContainer, Properties): + """Provide immutable dict/object attribute to an underlying dictionary.""" + + class OrderedDict(dict): """A dict that returns keys/values/items in the order they were added.""" diff --git a/test/engine/test_metadata.py b/test/engine/test_metadata.py index b2250c808f..b3a9cef2e5 100644 --- a/test/engine/test_metadata.py +++ b/test/engine/test_metadata.py @@ -353,7 +353,6 @@ class MetaDataTest(TestBase, ComparesTables): [d, b, a, c, e] ) - def test_tometadata_strip_schema(self): meta = MetaData() @@ -387,7 +386,7 @@ class MetaDataTest(TestBase, ComparesTables): MetaData(testing.db), autoload=True) -class TableOptionsTest(TestBase, AssertsCompiledSQL): +class TableTest(TestBase, AssertsCompiledSQL): def test_prefixes(self): table1 = Table("temporary_table_1", MetaData(), Column("col1", Integer), @@ -418,3 +417,27 @@ class TableOptionsTest(TestBase, AssertsCompiledSQL): t.info['bar'] = 'zip' assert t.info['bar'] == 'zip' + def test_c_immutable(self): + m = MetaData() + t1 = Table('t', m, Column('x', Integer), Column('y', Integer)) + assert_raises( + TypeError, + t1.c.extend, [Column('z', Integer)] + ) + + def assign(): + t1.c['z'] = Column('z', Integer) + assert_raises( + TypeError, + assign + ) + + def assign(): + t1.c.z = Column('z', Integer) + assert_raises( + TypeError, + assign + ) + + + \ No newline at end of file diff --git a/test/engine/test_reflection.py b/test/engine/test_reflection.py index d0d6e31e13..91e73d4f20 100644 --- a/test/engine/test_reflection.py +++ b/test/engine/test_reflection.py @@ -611,6 +611,7 @@ class ReflectionTest(TestBase, ComparesTables): self.assert_tables_equal(multi, table) self.assert_tables_equal(multi2, table2) j = sa.join(table, table2) + self.assert_(sa.and_(table.c.multi_id == table2.c.foo, table.c.multi_rev == table2.c.bar, table.c.multi_hoho diff --git a/test/sql/test_compiler.py b/test/sql/test_compiler.py index 338a5491ee..2d6f0e1045 100644 --- a/test/sql/test_compiler.py +++ b/test/sql/test_compiler.py @@ -1553,7 +1553,7 @@ class SelectTest(TestBase, AssertsCompiledSQL): "SELECT foo, bar FROM bat UNION SELECT foo, bar " "FROM bat UNION SELECT foo, bar FROM bat UNION SELECT foo, bar FROM bat" ) - + self.assert_compile( union(s, union(s, union(s, s))), "SELECT foo, bar FROM bat UNION (SELECT foo, bar FROM bat " -- 2.47.2