From: Mike Bayer Date: Wed, 8 Mar 2006 20:51:51 +0000 (+0000) Subject: serious overhaul to get eager loads to work inline with an inheriting mapper, when... X-Git-Tag: rel_0_1_4~25 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=06b8c73ad5136f615957bdf4e535330885ae1635;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git serious overhaul to get eager loads to work inline with an inheriting mapper, when the inheritance/eager loads share the same table. mapper inheritance will also favor the columns from the child table over those of the parent table when assigning column values to object attributes. "correlated subqueries" require a flag "correlated=True" if they are in the FROM clause of another SELECT statement, and they want to be correlated. this flag is set by default when using an "exists" clause. --- diff --git a/lib/sqlalchemy/engine.py b/lib/sqlalchemy/engine.py index 518ea8a15c..d613205339 100644 --- a/lib/sqlalchemy/engine.py +++ b/lib/sqlalchemy/engine.py @@ -712,6 +712,7 @@ class ResultProxy: if self.props.setdefault(colname, rec) is not rec: self.props[colname] = (ResultProxy.AmbiguousColumn(colname), 0) self.keys.append(colname) + #print "COLNAME", colname self.props[i] = rec i+=1 @@ -719,6 +720,7 @@ class ResultProxy: if isinstance(key, schema.Column) or isinstance(key, sql.ColumnElement): try: rec = self.props[key._label.lower()] + #print "GOT IT FROM LABEL FOR ", key._label except KeyError: try: rec = self.props[key.key.lower()] diff --git a/lib/sqlalchemy/mapping/mapper.py b/lib/sqlalchemy/mapping/mapper.py index 330c75c313..93af9d7f21 100644 --- a/lib/sqlalchemy/mapping/mapper.py +++ b/lib/sqlalchemy/mapping/mapper.py @@ -145,7 +145,12 @@ class Mapper(object): prop = ColumnProperty(column) self.props[column.key] = prop elif isinstance(prop, ColumnProperty): - prop.columns.append(column) + # the order which columns are appended to a ColumnProperty is significant, as the + # column at index 0 determines which result column is used to populate the object + # attribute, in the case of mapping against a join with column names repeated + # (and particularly in an inheritance relationship) + prop.columns.insert(0, column) + #prop.columns.append(column) else: if not allow_column_override: raise ArgumentError("WARNING: column '%s' not being added due to property '%s'. Specify 'allow_column_override=True' to mapper() to ignore this condition." % (column.key, repr(prop))) @@ -179,6 +184,12 @@ class Mapper(object): if getattr(prop, 'key', None) is None: prop.init(key, self) + # this prints a summary of the object attributes and how they + # will be mapped to table columns + #print "mapper %s, columntoproperty:" % (self.class_.__name__) + #for key, value in self.columntoproperty.iteritems(): + # print key.table.name, key.key, [(v.key, v) for v in value] + engines = property(lambda s: [t.engine for t in s.tables]) def add_property(self, key, prop): @@ -638,9 +649,7 @@ class Mapper(object): def delete_obj(self, objects, uow): """called by a UnitOfWork object to delete objects, which involves a DELETE statement for each table used by this mapper, for each object in the list.""" - l = list(self.tables) - l.reverse() - for table in l: + for table in util.reversed(self.tables): if not self._has_pks(table): continue delete = [] @@ -703,7 +712,8 @@ class Mapper(object): order_by = self.table.default_order_by() if self._should_nest(**kwargs): - s2 = sql.select(self.table.primary_key, whereclause, use_labels=True, **kwargs) + s2 = sql.select(self.table.primary_key, whereclause, use_labels=True, from_obj=[self.table], **kwargs) +# raise "ok first thing", str(s2) if not kwargs.get('distinct', False) and order_by: s2.order_by(*util.to_list(order_by)) s3 = s2.alias('rowcount') @@ -711,6 +721,7 @@ class Mapper(object): for i in range(0, len(self.table.primary_key)): crit.append(s3.primary_key[i] == self.table.primary_key[i]) statement = sql.select([], sql.and_(*crit), from_obj=[self.table], use_labels=True) + # raise "OK statement", str(statement) if order_by: statement.order_by(*util.to_list(order_by)) else: @@ -930,6 +941,8 @@ class TableFinder(sql.ClauseVisitor): table.accept_visitor(self) def visit_table(self, table): self.tables.append(table) + def __len__(self): + return len(self.tables) def __getitem__(self, i): return self.tables[i] def __iter__(self): diff --git a/lib/sqlalchemy/mapping/properties.py b/lib/sqlalchemy/mapping/properties.py index 183cf97f62..a8b1fde4ee 100644 --- a/lib/sqlalchemy/mapping/properties.py +++ b/lib/sqlalchemy/mapping/properties.py @@ -48,6 +48,7 @@ class ColumnProperty(MapperProperty): objectstore.uow().register_attribute(parent.class_, key, uselist = False) def execute(self, instance, row, identitykey, imap, isnew): if isnew: + #print "POPULATING OBJ", instance.__class__.__name__, "COL", self.columns[0]._label, "WITH DATA", row[self.columns[0]], "ROW IS A", row.__class__.__name__, "COL ID", id(self.columns[0]) instance.__dict__[self.key] = row[self.columns[0]] def __repr__(self): return "ColumnProperty(%s)" % repr([str(c) for c in self.columns]) @@ -648,16 +649,19 @@ class EagerLoader(PropertyLoader): parent._has_eager = True self.eagertarget = self.target.alias() +# print "ALIAS", str(self.eagertarget.select()) #selectable.__class__.__name__ if self.secondary: self.eagersecondary = self.secondary.alias() self.aliasizer = Aliasizer(self.target, self.secondary, aliases={ self.target:self.eagertarget, self.secondary:self.eagersecondary }) + #print "TARGET", self.target self.eagersecondaryjoin = self.secondaryjoin.copy_container() self.eagersecondaryjoin.accept_visitor(self.aliasizer) self.eagerprimary = self.primaryjoin.copy_container() self.eagerprimary.accept_visitor(self.aliasizer) + #print "JOINS:", str(self.eagerprimary), "|", str(self.eagersecondaryjoin) else: self.aliasizer = Aliasizer(self.target, aliases={self.target:self.eagertarget}) self.eagerprimary = self.primaryjoin.copy_container() @@ -778,7 +782,8 @@ class EagerLoader(PropertyLoader): """gets an instance from a row, via this EagerLoader's mapper.""" fakerow = util.DictDecorator(row) for c in self.eagertarget.c: - fakerow[c.parent] = row[c] + parent = self.target._get_col_by_original(c.original) + fakerow[parent] = row[c] row = fakerow return self.mapper._instance(row, imap, result_list) @@ -882,15 +887,18 @@ class Aliasizer(sql.ClauseVisitor): """converts a table instance within an expression to be an alias of that table.""" def __init__(self, *tables, **kwargs): self.tables = {} + self.aliases = kwargs.get('aliases', {}) for t in tables: self.tables[t] = t + if not self.aliases.has_key(t): + self.aliases[t] = sql.alias(t) + if isinstance(t, sql.Join): + for t2 in t.columns: + self.tables[t2.table] = t2 + self.aliases[t2.table] = self.aliases[t] self.binary = None - self.aliases = kwargs.get('aliases', {}) def get_alias(self, table): - try: - return self.aliases[table] - except: - return self.aliases.setdefault(table, sql.alias(table)) + return self.aliases[table] def visit_compound(self, compound): self.visit_clauselist(compound) def visit_clauselist(self, clist): diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index 73481009d7..04949c9354 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -129,6 +129,7 @@ def between_(ctest, cleft, cright): return BooleanExpression(ctest, and_(cleft, cright), 'BETWEEN') def exists(*args, **params): + params['correlate'] = True s = select(*args, **params) return BooleanExpression(TextClause("EXISTS"), s, None) @@ -839,7 +840,7 @@ class Join(FromClause): self.left = left self.right = right self.id = self.left.id + "_" + self.right.id - + # TODO: if no onclause, do NATURAL JOIN if onclause is None: self.onclause = self._match_primaries(left, right) @@ -852,7 +853,7 @@ class Join(FromClause): def _exportable_columns(self): return [c for c in self.left.columns] + [c for c in self.right.columns] def _proxy_column(self, column): - self._columns[column.table.name + "_" + column.key] = column + self._columns[column._label] = column if column.primary_key: self._primary_key.append(column) if column.foreign_key: @@ -894,7 +895,9 @@ class Join(FromClause): self.join = join def _exportable_columns(self): return [] - + + def alias(self, name=None): + return self.select(use_labels=True).alias(name) def _process_from_dict(self, data, asfrom): for f in self.onclause._get_from_objects(): data[f.id] = f @@ -915,7 +918,7 @@ class Alias(FromClause): self.original = baseselectable self.selectable = selectable if alias is None: - n = getattr(self.original, 'name') + n = getattr(self.original, 'name', None) if n is None: n = 'anon' elif len(n) > 15: @@ -974,7 +977,7 @@ class ColumnClause(ColumnElement): self.__label = None def _get_label(self): if self.__label is None: - if self.table is not None: + if self.table is not None and self.table.name is not None: self.__label = self.table.name + "_" + self.text else: self.__label = self.text @@ -1164,7 +1167,7 @@ class CompoundSelect(SelectBaseMixin, FromClause): class Select(SelectBaseMixin, FromClause): """represents a SELECT statement, with appendable clauses, as well as the ability to execute itself and return a result set.""" - def __init__(self, columns=None, whereclause = None, from_obj = [], order_by = None, group_by=None, having=None, use_labels = False, distinct=False, engine = None, limit=None, offset=None): + def __init__(self, columns=None, whereclause = None, from_obj = [], order_by = None, group_by=None, having=None, use_labels = False, distinct=False, engine = None, limit=None, offset=None, correlate=False): self._froms = util.OrderedDict() self.use_labels = use_labels self.id = "Select(%d)" % id(self) @@ -1175,6 +1178,7 @@ class Select(SelectBaseMixin, FromClause): self.oid_column = None self.limit = limit self.offset = offset + self.correlate = correlate # indicates if this select statement is a subquery inside another query self.issubquery = False @@ -1224,9 +1228,11 @@ class Select(SelectBaseMixin, FromClause): select.is_where = self.is_where select.issubquery = True select.parens = True + if not self.is_where and not select.correlate: + return if getattr(select, '_correlated', None) is None: select._correlated = self.select._froms - + def append_column(self, column): if _is_literal(column): column = ColumnClause(str(column), self) @@ -1266,7 +1272,8 @@ class Select(SelectBaseMixin, FromClause): def append_from(self, fromclause): if type(fromclause) == str: fromclause = FromClause(from_name = fromclause) - + if self.oid_column is None and hasattr(fromclause, 'oid_column'): + self.oid_column = fromclause.oid_column fromclause.accept_visitor(self._correlator) fromclause._process_from_dict(self._froms, True) diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py index 7115dbcec4..d60166a0c1 100644 --- a/lib/sqlalchemy/util.py +++ b/lib/sqlalchemy/util.py @@ -7,6 +7,7 @@ __all__ = ['OrderedProperties', 'OrderedDict', 'generic_repr', 'HashSet', 'AttrProp'] import thread, threading, weakref, UserList, time, string, inspect, sys from exceptions import * +import __builtin__ def to_list(x): if x is None: @@ -24,6 +25,18 @@ def to_set(x): else: return x +def reversed(seq): + try: + return __builtin__.reversed(seq) + except: + def rev(): + i = len(seq) -1 + while i >= 0: + yield seq[i] + i -= 1 + raise StopIteration() + return rev() + class AttrProp(object): """a quick way to stick a property accessor on an object""" def __init__(self, key): diff --git a/test/inheritance.py b/test/inheritance.py index 4400cab891..d679cf8c17 100644 --- a/test/inheritance.py +++ b/test/inheritance.py @@ -2,8 +2,7 @@ from sqlalchemy import * import testbase import string import sqlalchemy.attributes as attr - - +import sys class Principal( object ): pass @@ -54,8 +53,6 @@ class InheritTest(testbase.AssertMixin): ) - - principals.create() users.create() groups.create() @@ -65,6 +62,7 @@ class InheritTest(testbase.AssertMixin): groups.drop() users.drop() principals.drop() + testbase.db.tables.clear() def setUp(self): objectstore.clear() clear_mappers() @@ -111,6 +109,7 @@ class InheritTest2(testbase.AssertMixin): foo_bar.drop() bar.drop() foo.drop() + testbase.db.tables.clear() def testbasic(self): class Foo(object): @@ -155,6 +154,112 @@ class InheritTest2(testbase.AssertMixin): {'id':b.id, 'data':'barfoo', 'foos':(Foo, [{'id':f1.id,'data':'subfoo1'}, {'id':f2.id,'data':'subfoo2'}])}, ) +class InheritTest3(testbase.AssertMixin): + def setUpAll(self): + engine = testbase.db + global foo, bar, blub, bar_foo, blub_bar, blub_foo,tables + engine.engine.echo = 'debug' + # the 'data' columns are to appease SQLite which cant handle a blank INSERT + foo = Table('foo', engine, + Column('id', Integer, Sequence('foo_seq'), primary_key=True), + Column('data', String(20))) + + bar = Table('bar', engine, + Column('id', Integer, ForeignKey('foo.id'), primary_key=True), + Column('data', String(20))) + + blub = Table('blub', engine, + Column('id', Integer, ForeignKey('bar.id'), primary_key=True), + Column('data', String(20))) + + bar_foo = Table('bar_foo', engine, + Column('bar_id', Integer, ForeignKey('bar.id')), + Column('foo_id', Integer, ForeignKey('foo.id'))) + + blub_bar = Table('bar_blub', engine, + Column('blub_id', Integer, ForeignKey('blub.id')), + Column('bar_id', Integer, ForeignKey('bar.id'))) + + blub_foo = Table('blub_foo', engine, + Column('blub_id', Integer, ForeignKey('blub.id')), + Column('foo_id', Integer, ForeignKey('foo.id'))) + + tables = [foo, bar, blub, bar_foo, blub_bar, blub_foo] + for table in tables: + table.create() + def tearDownAll(self): + for table in reversed(tables): + table.drop() + testbase.db.tables.clear() + + def testbasic(self): + class Foo(object): + def __repr__(self): + return "Foo id %d, data %s" % (self.id, self.data) + Foo.mapper = mapper(Foo, foo) + + class Bar(object): + def __repr__(self): + return "Bar id %d, data %s" % (self.id, self.data) + + Bar.mapper = mapper(Bar, bar, inherits=Foo.mapper, properties={ + 'foos' :relation(Foo.mapper, bar_foo, primaryjoin=bar.c.id==bar_foo.c.bar_id, secondaryjoin=bar_foo.c.foo_id==foo.c.id, lazy=False) + }) + + Bar.mapper.select() + + def testadvanced(self): + class Foo(object): + def __init__(self, data=None): + self.data = data + def __repr__(self): + return "Foo id %d, data %s" % (self.id, self.data) + Foo.mapper = mapper(Foo, foo) + + class Bar(Foo): + def __repr__(self): + return "Bar id %d, data %s" % (self.id, self.data) + Bar.mapper = mapper(Bar, bar, inherits=Foo.mapper) + + class Blub(Bar): + def __repr__(self): + return "Blub id %d, data %s, bars %s, foos %s" % (self.id, self.data, repr([b for b in self.bars]), repr([f for f in self.foos])) + + Blub.mapper = mapper(Blub, blub, inherits=Bar.mapper, properties={ + 'bars':relation(Bar.mapper, blub_bar, primaryjoin=blub.c.id==blub_bar.c.blub_id, secondaryjoin=blub_bar.c.bar_id==bar.c.id, lazy=False), + 'foos':relation(Foo.mapper, blub_foo, primaryjoin=blub.c.id==blub_foo.c.blub_id, secondaryjoin=blub_foo.c.foo_id==foo.c.id, lazy=False), + }) + + useobjects = True + if (useobjects): + f1 = Foo("foo #1") + b1 = Bar("bar #1") + b2 = Bar("bar #2") + bl1 = Blub("blub #1") + bl1.foos.append(f1) + bl1.bars.append(b2) + objectstore.commit() + compare = repr(bl1) + blubid = bl1.id + objectstore.clear() + else: + foo.insert().execute(data='foo #1') + foo.insert().execute(data='foo #2') + bar.insert().execute(id=1, data="bar #1") + bar.insert().execute(id=2, data="bar #2") + blub.insert().execute(id=1, data="blub #1") + blub_bar.insert().execute(blub_id=1, bar_id=2) + blub_foo.insert().execute(blub_id=1, foo_id=2) + + l = Blub.mapper.select() + for x in l: + print x + + self.assert_(repr(l[0]) == compare) + objectstore.clear() + x = Blub.mapper.get_by(id=blubid) #traceback 2 + self.assert_(repr(x) == compare) + if __name__ == "__main__": testbase.main() diff --git a/test/select.py b/test/select.py index 20454a9fc4..47bc19515c 100644 --- a/test/select.py +++ b/test/select.py @@ -418,7 +418,7 @@ FROM mytable, myothertable WHERE mytable.myid = myothertable.otherid AND mytable select([s, table1]) ,"SELECT sq2.myid, sq2.name, sq2.description, mytable.myid, mytable.name, mytable.description FROM (SELECT ta.myid AS myid, ta.name AS name, ta.description AS description FROM mytable AS ta WHERE EXISTS (SELECT 1 FROM myothertable WHERE myothertable.otherid = ta.myid)) AS sq2, mytable") - s = select([addresses.c.street], addresses.c.user_id==users.c.user_id).alias('s') + s = select([addresses.c.street], addresses.c.user_id==users.c.user_id, correlate=True).alias('s') self.runtest( select([users, s.c.street], from_obj=[s]), """SELECT users.user_id, users.user_name, users.password, s.street FROM users, (SELECT addresses.street AS street FROM addresses WHERE addresses.user_id = users.user_id) AS s""")