]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- fixed up vertical.py
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 28 Jun 2008 15:23:08 +0000 (15:23 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 28 Jun 2008 15:23:08 +0000 (15:23 +0000)
- Fixed query.join() when used in conjunction with a
columns-only clause and an SQL-expression
ON clause in the join.

CHANGES
examples/vertical/vertical.py
lib/sqlalchemy/orm/query.py
test/orm/query.py

diff --git a/CHANGES b/CHANGES
index 50b9fb418a210f3f84833e88a3a97d7e3a0b9d3c..5f74037738c63266326c93bb18caba7f0196b71e 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -24,6 +24,10 @@ CHANGES
       joined-table inheritance mappers when using 
       query.join(cls, aliased=True).  [ticket:1082]
 
+    - Fixed query.join() when used in conjunction with a
+      columns-only clause and an SQL-expression 
+      ON clause in the join.
+      
     - Repaired `__str__()` method on Query. [ticket:1066]
 
 - sqlite
index 225beeffe92b11f657f752eaf856bb47612d03c8..6c3a61919ef9140a40670dd1e3f6d93e88cc9696 100644 (file)
@@ -7,31 +7,25 @@ from sqlalchemy.orm import *
 from sqlalchemy.orm.collections import mapped_collection
 import datetime
 
-e = MetaData('sqlite://')
-e.bind.echo = True
+engine = create_engine('sqlite://', echo=False)
+meta = MetaData(engine)
 
-Session = scoped_session(sessionmaker(transactional=True))
+Session = scoped_session(sessionmaker())
 
-# this table represents Entity objects.  each Entity gets a row in this table,
-# with a primary key and a title.
-entities = Table('entities', e, 
+# represent Entity objects
+entities = Table('entities', meta, 
     Column('entity_id', Integer, primary_key=True),
     Column('title', String(100), nullable=False),
     )
 
-# this table represents dynamic fields that can be associated
-# with values attached to an Entity.
-# a field has an ID, a name, and a datatype.  
-entity_fields = Table('entity_fields', e,
+# represent named, typed fields
+entity_fields = Table('entity_fields', meta,
     Column('field_id', Integer, primary_key=True),
     Column('name', String(40), nullable=False),
     Column('datatype', String(30), nullable=False))
     
-# this table represents attributes that are attached to an 
-# Entity object.  It combines a row from entity_fields with an actual value.
-# the value is stored in one of four columns, corresponding to the datatype
-# of the field value.
-entity_values = Table('entity_values', e, 
+# associate a field row with an entity row, including a typed value
+entity_values = Table('entity_values', meta, 
     Column('value_id', Integer, primary_key=True),
     Column('field_id', Integer, ForeignKey('entity_fields.field_id'), nullable=False),
     Column('entity_id', Integer, ForeignKey('entities.entity_id'), nullable=False),
@@ -40,135 +34,167 @@ entity_values = Table('entity_values', e,
     Column('binary_value', PickleType),
     Column('datetime_value', DateTime))
 
-e.create_all()
+meta.create_all()
 
 class Entity(object):
-    """represents an Entity.  The __getattr__ method is overridden to search the
-    object's _entities dictionary for the appropriate value, and the __setattribute__
-    method is overridden to set all non "_" attributes as EntityValues within the 
-    _entities dictionary. """
-
+    """a persistable dynamic object.  
+    
+    Marshalls attributes into a dictionary which is 
+    mapped to the database.
+    
+    """
+    def __init__(self, **kwargs):
+        for k in kwargs:
+            setattr(self, k, kwargs[k])
+            
     def __getattr__(self, key):
-        """getattr proxies requests for attributes which dont 'exist' on the object
-        to the underying _entities dictionary."""
+        """Proxy requests for attributes to the underlying _entities dictionary."""
+
         if key[0] == '_':
             return super(Entity, self).__getattr__(key)
         try:
             return self._entities[key].value
         except KeyError:
             raise AttributeError(key)
+
     def __setattr__(self, key, value):
-        """setattr proxies certain requests to set attributes as EntityValues within
-        the _entities dictionary.  This one is tricky as it has to allow underscore attributes,
-        as well as attributes managed by the Mapper, to be set by default mechanisms.  Since the 
-        mapper uses property-like objects on the Entity class to manage attributes, we check
-        for the key as an attribute of the class and if present, default to normal __setattr__
-        mechanisms, else we use the special logic which creates EntityValue objects in the
-        _entities dictionary."""
+        """Proxy requests for attribute set operations to the underlying _entities dictionary."""
+
         if key[0] == "_" or hasattr(Entity, key):
             object.__setattr__(self, key, value)
             return
+            
         try:
             ev = self._entities[key]
             ev.value = value
         except KeyError:
-            ev = EntityValue(key, value)
+            ev = _EntityValue(key, value)
             self._entities[key] = ev
         
-class EntityField(object):
-    """this class represents a row in the entity_fields table."""
-    def __init__(self, name=None):
+class _EntityField(object):
+    """Represents a field of a particular name and datatype."""
+
+    def __init__(self, name, datatype):
         self.name = name
-        self.datatype = None
-
-class EntityValue(object):
-    """the main job of EntityValue is to hold onto a value, corresponding the type of 
-    the value to the underlying datatype of its EntityField."""
-    def __init__(self, key=None, value=None):
-        if key is not None:
-            self.field = Session.query(EntityField).filter(EntityField.name==key).first() or EntityField(key)
-            if self.field.datatype is None:
-                if isinstance(value, int):
-                    self.field.datatype = 'int'
-                elif isinstance(value, str):
-                    self.field.datatype = 'string'
-                elif isinstance(value, datetime.datetime):
-                    self.field.datatype = 'datetime'
-                else:
-                    self.field.datatype = 'binary'
-            setattr(self, self.field.datatype + "_value", value)
+        self.datatype = datatype
+
+class _EntityValue(object):
+    """Represents an individual value."""
+
+    def __init__(self, key, value):
+        datatype = self._figure_datatype(value)
+        field = \
+            Session.query(_EntityField).filter(
+                and_(_EntityField.name==key, _EntityField.datatype==datatype)
+            ).first()
+
+        if not field:
+            field = _EntityField(key, datatype)
+            Session.add(field)
+        
+        self.field = field
+        setattr(self, self.field.datatype + "_value", value)
+    
+    def _figure_datatype(self, value):
+        typemap = {
+            int:'int',
+            str:'string',
+            datetime.datetime:'datetime',
+        }
+        for k in typemap:
+            if isinstance(value, k):
+                return typemap[k]
+        else:
+            return 'binary'
+
     def _get_value(self):
         return getattr(self, self.field.datatype + "_value")
+
     def _set_value(self, value):
         setattr(self, self.field.datatype + "_value", value)
-    name = property(lambda s: s.field.name)
     value = property(_get_value, _set_value)
+    
+    def name(self):
+        return self.field.name
+    name = property(name)
+
 
 # the mappers are a straightforward eager chain of 
 # Entity--(1->many)->EntityValue-(many->1)->EntityField
 # notice that we are identifying each mapper to its connecting
 # relation by just the class itself.
-mapper(EntityField, entity_fields)
+mapper(_EntityField, entity_fields)
 mapper(
-    EntityValue, entity_values,
+    _EntityValue, entity_values,
     properties = {
-        'field' : relation(EntityField, lazy=False, cascade='all')
+        'field' : relation(_EntityField, lazy=False, cascade='all')
     }
 )
 
 mapper(Entity, entities, properties = {
-    '_entities' : relation(EntityValue, lazy=False, cascade='all', collection_class=mapped_collection(lambda entityvalue: entityvalue.field.name))
+    '_entities' : relation(
+                        _EntityValue, 
+                        lazy=False, 
+                        cascade='all', 
+                        collection_class=mapped_collection(lambda entityvalue: entityvalue.field.name)
+                    )
 })
 
-# create two entities.  the objects can be used about as regularly as
-# any object can.
 session = Session()
-entity = Entity()
-entity.title = 'this is the first entity'
-entity.name =  'this is the name'
-entity.price = 43
-entity.data = ('hello', 'there')
-
-entity2 = Entity()
-entity2.title = 'this is the second entity'
-entity2.name = 'this is another name'
-entity2.price = 50
-entity2.data = ('hoo', 'ha')
-
-# commit
-[session.save(x) for x in (entity, entity2)]
-session.flush()
-
-# we would like to illustate loading everything totally clean from 
-# the database, so we clear out the session
-session.clear()
-
-# select both objects and print
-entities = session.query(Entity).select()
-for entity in entities:
-    print entity.title, entity.name, entity.price, entity.data
-
-# now change some data
-entities[0].price=90
-entities[0].title = 'another new title'
-entities[1].data = {'oof':5,'lala':8}
-entity3 = Entity()
-entity3.title = 'third entity'
-entity3.name = 'new name'
-entity3.price = '$1.95'
-entity3.data = 'some data'
-session.save(entity3)
-
-# commit changes.  the correct rows are updated, nothing else.
-session.flush()
-
-# lets see if that one came through.  clear the session, re-select
-# and print
-session.clear()
-entities = session.query(Entity).select()
-for entity in entities:
-    print entity.title, entity.name, entity.price, entity.data
-
-for entity in entities:
+entity1 = Entity(
+    title = 'this is the first entity',
+    name = 'this is the name',
+    price = 43,
+    data = ('hello', 'there')
+)
+
+entity2 = Entity(
+    title = 'this is the second entity',
+    name = 'this is another name',
+    price = 50,
+    data = ('hoo', 'ha')
+)
+
+session.add_all([entity1, entity2])
+session.commit()
+
+for entity in session.query(Entity):
+    print "Entity id %d:" % entity.entity_id, entity.title, entity.name, entity.price, entity.data
+
+# perform some changes, add a new Entity
+
+entity1.price = 90
+entity1.title = 'another new title'
+entity2.data = {'oof':5,'lala':8}
+
+entity3 = Entity(
+    title = 'third entity',
+    name = 'new name',
+    price = '$1.95',    # note we change 'price' to be a string.
+                        # this creates a new _EntityField separate from the
+                        # one used by integer 'price'.
+    data = 'some data'
+)
+session.add(entity3)
+
+session.commit()
+
+print "----------------"
+for entity in session.query(Entity):
+    print "Entity id %d:" % entity.entity_id, entity.title, entity.name, entity.price, entity.data
+
+print "----------------"
+# illustrate each _EntityField that's been created and list each Entity which uses it
+for ent_id, name, datatype in session.query(_EntityField.field_id, _EntityField.name, _EntityField.datatype):
+    print name, datatype, "(Enitites:",  ",".join([
+        str(entid) for entid in session.query(Entity.entity_id).\
+            join(
+                (_EntityValue, _EntityValue.entity_id==Entity.entity_id), 
+                (_EntityField, _EntityField.field_id==_EntityValue.field_id)
+            ).filter(_EntityField.field_id==ent_id)
+    ]), ")"
+
+# delete all the Entity objects
+for entity in session.query(Entity):
     session.delete(entity)
-session.flush()
+session.commit()
index ed82a7ca56572293678073e2319eac35dfb9605f..07caae07af555c467ad835340e1d8573ce960fc9 100644 (file)
@@ -222,6 +222,9 @@ class Query(object):
         return getattr(ent, 'extension', ent.mapper.extension)
 
     def _mapper_entities(self):
+        # TODO: this is wrong, its hardcoded to "priamry entity" when
+        # for the case of __all_equivs() it should not be
+        # the name of this accessor is wrong too
         for ent in self._entities:
             if hasattr(ent, 'primary_entity'):
                 yield ent
@@ -820,13 +823,13 @@ class Query(object):
                 if isinstance(onclause, interfaces.PropComparator):
                     clause = onclause.__clause_element__()
 
-                for ent in self._mapper_entities:
+                for ent in self._entities:
                     if ent.corresponds_to(left_entity):
                         clause = ent.selectable
                         break
 
             if not clause:
-                raise exc.InvalidRequestError("Could not find a FROM clause to join from")
+                raise sa_exc.InvalidRequestError("Could not find a FROM clause to join from")
 
             bogus, right_selectable, is_aliased_class = _entity_info(right_entity)
 
@@ -1619,6 +1622,15 @@ class _ColumnEntity(_QueryEntity):
         self.selectable = from_obj
         self.froms.add(from_obj)
 
+    def corresponds_to(self, entity):
+        if _is_aliased_class(entity):
+            return entity is self.entity_zero
+        else:
+            # TODO: this will fail with inheritance, entity_zero
+            # is not a base mapper.  MapperEntity has path_entity
+            # which serves this purpose (when saying: query(FooBar.somecol).join(SomeClass, FooBar.id==SomeClass.foo_id))
+            return entity.base_mapper is self.entity_zero
+
     def _resolve_expr_against_query_aliases(self, query, expr, context):
         return query._adapt_clause(expr, False, True)
 
index d3d0643648418361f066516604acd601e0e639d7..0c274efefdb46659de5efcab28a5d6a00aadac94 100644 (file)
@@ -169,6 +169,7 @@ class GetTest(QueryTest):
         assert u.addresses[0].email_address == 'jack@bean.com'
         assert u.orders[1].items[2].description == 'item 5'
 
+    @testing.fails_on_everything_except('sqlite')
     def test_query_str(self):
         s = create_session()
         q = s.query(User).filter(User.id==1)
@@ -731,6 +732,11 @@ class JoinTest(QueryTest):
             [User(name='jack')]
         )
 
+        self.assertEquals(
+            sess.query(User.name).select_from(join(User, Order).join(Item, Order.items)).filter(Item.description == 'item 4').all(),
+            [('jack',)]
+        )
+
         self.assertEquals(
             sess.query(User).join(Order, (Item, Order.items)).filter(Item.description == 'item 4').all(),
             [User(name='jack')]
@@ -748,8 +754,25 @@ class JoinTest(QueryTest):
             [User(name='jack')]
         )
 
-        # no arg error
-        result = sess.query(User).join('orders', aliased=True).order_by([Order.id]).reset_joinpoint().order_by(users.c.id).all()
+        self.assertEquals(
+            sess.query(User.name).join(
+                (Order, User.id==Order.user_id), 
+                (order_items, Order.id==order_items.c.order_id), 
+                (Item, order_items.c.item_id==Item.id)
+            ).filter(Item.description == 'item 4').all(),
+            [('jack',)]
+        )
+
+        ualias = aliased(User)
+        self.assertEquals(
+            sess.query(ualias.name).join(
+                (Order, ualias.id==Order.user_id), 
+                (order_items, Order.id==order_items.c.order_id), 
+                (Item, order_items.c.item_id==Item.id)
+            ).filter(Item.description == 'item 4').all(),
+            [('jack',)]
+        )
+
         
     def test_aliased_classes(self):
         sess = create_session()