]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- adapt initial patch from [ticket:1917] to current tip
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 16 Nov 2010 20:53:14 +0000 (15:53 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 16 Nov 2010 20:53:14 +0000 (15:53 -0500)
- raise TypeError for immutability

lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/schema.py
lib/sqlalchemy/sql/expression.py
lib/sqlalchemy/sql/operators.py
lib/sqlalchemy/util.py
test/engine/test_metadata.py
test/engine/test_reflection.py
test/sql/test_compiler.py

index e9da4f5337fc3c82461c26f0237852b82ef32bc9..8abb26fb6f4220418be9caada664d990dfcfe74c 100644 (file)
@@ -162,7 +162,7 @@ class Mapper(object):
         else:
             self.with_polymorphic = None
 
-        if isinstance(self.local_table, expression._SelectBaseMixin):
+        if isinstance(self.local_table, expression._SelectBase):
             raise sa_exc.InvalidRequestError(
                 "When mapping against a select() construct, map against "
                 "an alias() of the construct instead."
@@ -172,7 +172,7 @@ class Mapper(object):
 
         if self.with_polymorphic and \
                     isinstance(self.with_polymorphic[1],
-                                expression._SelectBaseMixin):
+                                expression._SelectBase):
             self.with_polymorphic = (self.with_polymorphic[0],
                                 self.with_polymorphic[1].alias())
 
index 2bccb8f73fadc45e6b8c4b16fd3007f482a5b5a8..2f482537d7ae6eddb397cfd10d4cc3054ba308ec 100644 (file)
@@ -162,7 +162,7 @@ class Query(object):
 
         fa = []
         for from_obj in obj:
-            if isinstance(from_obj, expression._SelectBaseMixin):
+            if isinstance(from_obj, expression._SelectBase):
                 from_obj = from_obj.alias()
             fa.append(from_obj)
 
@@ -1597,7 +1597,7 @@ class Query(object):
 
         if not isinstance(statement, 
                             (expression._TextClause,
-                            expression._SelectBaseMixin)):
+                            expression._SelectBase)):
             raise sa_exc.ArgumentError(
                             "from_statement accepts text(), select(), "
                             "and union() objects only.")
@@ -2468,7 +2468,7 @@ class Query(object):
 
             for hint in self._with_hints:
                 statement = statement.with_hint(*hint)
-                        
+                    
             if self._execution_options:
                 statement = statement.execution_options(
                                             **self._execution_options)
@@ -2803,7 +2803,7 @@ class QueryContext(object):
     def __init__(self, query):
 
         if query._statement is not None:
-            if isinstance(query._statement, expression._SelectBaseMixin) and \
+            if isinstance(query._statement, expression._SelectBase) and \
                                 not query._statement.use_labels:
                 self.statement = query._statement.apply_labels()
             else:
index a332cec36154300658e4cd08d39cf2cbb935db09..e7a5d6e464a39d439f1ce7e371b8b2f301b584ea 100644 (file)
@@ -227,7 +227,7 @@ class Table(SchemaItem, expression.TableClause):
         self.constraints = set()
         self._columns = expression.ColumnCollection()
         self._set_primary_key(PrimaryKeyConstraint())
-        self._foreign_keys = util.OrderedSet()
+        self.foreign_keys = util.OrderedSet()
         self._extra_dependencies = set()
         self.ddl_listeners = util.defaultdict(list)
         self.kwargs = {}
@@ -283,7 +283,7 @@ class Table(SchemaItem, expression.TableClause):
         if include_columns:
             for c in self.c:
                 if c.name not in include_columns:
-                    self.c.remove(c)
+                    self._columns.remove(c)
 
         for key in ('quote', 'quote_schema'):
             if key in kwargs:
@@ -307,10 +307,13 @@ class Table(SchemaItem, expression.TableClause):
                 "Invalid argument(s) for Table: %r" % kwargs.keys())
         self.kwargs.update(kwargs)
 
+    def _init_collections(self):
+        pass
+        
     def _set_primary_key(self, pk):
-        if getattr(self, '_primary_key', None) in self.constraints:
-            self.constraints.remove(self._primary_key)
-        self._primary_key = pk
+        if self.primary_key in self.constraints:
+            self.constraints.remove(self.primary_key)
+        self.primary_key = pk
         self.constraints.add(pk)
 
         for c in pk.columns:
@@ -330,10 +333,6 @@ class Table(SchemaItem, expression.TableClause):
     def key(self):
         return _get_table_key(self.name, self.schema)
 
-    @property
-    def primary_key(self):
-        return self._primary_key
-
     def __repr__(self):
         return "Table(%s)" % ', '.join(
             [repr(self.name)] + [repr(self.metadata)] +
@@ -937,7 +936,7 @@ class Column(SchemaItem, expression.ColumnClause):
             nullable = self.nullable, 
             quote=self.quote, _proxies=[self], *fk)
         c.table = selectable
-        selectable.columns.add(c)
+        selectable._columns.add(c)
         if self.primary_key:
             selectable.primary_key.add(c)
         for fn in c._table_events:
index c3dc339a50dd9adaac92e502228e2b0c9767da94..0f93643dc7701f9e884667ee8d945fe4623943da 100644 (file)
@@ -1658,7 +1658,7 @@ class _CompareMixin(ColumnOperators):
         if isinstance(seq_or_selectable, _ScalarSelect):
             return self.__compare(op, seq_or_selectable,
                                   negate=negate_op)
-        elif isinstance(seq_or_selectable, _SelectBaseMixin):
+        elif isinstance(seq_or_selectable, _SelectBase):
 
             # TODO: if we ever want to support (x, y, z) IN (select x,
             # y, z from table), we would need a multi-column version of
@@ -1830,7 +1830,7 @@ class _CompareMixin(ColumnOperators):
             return other.__clause_element__()
         elif not isinstance(other, ClauseElement):
             return self._bind_param(operator, other)
-        elif isinstance(other, (_SelectBaseMixin, Alias)):
+        elif isinstance(other, (_SelectBase, Alias)):
             return other.as_scalar()
         else:
             return other
@@ -1905,7 +1905,7 @@ class ColumnElement(ClauseElement, _CompareMixin):
         co = ColumnClause(name, selectable, type_=getattr(self,
                           'type', None))
         co.proxies = [self]
-        selectable.columns[key] = co
+        selectable._columns[key] = co
         return co
 
     def compare(self, other, use_proxies=False, equivalents=None, **kw):
@@ -2044,6 +2044,16 @@ class ColumnCollection(util.OrderedProperties):
         # always return a "True" value (i.e. a BinaryClause...)
 
         return col in util.column_set(self)
+    
+    def as_immutable(self):
+        return ImmutableColumnCollection(self._data)
+        
+class ImmutableColumnCollection(util.ImmutableProperties, ColumnCollection):
+    def __init__(self, data):
+        util.ImmutableProperties.__init__(self, data)
+    
+    extend = remove = util.ImmutableProperties._immutable
+
 
 class ColumnSet(util.ordered_column_set):
     def contains_column(self, col):
@@ -2239,44 +2249,50 @@ class FromClause(Selectable):
     def _reset_exported(self):
         """delete memoized collections when a FromClause is cloned."""
 
-        for attr in '_columns', '_primary_key', '_foreign_keys', \
-            'locate_all_froms':
-            self.__dict__.pop(attr, None)
+        for name in 'primary_key', '_columns', 'columns', \
+                'foreign_keys', 'locate_all_froms':
+            self.__dict__.pop(name, None)
 
     @util.memoized_property
-    def _columns(self):
+    def columns(self):
         """Return the collection of Column objects contained by this
         FromClause."""
-
-        self._export_columns()
-        return self._columns
-
+        
+        if '_columns' not in self.__dict__:
+            self._init_collections()
+            self._populate_column_collection()
+        return self._columns.as_immutable()
+    
     @util.memoized_property
-    def _primary_key(self):
+    def primary_key(self):
         """Return the collection of Column objects which comprise the
         primary key of this FromClause."""
-
-        self._export_columns()
-        return self._primary_key
-
+        
+        self._init_collections()
+        self._populate_column_collection()
+        return self.primary_key
+    
     @util.memoized_property
-    def _foreign_keys(self):
+    def foreign_keys(self):
         """Return the collection of ForeignKey objects which this
         FromClause references."""
+        
+        self._init_collections()
+        self._populate_column_collection()
+        return self.foreign_keys
 
-        self._export_columns()
-        return self._foreign_keys
-    columns = property(attrgetter('_columns'), doc=_columns.__doc__)
-    primary_key = property(attrgetter('_primary_key'),
-                           doc=_primary_key.__doc__)
-    foreign_keys = property(attrgetter('_foreign_keys'),
-                            doc=_foreign_keys.__doc__)
-
-    # synonyms for 'columns'
-
-    c = _select_iterable = property(attrgetter('columns'),
-                                    doc=_columns.__doc__)
-
+    c = property(attrgetter('columns'))
+    _select_iterable = property(attrgetter('columns'))
+    
+    def _init_collections(self):
+        assert '_columns' not in self.__dict__
+        assert 'primary_key' not in self.__dict__
+        assert 'foreign_keys' not in self.__dict__
+            
+        self._columns = ColumnCollection()
+        self.primary_key = ColumnSet()
+        self.foreign_keys = set()
+         
     def _export_columns(self):
         """Initialize column collections."""
 
@@ -3009,7 +3025,7 @@ class _Exists(_UnaryExpression):
     _from_objects = []
 
     def __init__(self, *args, **kwargs):
-        if args and isinstance(args[0], (_SelectBaseMixin, _ScalarSelect)):
+        if args and isinstance(args[0], (_SelectBase, _ScalarSelect)):
             s = args[0]
         else:
             if not args:
@@ -3088,10 +3104,10 @@ class Join(FromClause):
         columns = [c for c in self.left.columns] + \
                         [c for c in self.right.columns]
 
-        self._primary_key.extend(sqlutil.reduce_columns(
+        self.primary_key.extend(sqlutil.reduce_columns(
                 (c for c in columns if c.primary_key), self.onclause))
         self._columns.update((col._label, col) for col in columns)
-        self._foreign_keys.update(itertools.chain(
+        self.foreign_keys.update(itertools.chain(
                         *[col.foreign_keys for col in columns]))
 
     def _copy_internals(self, clone=_clone):
@@ -3281,11 +3297,25 @@ class _FromGrouping(FromClause):
 
     def __init__(self, element):
         self.element = element
-
+    
+    def _init_collections(self):
+        pass
+        
     @property
     def columns(self):
         return self.element.columns
 
+    @property
+    def primary_key(self):
+        return self.element.primary_key
+
+    @property
+    def foreign_keys(self):
+        # this could be
+        # self.element.foreign_keys
+        # see SelectableTest.test_join_condition
+        return set()
+
     @property
     def _hide_froms(self):
         return self.element._hide_froms
@@ -3476,7 +3506,7 @@ class ColumnClause(_Immutable, ColumnElement):
                 )
         c.proxies = [self]
         if attach:
-            selectable.columns[c.name] = c
+            selectable._columns[c.name] = c
         return c
 
 class TableClause(_Immutable, FromClause):
@@ -3496,11 +3526,14 @@ class TableClause(_Immutable, FromClause):
         super(TableClause, self).__init__()
         self.name = self.fullname = name
         self._columns = ColumnCollection()
-        self._primary_key = ColumnSet()
-        self._foreign_keys = set()
+        self.primary_key = ColumnSet()
+        self.foreign_keys = set()
         for c in columns:
             self.append_column(c)
-
+    
+    def _init_collections(self):
+        pass
+        
     def _export_columns(self):
         raise NotImplementedError()
 
@@ -3556,7 +3589,7 @@ class TableClause(_Immutable, FromClause):
     def _from_objects(self):
         return [self]
 
-class _SelectBaseMixin(Executable):
+class _SelectBase(Executable, FromClause):
     """Base class for :class:`Select` and ``CompoundSelects``."""
 
     def __init__(self,
@@ -3583,7 +3616,7 @@ class _SelectBaseMixin(Executable):
 
         self._order_by_clause = ClauseList(*util.to_list(order_by) or [])
         self._group_by_clause = ClauseList(*util.to_list(group_by) or [])
-
+    
     def as_scalar(self):
         """return a 'scalar' representation of this selectable, which can be
         used as a column expression.
@@ -3729,7 +3762,7 @@ class _ScalarSelect(_Grouping):
     def _make_proxy(self, selectable, name):
         return list(self.inner_columns)[0]._make_proxy(selectable, name)
 
-class CompoundSelect(_SelectBaseMixin, FromClause):
+class CompoundSelect(_SelectBase):
     """Forms the basis of ``UNION``, ``UNION ALL``, and other 
         SELECT-based set operations."""
 
@@ -3764,7 +3797,7 @@ class CompoundSelect(_SelectBaseMixin, FromClause):
 
             self.selects.append(s.self_group(self))
 
-        _SelectBaseMixin.__init__(self, **kwargs)
+        _SelectBase.__init__(self, **kwargs)
     
     def _scalar_type(self):
         return self.selects[0]._scalar_type()
@@ -3830,7 +3863,7 @@ class CompoundSelect(_SelectBaseMixin, FromClause):
         self._bind = bind
     bind = property(bind, _set_bind)
 
-class Select(_SelectBaseMixin, FromClause):
+class Select(_SelectBase):
     """Represents a ``SELECT`` statement.
 
     Select statements support appendable clauses, as well as the
@@ -3859,7 +3892,7 @@ class Select(_SelectBaseMixin, FromClause):
         argument descriptions.
 
         Additional generative and mutator methods are available on the
-        :class:`_SelectBaseMixin` superclass.
+        :class:`_SelectBase` superclass.
 
         """
         self._should_correlate = correlate
@@ -3907,7 +3940,7 @@ class Select(_SelectBaseMixin, FromClause):
         if prefixes:
             self._prefixes = tuple([_literal_as_text(p) for p in prefixes])
 
-        _SelectBaseMixin.__init__(self, **kwargs)
+        _SelectBase.__init__(self, **kwargs)
 
     def _get_display_froms(self, existing_froms=None):
         """Return the full list of 'from' clauses to be displayed.
index 6f70b1778d2cfeb7dfbdee363ac346a759c94f01..67830f7cf5ffdb31879efcd52d5a631df6378326 100644 (file)
@@ -83,7 +83,7 @@ def desc_op(a):
 def asc_op(a):
     return a.asc()
 
-_commutative = set([eq, ne, add, mul])
+_commutative = set([eq, ne, add, mul, and_])
 def is_commutative(op):
     return op in _commutative
 
index 8665cd0d4ac699d8ec075ae9bcfa7427aec4c2c9..cfeb38f54c0f379f1d51d0a486197a9c31e2a169 100644 (file)
@@ -146,36 +146,6 @@ except ImportError:
             return 'defaultdict(%s, %s)' % (self.default_factory,
                                             dict.__repr__(self))
 
-class frozendict(dict):
-    @property
-    def _blocked_attribute(obj):
-        raise AttributeError, "A frozendict cannot be modified."
-
-    __delitem__ = __setitem__ = clear = _blocked_attribute
-    pop = popitem = setdefault = update = _blocked_attribute
-
-    def __new__(cls, *args):
-        new = dict.__new__(cls)
-        dict.__init__(new, *args)
-        return new
-
-    def __init__(self, *args):
-        pass
-
-    def __reduce__(self):
-        return frozendict, (dict(self), )
-
-    def union(self, d):
-        if not self:
-            return frozendict(d)
-        else:
-            d2 = self.copy()
-            d2.update(d)
-            return frozendict(d2)
-            
-    def __repr__(self):
-        return "frozendict(%s)" % dict.__repr__(self)
-
 
 # find or create a dict implementation that supports __missing__
 class _probe(dict):
@@ -759,20 +729,44 @@ class NamedTuple(tuple):
     def keys(self):
         return [l for l in self._labels if l is not None]
 
+class ImmutableContainer(object):
+    def _immutable(self, *arg, **kw):
+        raise TypeError("%s object is immutable" % self.__class__.__name__)
 
-class OrderedProperties(object):
-    """An object that maintains the order in which attributes are set upon it.
+    __delitem__ = __setitem__ = __setattr__ = _immutable
 
-    Also provides an iterator and a very basic getitem/setitem
-    interface to those attributes.
+class frozendict(ImmutableContainer, dict):
+    
+    clear = pop = popitem = setdefault = \
+        update = ImmutableContainer._immutable
 
-    (Not really a dict, since it iterates over values, not keys.  Not really
-    a list, either, since each value must have a key associated; hence there is
-    no append or extend.)
-    """
+    def __new__(cls, *args):
+        new = dict.__new__(cls)
+        dict.__init__(new, *args)
+        return new
 
-    def __init__(self):
-        self.__dict__['_data'] = OrderedDict()
+    def __init__(self, *args):
+        pass
+
+    def __reduce__(self):
+        return frozendict, (dict(self), )
+
+    def union(self, d):
+        if not self:
+            return frozendict(d)
+        else:
+            d2 = self.copy()
+            d2.update(d)
+            return frozendict(d2)
+            
+    def __repr__(self):
+        return "frozendict(%s)" % dict.__repr__(self)
+
+class Properties(object):
+    """Provide a __getattr__/__setattr__ interface over a dict."""
+
+    def __init__(self, data):
+        self.__dict__['_data'] = data
 
     def __len__(self):
         return len(self._data)
@@ -809,7 +803,12 @@ class OrderedProperties(object):
 
     def __contains__(self, key):
         return key in self._data
-
+    
+    def as_immutable(self):
+        """Return an immutable proxy for this :class:`.Properties`."""
+        
+        return ImmutableProperties(self._data)
+        
     def update(self, value):
         self._data.update(value)
 
@@ -828,6 +827,17 @@ class OrderedProperties(object):
     def clear(self):
         self._data.clear()
 
+class OrderedProperties(Properties):
+    """Provide a __getattr__/__setattr__ interface with an OrderedDict
+    as backing store."""
+    def __init__(self):
+        Properties.__init__(self, OrderedDict())
+    
+
+class ImmutableProperties(ImmutableContainer, Properties):
+    """Provide immutable dict/object attribute to an underlying dictionary."""
+        
+    
 class OrderedDict(dict):
     """A dict that returns keys/values/items in the order they were added."""
 
index b2250c808f38d5dd9c813828077d73a24ef776ba..b3a9cef2e5393e8e361eb69edb52deceff20a772 100644 (file)
@@ -353,7 +353,6 @@ class MetaDataTest(TestBase, ComparesTables):
             [d, b, a, c, e]
         )
         
-        
     def test_tometadata_strip_schema(self):
         meta = MetaData()
 
@@ -387,7 +386,7 @@ class MetaDataTest(TestBase, ComparesTables):
                           MetaData(testing.db), autoload=True)
 
 
-class TableOptionsTest(TestBase, AssertsCompiledSQL):
+class TableTest(TestBase, AssertsCompiledSQL):
     def test_prefixes(self):
         table1 = Table("temporary_table_1", MetaData(),
                       Column("col1", Integer),
@@ -418,3 +417,27 @@ class TableOptionsTest(TestBase, AssertsCompiledSQL):
             t.info['bar'] = 'zip'
             assert t.info['bar'] == 'zip'
 
+    def test_c_immutable(self):
+        m = MetaData()
+        t1 = Table('t', m, Column('x', Integer), Column('y', Integer))
+        assert_raises(
+            TypeError,
+            t1.c.extend, [Column('z', Integer)]
+        )
+
+        def assign():
+            t1.c['z'] = Column('z', Integer)
+        assert_raises(
+            TypeError,
+            assign
+        )
+
+        def assign():
+            t1.c.z = Column('z', Integer)
+        assert_raises(
+            TypeError,
+            assign
+        )
+        
+        
+    
\ No newline at end of file
index d0d6e31e13856249097a7121c4e839e8646b4a7d..91e73d4f20a65f5209bbc1a9b25e216df948511c 100644 (file)
@@ -611,6 +611,7 @@ class ReflectionTest(TestBase, ComparesTables):
             self.assert_tables_equal(multi, table)
             self.assert_tables_equal(multi2, table2)
             j = sa.join(table, table2)
+
             self.assert_(sa.and_(table.c.multi_id == table2.c.foo,
                          table.c.multi_rev == table2.c.bar,
                          table.c.multi_hoho
index 338a5491eed51b200bc51398fe3c9c89ab220190..2d6f0e1045a189f6334bfd613061d10060921686 100644 (file)
@@ -1553,7 +1553,7 @@ class SelectTest(TestBase, AssertsCompiledSQL):
             "SELECT foo, bar FROM bat UNION SELECT foo, bar "
             "FROM bat UNION SELECT foo, bar FROM bat UNION SELECT foo, bar FROM bat"
         )
-
+        
         self.assert_compile(
             union(s, union(s, union(s, s))),
             "SELECT foo, bar FROM bat UNION (SELECT foo, bar FROM bat "