From: Mike Bayer Date: Tue, 27 Mar 2007 22:06:36 +0000 (+0000) Subject: - fixes [ticket:185], join object determines primary key and removes X-Git-Tag: rel_0_3_7~104 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=de4c25cd028d242eaf0adbba89731f1e791e1dfe;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - fixes [ticket:185], join object determines primary key and removes columns that are FK's to other columns in the primary key collection. - removed workaround code from query.py get() - removed obsolete inheritance test from mapper - added new get() test to inheritance.py for this particular issue - ColumnCollection has nicer string method --- diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 070fb4cb5e..77499c2714 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -774,17 +774,9 @@ class Query(object): ident = key[1] else: ident = util.to_list(ident) - i = 0 params = {} - for primary_key in self.primary_key_columns: + for i, primary_key in enumerate(self.primary_key_columns): params[primary_key._label] = ident[i] - # if there are not enough elements in the given identifier, then - # use the previous identifier repeatedly. this is a workaround for the issue - # in [ticket:185], where a mapper that uses joined table inheritance needs to specify - # all primary keys of the joined relationship, which includes even if the join is joining - # two primary key (and therefore synonymous) columns together, the usual case for joined table inheritance. - if len(ident) > i + 1: - i += 1 try: statement = self.compile(self._get_clause, lockmode=lockmode) return self._select_statement(statement, params=params, populate_existing=reload, version_check=(lockmode is not None))[0] diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index bd601ed800..5ed95fabb5 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -680,7 +680,7 @@ class ForeignKey(SchemaItem): """Return True if the given table is referenced by this ``ForeignKey``.""" return table.corresponding_column(self.column, False) is not None - + def _init_column(self): # ForeignKey inits its remote column as late as possible, so tables can # be defined without dependencies diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index 6924a60ceb..74f085cb1a 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -1075,15 +1075,24 @@ class ColumnCollection(util.OrderedProperties): super(ColumnCollection, self).__init__() [self.add(c) for c in cols] + def __str__(self): + return repr([str(c) for c in self]) + def add(self, column): """Add a column to this collection. The key attribute of the column will be used as the hash key for this dictionary. """ - self[column.key] = column - + + def remove(self, column): + del self[column.key] + + def extend(self, iter): + for c in iter: + self.add(c) + def __eq__(self, other): l = [] for c in other: @@ -1243,6 +1252,16 @@ class FromClause(Selectable): self._primary_key = ColumnCollection() self._foreign_keys = util.Set() self._orig_cols = {} + for co in self._adjusted_exportable_columns(): + cp = self._proxy_column(co) + for ci in cp.orig_set: + self._orig_cols[ci] = cp + if self.oid_column is not None: + for ci in self.oid_column.orig_set: + self._orig_cols[ci] = self.oid_column + + def _adjusted_exportable_columns(self): + """return the list of ColumnElements represented within this FromClause's _exportable_columns""" export = self._exportable_columns() for column in export: try: @@ -1250,13 +1269,8 @@ class FromClause(Selectable): except AttributeError: continue for co in s.columns: - cp = self._proxy_column(co) - for ci in cp.orig_set: - self._orig_cols[ci] = cp - if self.oid_column is not None: - for ci in self.oid_column.orig_set: - self._orig_cols[ci] = self.oid_column - + yield co + def _exportable_columns(self): return [] @@ -1661,10 +1675,23 @@ class Join(FromClause): else: self.onclause = onclause self.isouter = isouter - + self.__folded_equivalents = None + self._init_primary_key() + name = property(lambda s: "Join object on " + s.left.name + " " + s.right.name) encodedname = property(lambda s: s.name.encode('ascii', 'backslashreplace')) - + + def _init_primary_key(self): + pkcol = util.Set() + for col in self._adjusted_exportable_columns(): + if col.primary_key: + pkcol.add(col) + for col in list(pkcol): + for f in col.foreign_keys: + if f.column in pkcol: + pkcol.remove(col) + self.primary_key.extend(pkcol) + def _locate_oid_column(self): return self.left.oid_column @@ -1673,8 +1700,6 @@ class Join(FromClause): def _proxy_column(self, column): self._columns[column._label] = column - if column.primary_key: - self._primary_key.add(column) for f in column.foreign_keys: self._foreign_keys.add(f) return column @@ -1706,6 +1731,8 @@ class Join(FromClause): return True def _get_folded_equivalents(self, equivs=None): + if self.__folded_equivalents is not None: + return self.__folded_equivalents if equivs is None: equivs = util.Set() class LocateEquivs(NoColumnVisitor): @@ -1731,7 +1758,8 @@ class Join(FromClause): used.add(c.name) else: collist.append(c) - return collist + self.__folded_equivalents = collist + return self.__folded_equivalents def select(self, whereclause = None, fold_equivalents=False, **kwargs): """Create a ``Select`` from this ``Join``. @@ -1740,9 +1768,11 @@ class Join(FromClause): the WHERE criterion that will be sent to the ``select()`` function fold_equivalents - based on the join criterion of this ``Join``, do not include equivalent - columns in the column list of the resulting select. this will recursively - apply to any joins directly nested by this one as well. + based on the join criterion of this ``Join``, do not include repeat + column names in the column list of the resulting select, for columns that + are calculated to be "equivalent" based on the join criterion of this + ``Join``. this will recursively apply to any joins directly nested by + this one as well. \**kwargs all other kwargs are sent to the underlying ``select()`` function diff --git a/test/orm/inheritance.py b/test/orm/inheritance.py index d09e685b39..0aabdb9be2 100644 --- a/test/orm/inheritance.py +++ b/test/orm/inheritance.py @@ -96,26 +96,41 @@ class InheritTest2(testbase.ORMTest): Column('foo_id', Integer, ForeignKey('foo.id')), Column('bar_id', Integer, ForeignKey('bar.bid'))) + def testget(self): + class Foo(object):pass + def __init__(self, data=None): + self.data = data + class Bar(Foo):pass + + mapper(Foo, foo) + mapper(Bar, bar, inherits=Foo) + + b = Bar('somedata') + sess = create_session() + sess.save(b) + sess.flush() + sess.clear() + + # test that "bar.bid" does not need to be referenced in a get + # (ticket 185) + assert sess.query(Bar).get(b.id).id == b.id + def testbasic(self): class Foo(object): def __init__(self, data=None): self.data = data - def __str__(self): - return "Foo(%s)" % self.data - def __repr__(self): - return str(self) mapper(Foo, foo) class Bar(Foo): - def __str__(self): - return "Bar(%s)" % self.data + pass mapper(Bar, bar, inherits=Foo, properties={ 'foos': relation(Foo, secondary=foo_bar, lazy=False) }) sess = create_session() - b = Bar('barfoo', _sa_session=sess) + b = Bar('barfoo') + sess.save(b) sess.flush() f1 = Foo('subfoo1') diff --git a/test/orm/mapper.py b/test/orm/mapper.py index f5a4613c95..942fdfdd9e 100644 --- a/test/orm/mapper.py +++ b/test/orm/mapper.py @@ -626,84 +626,6 @@ class MapperTest(MapperSuperTest): q3 = sess.query(User).options(eagerload('orders.items.keywords')) u = q3.select() self.assert_sql_count(db, go, 2) - -class InheritanceTest(MapperSuperTest): - - def testinherits(self): - class _Order(object): - pass - ordermapper = mapper(_Order, orders) - - class _User(object): - pass - usermapper = mapper(_User, users, properties = dict( - orders = relation(ordermapper, lazy = False) - )) - - class AddressUser(_User): - pass - mapper(AddressUser, addresses, inherits = usermapper) - - sess = create_session() - q = sess.query(AddressUser) - l = q.select() - - jack = l[0] - self.assert_(jack.user_name=='jack') - jack.email_address = 'jack@gmail.com' - sess.flush() - sess.clear() - au = q.get_by(user_name='jack') - self.assert_(au.email_address == 'jack@gmail.com') - - def testinherits2(self): - class _Order(object): - pass - class _Address(object): - pass - class AddressUser(_Address): - pass - ordermapper = mapper(_Order, orders) - addressmapper = mapper(_Address, addresses) - usermapper = mapper(AddressUser, users, inherits = addressmapper, - properties = { - 'orders' : relation(ordermapper, lazy=False) - }) - sess = create_session() - l = sess.query(usermapper).select() - jack = l[0] - self.assert_(jack.user_name=='jack') - jack.email_address = 'jack@gmail.com' - sess.flush() - sess.clear() - au = sess.query(usermapper).get_by(user_name='jack') - self.assert_(au.email_address == 'jack@gmail.com') - - def testlazyoption(self): - """test that a lazy options gets created against its correct mapper when - using options with inheriting mappers""" - class _Order(object): - pass - class _User(object): - pass - class AddressUser(_User): - pass - ordermapper = mapper(_Order, orders) - usermapper = mapper(_User, users, - properties = { - 'orders' : relation(ordermapper, lazy=True) - }) - amapper = mapper(AddressUser, addresses, inherits = usermapper) - - sess = create_session() - - def go(): - l = sess.query(AddressUser).options(lazyload('orders')).select() - # this would fail because the "orders" lazyloader gets created against AddressUsers selectable - # and not _User's. - assert len(l[0].orders) == 3 - self.assert_sql_count(db, go, 2) - class DeferredTest(MapperSuperTest):