]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- The composite() property type now supports
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 11 Aug 2008 17:18:10 +0000 (17:18 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 11 Aug 2008 17:18:10 +0000 (17:18 +0000)
a __set_composite_values__() method on the composite
class which is required if the class represents
state using attribute names other than the
column's keynames; default-generated values now
get populated properly upon flush.  Also,
composites with attributes set to None compare
correctly.  [ticket:1132]

CHANGES
doc/build/content/mappers.txt
lib/sqlalchemy/orm/__init__.py
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/orm/strategies.py
test/orm/mapper.py

diff --git a/CHANGES b/CHANGES
index 30b0baae5d907ac1a0bd812412b5ca31eee9010a..503a9c16eecff386b148dd760e80f8524168be55 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -25,6 +25,15 @@ CHANGES
     - Fixed @on_reconsitute hook for subclasses 
       which inherit from a base class.
       [ticket:1129]
+
+    - The composite() property type now supports
+      a __set_composite_values__() method on the composite
+      class which is required if the class represents
+      state using attribute names other than the 
+      column's keynames; default-generated values now
+      get populated properly upon flush.  Also,
+      composites with attributes set to None compare
+      correctly.  [ticket:1132]
       
 - sql
     - Temporarily rolled back the "ORDER BY" enhancement
index 3149e9972d95f77a443b840b4c820530e7e8131f..f0821e6e249cf7891465d5261a140b71f8d0ddf9 100644 (file)
@@ -181,7 +181,7 @@ Sets of columns can be associated with a single datatype.  The ORM treats the gr
         Column('y2', Integer),
         )
 
-The requirements for the custom datatype class are that it have a constructor which accepts positional arguments corresponding to its column format, and also provides a method `__composite_values__()` which returns the state of the object as a list or tuple, in order of its column-based attributes.  It also should supply adequate `__eq__()` and `__ne__()` methods which test the equality of two instances:
+The requirements for the custom datatype class are that it have a constructor which accepts positional arguments corresponding to its column format, and also provides a method `__composite_values__()` which returns the state of the object as a list or tuple, in order of its column-based attributes.  It also should supply adequate `__eq__()` and `__ne__()` methods which test the equality of two instances, and may optionally provide a `__set_composite_values__` method which is used to set internal state in some cases (typically when default values have been generated during a flush):
     
     {python}
     class Point(object):
@@ -190,11 +190,16 @@ The requirements for the custom datatype class are that it have a constructor wh
             self.y = y
         def __composite_values__(self):
             return [self.x, self.y]            
+        def __set_composite_values__(self, x, y):
+            self.x = x
+            self.y = y
         def __eq__(self, other):
             return other.x == self.x and other.y == self.y
         def __ne__(self, other):
             return not self.__eq__(other)
 
+If `__set_composite_values__()` is not provided, the names of the mapped columns are taken as the names of attributes on the object, and `setattr()` is used to set data.
+
 Setting up the mapping uses the `composite()` function:
 
 
index 34867c871f7448f7b3d4878fe034ba39e914055c..425a41b3719a562b665304c450922c12563a22fc 100644 (file)
@@ -416,11 +416,29 @@ def composite(class_, *cols, **kwargs):
               self.x = x
               self.y = y
           def __composite_values__(self):
-              return (self.x, self.y)
-
+              return self.x, self.y
+          def __eq__(self, other):
+              return other is not None and self.x == other.x and self.y == other.y
+              
       # and then in the mapping:
       ... composite(Point, mytable.c.x, mytable.c.y) ...
 
+    The composite object may have its attributes populated based on the names
+    of the mapped columns.  To override the way internal state is set,
+    additionally implement ``__set_composite_values__``:
+        
+        class Point(object):
+            def __init__(self, x, y):
+                self.some_x = x
+                self.some_y = y
+            def __composite_values__(self):
+                return self.some_x, self.some_y
+            def __set_composite_values__(self, x, y):
+                self.some_x = x
+                self.some_y = y
+            def __eq__(self, other):
+                return other is not None and self.some_x == other.x and self.some_y == other.y
+
     Arguments are:
 
     class\_
index 75b9835bb1f13ed538660aa66ae2a5885c74dcc9..f46fd722d0c4c4eeab9dbdd888301b3749eab765 100644 (file)
@@ -103,6 +103,7 @@ class CompositeProperty(ColumnProperty):
 
     def __init__(self, class_, *columns, **kwargs):
         super(CompositeProperty, self).__init__(*columns, **kwargs)
+        self._col_position_map = dict((c, i) for i, c in enumerate(columns))
         self.composite_class = class_
         self.comparator_factory = kwargs.pop('comparator', CompositeProperty.Comparator)
         self.strategy_class = strategies.CompositeColumnLoader
@@ -123,17 +124,22 @@ class CompositeProperty(ColumnProperty):
         return self.get_col_value(column, obj)
 
     def setattr(self, state, value, column):
-        # TODO: test coverage for this method
+
         obj = state.get_impl(self.key).get(state)
         if obj is None:
             obj = self.composite_class(*[None for c in self.columns])
             state.get_impl(self.key).set(state, obj, None)
 
-        for a, b in zip(self.columns, value.__composite_values__()):
-            if a is column:
-                setattr(obj, b, value)
-
+        if hasattr(obj, '__set_composite_values__'):
+            values = list(obj.__composite_values__())
+            values[self._col_position_map[column]] = value
+            obj.__set_composite_values__(*values)
+        else:
+            setattr(obj, column.key, value)
+            
     def get_col_value(self, column, value):
+        if value is None:
+            return None
         for a, b in zip(self.columns, value.__composite_values__()):
             if a is column:
                 return b
@@ -144,16 +150,12 @@ class CompositeProperty(ColumnProperty):
 
         def __eq__(self, other):
             if other is None:
-                return sql.and_(*[a==None for a in self.prop.columns])
+                values = [None] * len(self.prop.columns)
             else:
-                return sql.and_(*[a==b for a, b in
-                                  zip(self.prop.columns,
-                                      other.__composite_values__())])
-
+                values = other.__composite_values__()
+            return sql.and_(*[a==b for a, b in zip(self.prop.columns, values)])
         def __ne__(self, other):
-            return sql.or_(*[a!=b for a, b in
-                             zip(self.prop.columns,
-                                 other.__composite_values__())])
+            return sql.not_(self.__eq__(other))
 
     def __str__(self):
         return str(self.parent.class_.__name__) + "." + self.key
index b5250638649bc13650978ab441e470c2146311c8..b9801370b4e0502a85836fd1a71689766ebd71bc 100644 (file)
@@ -1150,7 +1150,7 @@ class Query(object):
 
             _get_clause = q._adapt_clause(_get_clause, True, False)
             q._criterion = _get_clause
-
+            
             for i, primary_key in enumerate(mapper.primary_key):
                 try:
                     params[_get_params[primary_key].key] = ident[i]
index f0eb93e3fd4e876784bbebcb7ae3a56605ea56e7..e2adf701dcde68896c29c591339fb03f458813d3 100644 (file)
@@ -95,6 +95,8 @@ class CompositeColumnLoader(ColumnLoader):
         self.logger.info("%s register managed composite attribute" % self)
 
         def copy(obj):
+            if obj is None:
+                return None
             return self.parent_property.composite_class(*obj.__composite_values__())
             
         def compare(a, b):
index 3f578244775f39c85768aa054f77419aaa4192c9..02db1c8d146a48b9063a2261cf79bcc4bfa81c29 100644 (file)
@@ -1366,7 +1366,7 @@ class CompositeTypesTest(_base.MappedTest):
     def define_tables(self, metadata):
         Table('graphs', metadata,
             Column('id', Integer, primary_key=True),
-            Column('version_id', Integer, primary_key=True),
+            Column('version_id', Integer, primary_key=True, nullable=True),
             Column('name', String(30)))
 
         Table('edges', metadata,
@@ -1382,6 +1382,14 @@ class CompositeTypesTest(_base.MappedTest):
             ['graph_id', 'graph_version_id'],
             ['graphs.id', 'graphs.version_id']))
 
+        Table('foobars', metadata,
+            Column('id', Integer, primary_key=True),
+            Column('x1', Integer, default=2),
+            Column('x2', Integer),
+            Column('x3', Integer, default=15),
+            Column('x4', Integer)
+        )
+        
     @testing.resolve_artifact_names
     def test_basic(self):
         class Point(object):
@@ -1465,7 +1473,6 @@ class CompositeTypesTest(_base.MappedTest):
                 self.id = id
                 self.version = version
             def __composite_values__(self):
-                # a tuple this time
                 return (self.id, self.version)
             def __eq__(self, other):
                 return other.id == self.id and other.version == self.version
@@ -1476,7 +1483,7 @@ class CompositeTypesTest(_base.MappedTest):
             def __init__(self, version):
                 self.version = version
 
-        mapper(Graph, graphs, properties={
+        mapper(Graph, graphs, allow_null_pks=True, properties={
             'version':sa.orm.composite(Version, graphs.c.id,
                                        graphs.c.version_id)})
 
@@ -1492,6 +1499,134 @@ class CompositeTypesTest(_base.MappedTest):
 
         g2 = sess.query(Graph).get(Version(1, 1))
         eq_(g.version, g2.version)
+        
+        # TODO: can't seem to get NULL in for a PK value
+        # in either mysql or postgres, autoincrement=False etc.
+        # notwithstanding
+        @testing.fails_on_everything_except("sqlite")
+        def go():
+            g = Graph(Version(2, None))
+            sess.save(g)
+            sess.flush()
+            sess.clear()
+            g2 = sess.query(Graph).filter_by(version=Version(2, None)).one()
+            eq_(g.version, g2.version)
+        go()
+        
+    @testing.resolve_artifact_names
+    def test_attributes_with_defaults(self):
+        class Foobar(object):
+            pass
+
+        class FBComposite(object):
+            def __init__(self, x1, x2, x3, x4):
+                self.x1 = x1
+                self.x2 = x2
+                self.x3 = x3
+                self.x4 = x4
+            def __composite_values__(self):
+                return self.x1, self.x2, self.x3, self.x4
+            def __eq__(self, other):
+                return other.x1 == self.x1 and other.x2 == self.x2 and other.x3 == self.x3 and other.x4 == self.x4
+            def __ne__(self, other):
+                return not self.__eq__(other)
+
+        mapper(Foobar, foobars, properties=dict(
+            foob=sa.orm.composite(FBComposite, foobars.c.x1, foobars.c.x2, foobars.c.x3, foobars.c.x4)
+        ))
+
+        sess = create_session()
+        f1 = Foobar()
+        f1.foob = FBComposite(None, 5, None, None)
+        sess.save(f1)
+        sess.flush()
+
+        assert f1.foob == FBComposite(2, 5, 15, None)
+    
+    @testing.resolve_artifact_names
+    def test_set_composite_values(self):
+        class Foobar(object):
+            pass
+        
+        class FBComposite(object):
+            def __init__(self, x1, x2, x3, x4):
+                self.x1val = x1
+                self.x2val = x2
+                self.x3 = x3
+                self.x4 = x4
+            def __composite_values__(self):
+                return self.x1val, self.x2val, self.x3, self.x4
+            def __set_composite_values__(self, x1, x2, x3, x4):
+                self.x1val = x1
+                self.x2val = x2
+                self.x3 = x3
+                self.x4 = x4
+            def __eq__(self, other):
+                return other.x1val == self.x1val and other.x2val == self.x2val and other.x3 == self.x3 and other.x4 == self.x4
+            def __ne__(self, other):
+                return not self.__eq__(other)
+        
+        mapper(Foobar, foobars, properties=dict(
+            foob=sa.orm.composite(FBComposite, foobars.c.x1, foobars.c.x2, foobars.c.x3, foobars.c.x4)
+        ))
+        
+        sess = create_session()
+        f1 = Foobar()
+        f1.foob = FBComposite(None, 5, None, None)
+        sess.save(f1)
+        sess.flush()
+        
+        assert f1.foob == FBComposite(2, 5, 15, None)
+    
+    @testing.resolve_artifact_names
+    def test_save_null(self):
+        """test saving a null composite value
+        
+        See google groups thread for more context:
+        http://groups.google.com/group/sqlalchemy/browse_thread/thread/0c6580a1761b2c29
+        
+        """
+        class Point(object):
+            def __init__(self, x, y):
+                self.x = x
+                self.y = y
+            def __composite_values__(self):
+                return [self.x, self.y]
+            def __eq__(self, other):
+                return other.x == self.x and other.y == self.y
+            def __ne__(self, other):
+                return not self.__eq__(other)
+
+        class Graph(object):
+            pass
+        class Edge(object):
+            def __init__(self, start, end):
+                self.start = start
+                self.end = end
+
+        mapper(Graph, graphs, properties={
+            'edges':relation(Edge)
+        })
+        mapper(Edge, edges, properties={
+            'start':sa.orm.composite(Point, edges.c.x1, edges.c.y1),
+            'end':sa.orm.composite(Point, edges.c.x2, edges.c.y2)
+        })
+
+        sess = create_session()
+        g = Graph()
+        g.id = 1
+        g.version_id=1
+        e = Edge(None, None)
+        g.edges.append(e)
+        
+        sess.save(g)
+        sess.flush()
+        
+        sess.clear()
+        
+        g2 = sess.query(Graph).get([1, 1])
+        assert g2.edges[-1].start.x is None
+        assert g2.edges[-1].start.y is None
 
 
 class NoLoadTest(_fixtures.FixtureTest):