]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
serious overhaul to get eager loads to work inline with an inheriting mapper, when...
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 8 Mar 2006 20:51:51 +0000 (20:51 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 8 Mar 2006 20:51:51 +0000 (20:51 +0000)
lib/sqlalchemy/engine.py
lib/sqlalchemy/mapping/mapper.py
lib/sqlalchemy/mapping/properties.py
lib/sqlalchemy/sql.py
lib/sqlalchemy/util.py
test/inheritance.py
test/select.py

index 518ea8a15c5ba2440ac132970c989802a6acd654..d61320533963ec8d793670b52b6885eadaf06da8 100644 (file)
@@ -712,6 +712,7 @@ class ResultProxy:
                 if self.props.setdefault(colname, rec) is not rec:
                     self.props[colname] = (ResultProxy.AmbiguousColumn(colname), 0)
                 self.keys.append(colname)
+                #print "COLNAME", colname
                 self.props[i] = rec
                 i+=1
 
@@ -719,6 +720,7 @@ class ResultProxy:
         if isinstance(key, schema.Column) or isinstance(key, sql.ColumnElement):
             try:
                 rec = self.props[key._label.lower()]
+                #print "GOT IT FROM LABEL FOR ", key._label
             except KeyError:
                 try:
                     rec = self.props[key.key.lower()]
index 330c75c313fcbd458438585df06608d0e5a9d0bd..93af9d7f21f1886f3988913ed67c5eef8cae9c30 100644 (file)
@@ -145,7 +145,12 @@ class Mapper(object):
                 prop = ColumnProperty(column)
                 self.props[column.key] = prop
             elif isinstance(prop, ColumnProperty):
-                prop.columns.append(column)
+                # the order which columns are appended to a ColumnProperty is significant, as the 
+                # column at index 0 determines which result column is used to populate the object
+                # attribute, in the case of mapping against a join with column names repeated
+                # (and particularly in an inheritance relationship)
+                prop.columns.insert(0, column)
+                #prop.columns.append(column)
             else:
                 if not allow_column_override:
                     raise ArgumentError("WARNING: column '%s' not being added due to property '%s'.  Specify 'allow_column_override=True' to mapper() to ignore this condition." % (column.key, repr(prop)))
@@ -179,6 +184,12 @@ class Mapper(object):
             if getattr(prop, 'key', None) is None:
                 prop.init(key, self)
 
+        # this prints a summary of the object attributes and how they
+        # will be mapped to table columns
+        #print "mapper %s, columntoproperty:" % (self.class_.__name__)
+        #for key, value in self.columntoproperty.iteritems():
+        #    print key.table.name, key.key, [(v.key, v) for v in value]
+            
     engines = property(lambda s: [t.engine for t in s.tables])
 
     def add_property(self, key, prop):
@@ -638,9 +649,7 @@ class Mapper(object):
     def delete_obj(self, objects, uow):
         """called by a UnitOfWork object to delete objects, which involves a
         DELETE statement for each table used by this mapper, for each object in the list."""
-        l = list(self.tables)
-        l.reverse()
-        for table in l:
+        for table in util.reversed(self.tables):
             if not self._has_pks(table):
                 continue
             delete = []
@@ -703,7 +712,8 @@ class Mapper(object):
                 order_by = self.table.default_order_by()
 
         if self._should_nest(**kwargs):
-            s2 = sql.select(self.table.primary_key, whereclause, use_labels=True, **kwargs)
+            s2 = sql.select(self.table.primary_key, whereclause, use_labels=True, from_obj=[self.table], **kwargs)
+#            raise "ok first thing", str(s2)
             if not kwargs.get('distinct', False) and order_by:
                 s2.order_by(*util.to_list(order_by))
             s3 = s2.alias('rowcount')
@@ -711,6 +721,7 @@ class Mapper(object):
             for i in range(0, len(self.table.primary_key)):
                 crit.append(s3.primary_key[i] == self.table.primary_key[i])
             statement = sql.select([], sql.and_(*crit), from_obj=[self.table], use_labels=True)
+ #           raise "OK statement", str(statement)
             if order_by:
                 statement.order_by(*util.to_list(order_by))
         else:
@@ -930,6 +941,8 @@ class TableFinder(sql.ClauseVisitor):
         table.accept_visitor(self)
     def visit_table(self, table):
         self.tables.append(table)
+    def __len__(self):
+        return len(self.tables)
     def __getitem__(self, i):
         return self.tables[i]
     def __iter__(self):
index 183cf97f628269b235e723f437bea5b9bde28cae..a8b1fde4ee27948fdfbc60e3fc0a52fb2b29a191 100644 (file)
@@ -48,6 +48,7 @@ class ColumnProperty(MapperProperty):
             objectstore.uow().register_attribute(parent.class_, key, uselist = False)
     def execute(self, instance, row, identitykey, imap, isnew):
         if isnew:
+            #print "POPULATING OBJ", instance.__class__.__name__, "COL", self.columns[0]._label, "WITH DATA", row[self.columns[0]], "ROW IS A", row.__class__.__name__, "COL ID", id(self.columns[0])
             instance.__dict__[self.key] = row[self.columns[0]]
     def __repr__(self):
         return "ColumnProperty(%s)" % repr([str(c) for c in self.columns])
@@ -648,16 +649,19 @@ class EagerLoader(PropertyLoader):
         parent._has_eager = True
 
         self.eagertarget = self.target.alias()
+#        print "ALIAS", str(self.eagertarget.select()) #selectable.__class__.__name__
         if self.secondary:
             self.eagersecondary = self.secondary.alias()
             self.aliasizer = Aliasizer(self.target, self.secondary, aliases={
                     self.target:self.eagertarget,
                     self.secondary:self.eagersecondary
                     })
+            #print "TARGET", self.target
             self.eagersecondaryjoin = self.secondaryjoin.copy_container()
             self.eagersecondaryjoin.accept_visitor(self.aliasizer)
             self.eagerprimary = self.primaryjoin.copy_container()
             self.eagerprimary.accept_visitor(self.aliasizer)
+            #print "JOINS:", str(self.eagerprimary), "|", str(self.eagersecondaryjoin)
         else:
             self.aliasizer = Aliasizer(self.target, aliases={self.target:self.eagertarget})
             self.eagerprimary = self.primaryjoin.copy_container()
@@ -778,7 +782,8 @@ class EagerLoader(PropertyLoader):
         """gets an instance from a row, via this EagerLoader's mapper."""
         fakerow = util.DictDecorator(row)
         for c in self.eagertarget.c:
-            fakerow[c.parent] = row[c]
+            parent = self.target._get_col_by_original(c.original)
+            fakerow[parent] = row[c]
         row = fakerow
         return self.mapper._instance(row, imap, result_list)
 
@@ -882,15 +887,18 @@ class Aliasizer(sql.ClauseVisitor):
     """converts a table instance within an expression to be an alias of that table."""
     def __init__(self, *tables, **kwargs):
         self.tables = {}
+        self.aliases = kwargs.get('aliases', {})
         for t in tables:
             self.tables[t] = t
+            if not self.aliases.has_key(t):
+                self.aliases[t] = sql.alias(t)
+            if isinstance(t, sql.Join):
+                for t2 in t.columns:
+                    self.tables[t2.table] = t2
+                    self.aliases[t2.table] = self.aliases[t]
         self.binary = None
-        self.aliases = kwargs.get('aliases', {})
     def get_alias(self, table):
-        try:
-            return self.aliases[table]
-        except:
-            return self.aliases.setdefault(table, sql.alias(table))
+        return self.aliases[table]
     def visit_compound(self, compound):
         self.visit_clauselist(compound)
     def visit_clauselist(self, clist):
index 73481009d79ac8d3e580863fbc178934eb199426..04949c93546e2c24eabe660ab3e8767ccb4a3e39 100644 (file)
@@ -129,6 +129,7 @@ def between_(ctest, cleft, cright):
     return BooleanExpression(ctest, and_(cleft, cright), 'BETWEEN')
         
 def exists(*args, **params):
+    params['correlate'] = True
     s = select(*args, **params)
     return BooleanExpression(TextClause("EXISTS"), s, None)
 
@@ -839,7 +840,7 @@ class Join(FromClause):
         self.left = left
         self.right = right
         self.id = self.left.id + "_" + self.right.id
-
+        
         # TODO: if no onclause, do NATURAL JOIN
         if onclause is None:
             self.onclause = self._match_primaries(left, right)
@@ -852,7 +853,7 @@ class Join(FromClause):
     def _exportable_columns(self):
         return [c for c in self.left.columns] + [c for c in self.right.columns]
     def _proxy_column(self, column):
-        self._columns[column.table.name + "_" + column.key] = column
+        self._columns[column._label] = column
         if column.primary_key:
             self._primary_key.append(column)
         if column.foreign_key:
@@ -894,7 +895,9 @@ class Join(FromClause):
             self.join = join
         def _exportable_columns(self):
             return []
-                
+    
+    def alias(self, name=None):
+        return self.select(use_labels=True).alias(name)            
     def _process_from_dict(self, data, asfrom):
         for f in self.onclause._get_from_objects():
             data[f.id] = f
@@ -915,7 +918,7 @@ class Alias(FromClause):
         self.original = baseselectable
         self.selectable = selectable
         if alias is None:
-            n = getattr(self.original, 'name')
+            n = getattr(self.original, 'name', None)
             if n is None:
                 n = 'anon'
             elif len(n) > 15:
@@ -974,7 +977,7 @@ class ColumnClause(ColumnElement):
         self.__label = None
     def _get_label(self):
         if self.__label is None:
-            if self.table is not None:
+            if self.table is not None and self.table.name is not None:
                 self.__label =  self.table.name + "_" + self.text
             else:
                 self.__label = self.text
@@ -1164,7 +1167,7 @@ class CompoundSelect(SelectBaseMixin, FromClause):
 class Select(SelectBaseMixin, FromClause):
     """represents a SELECT statement, with appendable clauses, as well as 
     the ability to execute itself and return a result set."""
-    def __init__(self, columns=None, whereclause = None, from_obj = [], order_by = None, group_by=None, having=None, use_labels = False, distinct=False, engine = None, limit=None, offset=None):
+    def __init__(self, columns=None, whereclause = None, from_obj = [], order_by = None, group_by=None, having=None, use_labels = False, distinct=False, engine = None, limit=None, offset=None, correlate=False):
         self._froms = util.OrderedDict()
         self.use_labels = use_labels
         self.id = "Select(%d)" % id(self)
@@ -1175,6 +1178,7 @@ class Select(SelectBaseMixin, FromClause):
         self.oid_column = None
         self.limit = limit
         self.offset = offset
+        self.correlate = correlate
         
         # indicates if this select statement is a subquery inside another query
         self.issubquery = False
@@ -1224,9 +1228,11 @@ class Select(SelectBaseMixin, FromClause):
             select.is_where = self.is_where
             select.issubquery = True
             select.parens = True
+            if not self.is_where and not select.correlate:
+                return
             if getattr(select, '_correlated', None) is None:
                 select._correlated = self.select._froms
-
+                
     def append_column(self, column):
         if _is_literal(column):
             column = ColumnClause(str(column), self)
@@ -1266,7 +1272,8 @@ class Select(SelectBaseMixin, FromClause):
     def append_from(self, fromclause):
         if type(fromclause) == str:
             fromclause = FromClause(from_name = fromclause)
-
+        if self.oid_column is None and hasattr(fromclause, 'oid_column'):
+            self.oid_column = fromclause.oid_column
         fromclause.accept_visitor(self._correlator)
         fromclause._process_from_dict(self._froms, True)
 
index 7115dbcec47d2a0960e26f0c08dcaab990224705..d60166a0c1eba1975bacaa8c603c9509dbe0dcca 100644 (file)
@@ -7,6 +7,7 @@
 __all__ = ['OrderedProperties', 'OrderedDict', 'generic_repr', 'HashSet', 'AttrProp']
 import thread, threading, weakref, UserList, time, string, inspect, sys
 from exceptions import *
+import __builtin__
 
 def to_list(x):
     if x is None:
@@ -24,6 +25,18 @@ def to_set(x):
     else:
         return x
 
+def reversed(seq):
+    try:
+        return __builtin__.reversed(seq)
+    except:
+        def rev():
+            i = len(seq) -1
+            while  i >= 0:
+                yield seq[i]
+                i -= 1
+            raise StopIteration()
+        return rev()
+        
 class AttrProp(object):
     """a quick way to stick a property accessor on an object"""
     def __init__(self, key):
index 4400cab891ef10dde578548cd28db108ae805a42..d679cf8c1729691cbae77c96b20c78ad58c1f483 100644 (file)
@@ -2,8 +2,7 @@ from sqlalchemy import *
 import testbase
 import string
 import sqlalchemy.attributes as attr
-
-
+import sys
 
 class Principal( object ):
     pass
@@ -54,8 +53,6 @@ class InheritTest(testbase.AssertMixin):
 
                 )
 
-
-
             principals.create()
             users.create()
             groups.create()
@@ -65,6 +62,7 @@ class InheritTest(testbase.AssertMixin):
             groups.drop()
             users.drop()
             principals.drop()
+            testbase.db.tables.clear()
         def setUp(self):
             objectstore.clear()
             clear_mappers()
@@ -111,6 +109,7 @@ class InheritTest2(testbase.AssertMixin):
         foo_bar.drop()
         bar.drop()
         foo.drop()
+        testbase.db.tables.clear()
 
     def testbasic(self):
         class Foo(object): 
@@ -155,6 +154,112 @@ class InheritTest2(testbase.AssertMixin):
             {'id':b.id, 'data':'barfoo', 'foos':(Foo, [{'id':f1.id,'data':'subfoo1'}, {'id':f2.id,'data':'subfoo2'}])},
             )
 
+class InheritTest3(testbase.AssertMixin):
+    def setUpAll(self):
+        engine = testbase.db
+        global foo, bar, blub, bar_foo, blub_bar, blub_foo,tables
+        engine.engine.echo = 'debug'
+        # the 'data' columns are to appease SQLite which cant handle a blank INSERT
+        foo = Table('foo', engine,
+            Column('id', Integer, Sequence('foo_seq'), primary_key=True),
+            Column('data', String(20)))
+
+        bar = Table('bar', engine,
+            Column('id', Integer, ForeignKey('foo.id'), primary_key=True),
+            Column('data', String(20)))
+
+        blub = Table('blub', engine,
+            Column('id', Integer, ForeignKey('bar.id'), primary_key=True),
+            Column('data', String(20)))
+
+        bar_foo = Table('bar_foo', engine, 
+            Column('bar_id', Integer, ForeignKey('bar.id')),
+            Column('foo_id', Integer, ForeignKey('foo.id')))
+            
+        blub_bar = Table('bar_blub', engine,
+            Column('blub_id', Integer, ForeignKey('blub.id')),
+            Column('bar_id', Integer, ForeignKey('bar.id')))
+
+        blub_foo = Table('blub_foo', engine,
+            Column('blub_id', Integer, ForeignKey('blub.id')),
+            Column('foo_id', Integer, ForeignKey('foo.id')))
+
+        tables = [foo, bar, blub, bar_foo, blub_bar, blub_foo]
+        for table in tables:
+            table.create()
+    def tearDownAll(self):
+        for table in reversed(tables):
+            table.drop()
+        testbase.db.tables.clear()
+
+    def testbasic(self):
+        class Foo(object):
+            def __repr__(self):
+                return "Foo id %d, data %s" % (self.id, self.data)
+        Foo.mapper = mapper(Foo, foo)
+
+        class Bar(object):
+            def __repr__(self):
+                return "Bar id %d, data %s" % (self.id, self.data)
+                
+        Bar.mapper = mapper(Bar, bar, inherits=Foo.mapper, properties={
+            'foos' :relation(Foo.mapper, bar_foo, primaryjoin=bar.c.id==bar_foo.c.bar_id, secondaryjoin=bar_foo.c.foo_id==foo.c.id, lazy=False)
+        })
+
+        Bar.mapper.select()
+    
+    def testadvanced(self):    
+        class Foo(object):
+            def __init__(self, data=None):
+                self.data = data
+            def __repr__(self):
+                return "Foo id %d, data %s" % (self.id, self.data)
+        Foo.mapper = mapper(Foo, foo)
+
+        class Bar(Foo):
+            def __repr__(self):
+                return "Bar id %d, data %s" % (self.id, self.data)
+        Bar.mapper = mapper(Bar, bar, inherits=Foo.mapper)
+
+        class Blub(Bar):
+            def __repr__(self):
+                return "Blub id %d, data %s, bars %s, foos %s" % (self.id, self.data, repr([b for b in self.bars]), repr([f for f in self.foos]))
+            
+        Blub.mapper = mapper(Blub, blub, inherits=Bar.mapper, properties={
+            'bars':relation(Bar.mapper, blub_bar, primaryjoin=blub.c.id==blub_bar.c.blub_id, secondaryjoin=blub_bar.c.bar_id==bar.c.id, lazy=False),
+            'foos':relation(Foo.mapper, blub_foo, primaryjoin=blub.c.id==blub_foo.c.blub_id, secondaryjoin=blub_foo.c.foo_id==foo.c.id, lazy=False),
+        })
+
+        useobjects = True
+        if (useobjects):
+            f1 = Foo("foo #1")
+            b1 = Bar("bar #1")
+            b2 = Bar("bar #2")
+            bl1 = Blub("blub #1")
+            bl1.foos.append(f1)
+            bl1.bars.append(b2)
+            objectstore.commit()
+            compare = repr(bl1)
+            blubid = bl1.id
+            objectstore.clear()
+        else:
+            foo.insert().execute(data='foo #1')
+            foo.insert().execute(data='foo #2')
+            bar.insert().execute(id=1, data="bar #1")
+            bar.insert().execute(id=2, data="bar #2")
+            blub.insert().execute(id=1, data="blub #1")
+            blub_bar.insert().execute(blub_id=1, bar_id=2)
+            blub_foo.insert().execute(blub_id=1, foo_id=2)
+
+        l = Blub.mapper.select()
+        for x in l:
+            print x
+        
+        self.assert_(repr(l[0]) == compare)
+        objectstore.clear()
+        x = Blub.mapper.get_by(id=blubid) #traceback 2
+        self.assert_(repr(x) == compare)
+
 
 if __name__ == "__main__":    
     testbase.main()
index 20454a9fc4502a85bfd5242a52e403f977da1e30..47bc19515cdc0dd93784d937dcc554d162edd984 100644 (file)
@@ -418,7 +418,7 @@ FROM mytable, myothertable WHERE mytable.myid = myothertable.otherid AND mytable
             select([s, table1])
             ,"SELECT sq2.myid, sq2.name, sq2.description, mytable.myid, mytable.name, mytable.description FROM (SELECT ta.myid AS myid, ta.name AS name, ta.description AS description FROM mytable AS ta WHERE EXISTS (SELECT 1 FROM myothertable WHERE myothertable.otherid = ta.myid)) AS sq2, mytable")
 
-        s = select([addresses.c.street], addresses.c.user_id==users.c.user_id).alias('s')
+        s = select([addresses.c.street], addresses.c.user_id==users.c.user_id, correlate=True).alias('s')
         self.runtest(
             select([users, s.c.street], from_obj=[s]),
             """SELECT users.user_id, users.user_name, users.password, s.street FROM users, (SELECT addresses.street AS street FROM addresses WHERE addresses.user_id = users.user_id) AS s""")