]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
merge -r5658:5665 from trunk
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 13 Jan 2009 17:37:38 +0000 (17:37 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 13 Jan 2009 17:37:38 +0000 (17:37 +0000)
lib/sqlalchemy/ext/declarative.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/schema.py
test/ext/declarative.py
test/orm/unitofwork.py

index a5cb6e9d2d11fd30ef339516a75773840804d5fb..3b4880403ae13cca96d8361835cfe75effc6002e 100644 (file)
@@ -478,7 +478,9 @@ def _as_declarative(cls, classname, dict_):
                                           *(tuple(cols) + tuple(args)), **table_kw)
     else:
         table = cls.__table__
-
+        if cols:
+            raise exceptions.ArgumentError("Can't add additional columns when specifying __table__")
+            
     mapper_args = getattr(cls, '__mapper_args__', {})
     if 'inherits' not in mapper_args:
         inherits = cls.__mro__[1]
@@ -530,7 +532,7 @@ def _as_declarative(cls, classname, dict_):
             mapper_args['exclude_properties'] = exclude_properties = \
                 set([c.key for c in inherited_table.c if c not in inherited_mapper._columntoproperty])
             exclude_properties.difference_update([c.key for c in cols])
-        
+    
     cls.__mapper__ = mapper_cls(cls, table, properties=our_stuff, **mapper_args)
 
 class DeclarativeMeta(type):
index 04fc9d0ef1381691e3005c69bbc2a21e472e33ee..6bcc89b3c23cf96d1319beed23d6d877b7017f6f 100644 (file)
@@ -1305,13 +1305,15 @@ class Mapper(object):
                                 if col in pks:
                                     if history.deleted:
                                         params[col._label] = prop.get_col_value(col, history.deleted[0])
+                                        hasdata = True
                                     else:
                                         # row switch logic can reach us here
                                         # remove the pk from the update params so the update doesn't
                                         # attempt to include the pk in the update statement
                                         del params[col.key]
                                         params[col._label] = prop.get_col_value(col, history.added[0])
-                                hasdata = True
+                                else:
+                                    hasdata = True
                             elif col in pks:
                                 params[col._label] = mapper._get_state_attr_by_column(state, col)
                     if hasdata:
index 0211b9707ac69ba6f16b38133a1a8d76323be4aa..a4561d443df6300764f107c9cd3b5d31149252c6 100644 (file)
@@ -32,12 +32,24 @@ class ColumnProperty(StrategizedProperty):
     """Describes an object attribute that corresponds to a table column."""
 
     def __init__(self, *columns, **kwargs):
-        """The list of `columns` describes a single object
-        property. If there are multiple tables joined together for the
-        mapper, this list represents the equivalent column as it
-        appears across each table.
-        """
+        """Construct a ColumnProperty.
+
+        :param \*columns: The list of `columns` describes a single
+          object property. If there are multiple tables joined
+          together for the mapper, this list represents the equivalent
+          column as it appears across each table.
+
+        :param group:
+
+        :param deferred:
+
+        :param comparator_factory:
 
+        :param descriptor:
+
+        :param extension:
+
+        """
         self.columns = [expression._labeled(c) for c in columns]
         self.group = kwargs.pop('group', None)
         self.deferred = kwargs.pop('deferred', False)
@@ -45,6 +57,11 @@ class ColumnProperty(StrategizedProperty):
         self.comparator_factory = kwargs.pop('comparator_factory', self.__class__.Comparator)
         self.descriptor = kwargs.pop('descriptor', None)
         self.extension = kwargs.pop('extension', None)
+        if kwargs:
+            raise TypeError(
+                "%s received unexpected keyword argument(s): %s" % (
+                    self.__class__.__name__, ', '.join(sorted(kwargs.keys()))))
+
         util.set_creation_order(self)
         if self.no_instrument:
             self.strategy_class = strategies.UninstrumentedColumnLoader
@@ -1136,4 +1153,4 @@ mapper.ColumnProperty = ColumnProperty
 mapper.SynonymProperty = SynonymProperty
 mapper.ComparableProperty = ComparableProperty
 mapper.RelationProperty = RelationProperty
-mapper.ConcreteInheritedProperty = ConcreteInheritedProperty
\ No newline at end of file
+mapper.ConcreteInheritedProperty = ConcreteInheritedProperty
index c9dc152b9845432303afb6da2d0a1cc98df91619..6d7a8f2693dc8e8ffee00c98fdc8aeea19fcd2a4 100644 (file)
@@ -609,7 +609,9 @@ class Column(SchemaItem, expression.ColumnClause):
                 "Unknown arguments passed to Column: " + repr(kwargs.keys()))
 
     def __str__(self):
-        if self.table is not None:
+        if self.name is None:
+            return "(no name)"
+        elif self.table is not None:
             if self.table.named_with_column:
                 return (self.table.description + "." + self.description)
             else:
@@ -617,9 +619,9 @@ class Column(SchemaItem, expression.ColumnClause):
         else:
             return self.description
 
+    @property
     def bind(self):
         return self.table.bind
-    bind = property(bind)
 
     def references(self, column):
         """Return True if this Column references the given column via foreign key."""
index 3176832f309fbd65be37676f496751eaa75900e5..c9477b5d85c1e53f51fa6dbb7b1c1dd6b90f2984 100644 (file)
@@ -63,6 +63,26 @@ class DeclarativeTest(DeclarativeTestBase):
             class User(Base):
                 id = Column('id', Integer, primary_key=True)
         self.assertRaisesMessage(sa.exc.InvalidRequestError, "does not have a __table__", go)
+
+    def test_cant_add_columns(self):
+        t = Table('t', Base.metadata, Column('id', Integer, primary_key=True))
+        def go():
+            class User(Base):
+                __table__ = t
+                foo = Column(Integer, primary_key=True)
+        self.assertRaisesMessage(sa.exc.ArgumentError, "add additional columns", go)
+    
+    def test_undefer_column_name(self):
+        # TODO: not sure if there was an explicit
+        # test for this elsewhere
+        foo = Column(Integer)
+        eq_(str(foo), '(no name)')
+        eq_(foo.key, None)
+        eq_(foo.name, None)
+        decl._undefer_column_name('foo', foo)
+        eq_(str(foo), 'foo')
+        eq_(foo.key, 'foo')
+        eq_(foo.name, 'foo')
         
     def test_recompile_on_othermapper(self):
         """declarative version of the same test in mappers.py"""
index a4363b5e5ff2a06a686a7676033e7da59c1db4bc..553713da539fe233bcbe1d1f2d9d793463d9e331 100644 (file)
@@ -2216,6 +2216,49 @@ class RowSwitchTest(_base.MappedTest):
         assert list(sess.execute(t5.select(), mapper=T5)) == [(2, 'some other t5')]
         assert list(sess.execute(t6.select(), mapper=T5)) == [(1, 'some other t6', 2)]
 
+class InheritingRowSwitchTest(_base.MappedTest):
+    def define_tables(self, metadata):
+        Table('parent', metadata,
+            Column('id', Integer, primary_key=True),
+            Column('pdata', String(30))
+        )
+        Table('child', metadata,
+            Column('id', Integer, primary_key=True),
+            Column('pid', Integer, ForeignKey('parent.id')),
+            Column('cdata', String(30))
+        )
+
+    def setup_classes(self):
+        class P(_base.ComparableEntity):
+            pass
+
+        class C(P):
+            pass
+    
+    @testing.resolve_artifact_names
+    def test_row_switch_no_child_table(self):
+        mapper(P, parent)
+        mapper(C, child, inherits=P)
+        
+        sess = create_session()
+        c1 = C(id=1, pdata='c1', cdata='c1')
+        sess.add(c1)
+        sess.flush()
+        
+        # establish a row switch between c1 and c2.
+        # c2 has no value for the "child" table
+        c2 = C(id=1, pdata='c2')
+        sess.add(c2)
+        sess.delete(c1)
+
+        self.assert_sql_execution(testing.db, sess.flush,
+            CompiledSQL("UPDATE parent SET pdata=:pdata WHERE parent.id = :parent_id",
+                {'pdata':'c2', 'parent_id':1}
+            )
+        )
+        
+        
+
 class TransactionTest(_base.MappedTest):
     __requires__ = ('deferrable_constraints',)